diff --git a/SpecForge-ext/cache/compiled_kernels/2e/c2etayrlw6ivbtj3uahv4l3y7x534xpzfww6cyknbe2kfe54yei5.py b/SpecForge-ext/cache/compiled_kernels/2e/c2etayrlw6ivbtj3uahv4l3y7x534xpzfww6cyknbe2kfe54yei5.py new file mode 100644 index 0000000000000000000000000000000000000000..a7205a8e78f9df1b0f7c2ae1e50a9eaf473f8a6d --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/2e/c2etayrlw6ivbtj3uahv4l3y7x534xpzfww6cyknbe2kfe54yei5.py @@ -0,0 +1,43 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 128, 'r0_': 16}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_sum_2(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + x2 = (xindex % ks1) + x3 = xindex // ks1 + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x2 + x3*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp5, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/2h/c2hvdjlmxyob2txn4nddktnqpzxakuy4vukk46jxvlks5plszr5r.py b/SpecForge-ext/cache/compiled_kernels/2h/c2hvdjlmxyob2txn4nddktnqpzxakuy4vukk46jxvlks5plszr5r.py new file mode 100644 index 0000000000000000000000000000000000000000..f39aa709bdfa1a51017a0b3c3d01193aa24c2cf9 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/2h/c2hvdjlmxyob2txn4nddktnqpzxakuy4vukk46jxvlks5plszr5r.py @@ -0,0 +1,62 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*bf16', 'in_ptr1': 'fp64', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mean_mul_pow_rsqrt_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_mean_mul_pow_rsqrt_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tmp1 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp9 = in_ptr1 + tmp6 = ks0 + tmp7 = tmp6.to(tl.float32) + tmp8 = (tmp4 / tmp7) + tmp10 = tmp9.to(tl.float32) + tmp11 = tmp8 + tmp10 + tmp12 = libdevice.rsqrt(tmp11) + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, xmask) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp13 = tl.load(in_ptr2 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp14 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp15 = tmp14.to(tl.float32) + tmp16 = tmp15 * tmp12 + tmp17 = tmp16.to(tl.float32) + tmp18 = tmp13 * tmp17 + tl.store(out_ptr0 + (r0_1 + ks0*x0), tmp18, r0_mask & xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/2s/c2sasa5yimiwlxmywmcvgtuh2fvol2mvhppzairkbqvuwicnbd5y.py b/SpecForge-ext/cache/compiled_kernels/2s/c2sasa5yimiwlxmywmcvgtuh2fvol2mvhppzairkbqvuwicnbd5y.py new file mode 100644 index 0000000000000000000000000000000000000000..0c96182c6d4f497d2174ccf0a6377049f3b8cd0c --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/2s/c2sasa5yimiwlxmywmcvgtuh2fvol2mvhppzairkbqvuwicnbd5y.py @@ -0,0 +1,62 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32) + _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + + _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine( + _tmp3_max, _tmp3_sum, tmp2, False + ) + + _tmp3_max = tl.where(r0_mask & xmask, _tmp3_max_next, _tmp3_max) + _tmp3_sum = tl.where(r0_mask & xmask, _tmp3_sum_next, _tmp3_sum) + + tmp3, tmp4 = triton_helpers.online_softmax_reduce( + _tmp3_max, _tmp3_sum, 1, False) + tmp3 = tmp3[:, None] + tmp4 = tmp4[:, None] + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp6 - tmp3 + tmp8 = libdevice.exp(tmp7) + tmp9 = (tmp8 / tmp4) + tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask & xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/2s/c2sberivpmlochymf7ley5gcelz4rvptoxduf2zbu47hbzpiry2g.py b/SpecForge-ext/cache/compiled_kernels/2s/c2sberivpmlochymf7ley5gcelz4rvptoxduf2zbu47hbzpiry2g.py new file mode 100644 index 0000000000000000000000000000000000000000..b45f80c99dc0a3d71caa7ee22b8b87fe14425438 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/2s/c2sberivpmlochymf7ley5gcelz4rvptoxduf2zbu47hbzpiry2g.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32', 'ks8': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128*ks1, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 8 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks8 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = ks8 + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/2x/3dd4effcc6c7612a42d28cac3a6342345062808f2904d114a779d751ce7956b2.best_config b/SpecForge-ext/cache/compiled_kernels/2x/3dd4effcc6c7612a42d28cac3a6342345062808f2904d114a779d751ce7956b2.best_config new file mode 100644 index 0000000000000000000000000000000000000000..6c3a3559496e6e4d68292da2e678eca0b03342ab --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/2x/3dd4effcc6c7612a42d28cac3a6342345062808f2904d114a779d751ce7956b2.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 65, "triton_cache_hash": "NFABHOURJ57C2IKXWDMS2VHZ76PCVKJVD7V6CBWJDLMT5TQE5GFA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/2x/c2xgz3ru7j7sptpmoelww3e5lkmoeimpyawjjwmcpaujxtdorhwr.py b/SpecForge-ext/cache/compiled_kernels/2x/c2xgz3ru7j7sptpmoelww3e5lkmoeimpyawjjwmcpaujxtdorhwr.py new file mode 100644 index 0000000000000000000000000000000000000000..d25e6baeb247006d7d15e57e79178d390aec9530 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/2x/c2xgz3ru7j7sptpmoelww3e5lkmoeimpyawjjwmcpaujxtdorhwr.py @@ -0,0 +1,56 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 67108864}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/2x/c2xsu5ssb3jappbwwrbr53muiaoukfjzccks7reewucgvplouktq.py b/SpecForge-ext/cache/compiled_kernels/2x/c2xsu5ssb3jappbwwrbr53muiaoukfjzccks7reewucgvplouktq.py new file mode 100644 index 0000000000000000000000000000000000000000..9490bcec18b8c7d610e0e736f734293b4bfaf64c --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/2x/c2xsu5ssb3jappbwwrbr53muiaoukfjzccks7reewucgvplouktq.py @@ -0,0 +1,43 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_sum_2(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + x2 = (xindex % ks1) + x3 = xindex // ks1 + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x2 + x3*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp5, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/2x/c2xunts4zntd65pabgkkxg5ylyh7sahfyogzmgljfiljdui4o365.py b/SpecForge-ext/cache/compiled_kernels/2x/c2xunts4zntd65pabgkkxg5ylyh7sahfyogzmgljfiljdui4o365.py new file mode 100644 index 0000000000000000000000000000000000000000..d2b2f2b2bac18861e6cc964cb303121164a518c5 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/2x/c2xunts4zntd65pabgkkxg5ylyh7sahfyogzmgljfiljdui4o365.py @@ -0,0 +1,62 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 128, 'r0_': 32}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'in_ptr1': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr3'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2(in_ptr0, in_ptr1, out_ptr1, out_ptr2, out_ptr3, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x0), tmp5, xmask) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp6 = tl.load(in_ptr1 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8 < tmp5 + tmp10 = ks0 + tmp11 = tl.where(tmp9, tmp7, tmp10) + tmp12 = 1 + ks0 + tmp13 = tmp11 + tmp12 + tmp14 = tmp11 < 0 + tmp15 = tl.where(tmp14, tmp13, tmp11) + tl.device_assert(((0 <= tmp15) & (tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128)))) | ~(r0_mask & xmask), "index out of bounds: 0 <= tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128))") + tmp17 = tl.full([1, 1], 1, tl.int32) + tl.store(out_ptr2 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp7, r0_mask & xmask) + tl.store(out_ptr3 + (tl.broadcast_to(tmp15 + x0 + ks0*x0, [XBLOCK, R0_BLOCK])), tmp17, r0_mask & xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py b/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py new file mode 100644 index 0000000000000000000000000000000000000000..5125090333a6044eb4b46e80d8b881be6d8a5015 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py @@ -0,0 +1,56 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 262144}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i64', 'r0_numel': 'i64', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 151936 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0).to(tl.int64) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None].to(tl.int64) + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :].to(tl.int64) + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 9223372036854775807, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tmp11 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last') + tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32) + tmp4 = tmp2 + tmp3 + tmp5 = tmp2 < 0 + tmp6 = tl.where(tmp5, tmp4, tmp2) + tl.device_assert(((0 <= tmp6) & (tmp6 < 151936)) | ~(xmask), "index out of bounds: 0 <= tmp6 < 151936") + tmp8 = tl.load(in_ptr1 + (tmp6), xmask, eviction_policy='evict_last').to(tl.int1) + tmp9 = tmp8.to(tl.int32) + tmp10 = tmp9.to(tl.int64) + tmp12 = tmp10 * tmp11 + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/3k/c0fc7bc81a7e9d406f980957c0881903e8484dd7f57d970f2ddd21ca3ab2994d.best_config b/SpecForge-ext/cache/compiled_kernels/3k/c0fc7bc81a7e9d406f980957c0881903e8484dd7f57d970f2ddd21ca3ab2994d.best_config new file mode 100644 index 0000000000000000000000000000000000000000..6c3a3559496e6e4d68292da2e678eca0b03342ab --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/3k/c0fc7bc81a7e9d406f980957c0881903e8484dd7f57d970f2ddd21ca3ab2994d.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 65, "triton_cache_hash": "NFABHOURJ57C2IKXWDMS2VHZ76PCVKJVD7V6CBWJDLMT5TQE5GFA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/3k/c3kdupo6eufhy2marzoeoddgc3okqj6m3aii3f42onl4ag77vf6u.py b/SpecForge-ext/cache/compiled_kernels/3k/c3kdupo6eufhy2marzoeoddgc3okqj6m3aii3f42onl4ag77vf6u.py new file mode 100644 index 0000000000000000000000000000000000000000..1d86f6833dd9f0709f1a7e201e44e2b85c83bf7e --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/3k/c3kdupo6eufhy2marzoeoddgc3okqj6m3aii3f42onl4ag77vf6u.py @@ -0,0 +1,56 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 67108864}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/3m/c3mfnz3jqpdzlott45yvd2kki53nhik366siiuob2jitdkwx6tyg.py b/SpecForge-ext/cache/compiled_kernels/3m/c3mfnz3jqpdzlott45yvd2kki53nhik366siiuob2jitdkwx6tyg.py new file mode 100644 index 0000000000000000000000000000000000000000..ad984631be6fc334e9eac24b1526a531839ec943 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/3m/c3mfnz3jqpdzlott45yvd2kki53nhik366siiuob2jitdkwx6tyg.py @@ -0,0 +1,46 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'out_ptr1': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_clamp_min_div_sum_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_clamp_min_div_sum_3(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 1 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = _tmp2 + tmp1 + _tmp2 = tl.where(r0_mask, tmp3, _tmp2) + tmp2 = tl.sum(_tmp2, 1)[:, None] + tmp4 = tl.load(in_ptr1 + (0)) + tmp5 = tl.broadcast_to(tmp4, [XBLOCK, 1]) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp2.to(tl.float32) + tmp8 = 1e-06 + tmp9 = triton_helpers.maximum(tmp7, tmp8) + tmp10 = (tmp6 / tmp9) + tl.store(out_ptr1 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp10, None) diff --git a/SpecForge-ext/cache/compiled_kernels/3p/c3pmafpvrty43do4nz3cf2mvhkihfulfxbiolmcu2votxja4s56e.py b/SpecForge-ext/cache/compiled_kernels/3p/c3pmafpvrty43do4nz3cf2mvhkihfulfxbiolmcu2votxja4s56e.py new file mode 100644 index 0000000000000000000000000000000000000000..ed47aa7ade117817803ec035175733bb32f2d488 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/3p/c3pmafpvrty43do4nz3cf2mvhkihfulfxbiolmcu2votxja4s56e.py @@ -0,0 +1,352 @@ +# AOT ID: ['14_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ob/cob65ptxwcswkyjowvaxmwnu4cpoiijoxwce6eyz2ndtpqxwqxm5.py +# Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax => argmax +# Graph fragment: +# %arg1_1 : Tensor "bf16[2, s3, 32000][32000*s3, 32000, 1]cuda:4" = PlaceHolder[target=arg1_1] +# %argmax : Tensor "i64[2, s3][s3, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {}) +# return %argmax +triton_red_fused_argmax_0 = async_compile.triton('triton_red_fused_argmax_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/qj/cqj5277ktaoo5rg4kvnn7pm72cbfiwp7hxewmxzj4aevxoorlebn.py +# Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax_1 => argmax_1 +# Graph fragment: +# %arg3_1 : Tensor "f32[2, s3, 32000][s71, 32000, 1]cuda:4" = PlaceHolder[target=arg3_1] +# %argmax_1 : Tensor "i64[2, s3][s3, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg3_1, -1), kwargs = {}) +# return %argmax_1 +triton_red_fused_argmax_1 = async_compile.triton('triton_red_fused_argmax_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + ks1*x1), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sx/csxbkyhsnglm4cv6i6ibhzv34wlbjpntyk3gj27zbyc4s4efmtxh.py +# Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum] +# Source node to ATen node mapping: +# eq => eq_2 +# mul => mul_3 +# squeeze => squeeze +# sum_1 => sum_1 +# Graph fragment: +# %argmax : Tensor "i64[2, s3][s3, 1]cuda:4" = PlaceHolder[target=argmax] +# %argmax_1 : Tensor "i64[2, s3][s3, 1]cuda:4" = PlaceHolder[target=argmax_1] +# %arg4_1 : Tensor "i64[2, s3, 1][s3, 1, 1]cuda:4" = PlaceHolder[target=arg4_1] +# %eq_2 : Tensor "b8[2, s3][s3, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {}) +# %squeeze : Tensor "i64[2, s3][s3, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg4_1, -1), kwargs = {}) +# %mul_3 : Tensor "i64[2, s3][s3, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq_2, %squeeze), kwargs = {}) +# %sum_1 : Tensor "i64[][]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_3,), kwargs = {}) +# return %sum_1 +triton_red_fused_eq_mul_squeeze_sum_2 = async_compile.triton('triton_red_fused_eq_mul_squeeze_sum_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'out_ptr0': '*i64', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 1 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp4 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp7, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xe/cxeou6auzbu4dnrn2twxe573bmqovq7xnk4b6hydfbw53px4etc7.py +# Topologically Sorted Source Nodes: [sum_2, clamp_min, truediv], Original ATen: [aten.sum, aten.clamp_min, aten.div] +# Source node to ATen node mapping: +# clamp_min => clamp_min +# sum_2 => sum_2 +# truediv => div +# Graph fragment: +# %arg6_1 : Tensor "i64[2, s14, 1][s14, 1, 1]cuda:4" = PlaceHolder[target=arg6_1] +# %sum_1 : Tensor "i64[][]cuda:4" = PlaceHolder[target=sum_1] +# %sum_2 : Tensor "i64[][]cuda:4" = PlaceHolder[target=sum_2] +# %sum_2 : Tensor "i64[][]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg6_1,), kwargs = {}) +# %clamp_min : Tensor "f32[][]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%sum_2, 1e-06), kwargs = {}) +# %div : Tensor "f32[][]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, %clamp_min), kwargs = {}) +# return %sum_2,%div +triton_red_fused_clamp_min_div_sum_3 = async_compile.triton('triton_red_fused_clamp_min_div_sum_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'out_ptr1': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_clamp_min_div_sum_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_clamp_min_div_sum_3(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 1 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = _tmp2 + tmp1 + _tmp2 = tl.where(r0_mask, tmp3, _tmp2) + tmp2 = tl.sum(_tmp2, 1)[:, None] + tmp4 = tl.load(in_ptr1 + (0)) + tmp5 = tl.broadcast_to(tmp4, [XBLOCK, 1]) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp2.to(tl.float32) + tmp8 = 1e-06 + tmp9 = triton_helpers.maximum(tmp7, tmp8) + tmp10 = (tmp6 / tmp9) + tl.store(out_ptr1 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp10, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1 = args + args.clear() + s3 = arg0_1 + s71 = arg2_1 + s14 = arg5_1 + assert_size_stride(arg1_1, (2, s3, 32000), (32000*s3, 32000, 1)) + assert_size_stride(arg3_1, (2, s3, 32000), (s71, 32000, 1)) + assert_size_stride(arg4_1, (2, s3, 1), (s3, 1, 1)) + assert_size_stride(arg6_1, (2, s14, 1), (s14, 1, 1)) + with torch.cuda._DeviceGuard(4): + torch.cuda.set_device(4) + buf0 = empty_strided_cuda((2, s3), (s3, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] + triton_red_fused_argmax_0_xnumel = 2*s3 + stream4 = get_raw_stream(4) + triton_red_fused_argmax_0.run(arg1_1, buf0, triton_red_fused_argmax_0_xnumel, 32000, stream=stream4) + del arg1_1 + buf1 = empty_strided_cuda((2, s3), (s3, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] + triton_red_fused_argmax_1_xnumel = 2*s3 + stream4 = get_raw_stream(4) + triton_red_fused_argmax_1.run(arg3_1, buf1, s3, s71, triton_red_fused_argmax_1_xnumel, 32000, stream=stream4) + del arg3_1 + buf2 = empty_strided_cuda((), (), torch.int64) + # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum] + triton_red_fused_eq_mul_squeeze_sum_2_r0_numel = 2*s3 + stream4 = get_raw_stream(4) + triton_red_fused_eq_mul_squeeze_sum_2.run(buf0, buf1, arg4_1, buf2, 1, triton_red_fused_eq_mul_squeeze_sum_2_r0_numel, stream=stream4) + del arg4_1 + del buf0 + del buf1 + buf4 = empty_strided_cuda((), (), torch.float32) + # Topologically Sorted Source Nodes: [sum_2, clamp_min, truediv], Original ATen: [aten.sum, aten.clamp_min, aten.div] + triton_red_fused_clamp_min_div_sum_3_r0_numel = 2*s14 + stream4 = get_raw_stream(4) + triton_red_fused_clamp_min_div_sum_3.run(arg6_1, buf2, buf4, 1, triton_red_fused_clamp_min_div_sum_3_r0_numel, stream=stream4) + del arg6_1 + del buf2 + return (buf4, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 1543 + arg1_1 = rand_strided((2, 1543, 32000), (49376000, 32000, 1), device='cuda:4', dtype=torch.bfloat16) + arg2_1 = 49600000 + arg3_1 = rand_strided((2, 1543, 32000), (49600000, 32000, 1), device='cuda:4', dtype=torch.float32) + arg4_1 = rand_strided((2, 1543, 1), (1543, 1, 1), device='cuda:4', dtype=torch.int64) + arg5_1 = 1543 + arg6_1 = rand_strided((2, 1543, 1), (1543, 1, 1), device='cuda:4', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/3u/c3ukv75kqyf3oeeogojmsgmsebbc2fg3rqs4dsmnshhsgj4hjkzx.py b/SpecForge-ext/cache/compiled_kernels/3u/c3ukv75kqyf3oeeogojmsgmsebbc2fg3rqs4dsmnshhsgj4hjkzx.py new file mode 100644 index 0000000000000000000000000000000000000000..584b3d919b923a1e656651e026770491d8445e0c --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/3u/c3ukv75kqyf3oeeogojmsgmsebbc2fg3rqs4dsmnshhsgj4hjkzx.py @@ -0,0 +1,168 @@ +# AOT ID: ['10_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/js/cjse6ak6jsp3o35wdszmvjyn4cqeqewbex3a5ks2m6fqecygrmmg.py +# Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul] +# Source node to ATen node mapping: +# getitem_1 => unsqueeze +# position_mask => mul_2 +# target_mask => index +# target_mask_1 => convert_element_type +# target_max_token => argmax +# Graph fragment: +# %arg1_1 : Tensor "bf16[2, s14, 151936][151936*s14, 151936, 1]cuda:0" = PlaceHolder[target=arg1_1] +# %argmax : Tensor "i64[2, s14][s14, 1]cuda:0" = PlaceHolder[target=argmax] +# %arg2_1 : Tensor "b8[151936][1]cuda:0" = PlaceHolder[target=arg2_1] +# %arg3_1 : Tensor "i64[2, s14, 1][s14, 1, 1]cuda:0" = PlaceHolder[target=arg3_1] +# %argmax : Tensor "i64[2, s14][s14, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {}) +# %index : Tensor "b8[2, s14][s14, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%argmax]), kwargs = {}) +# %unsqueeze : Tensor "b8[2, s14, 1][s14, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 2), kwargs = {}) +# %convert_element_type : Tensor "i32[2, s14, 1][s14, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze, torch.int32), kwargs = {}) +# %mul_2 : Tensor "i64[2, s14, 1][s14, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %arg3_1), kwargs = {}) +# return %argmax,%mul_2 +triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0 = async_compile.triton('triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 262144}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 151936 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tmp11 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last') + tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32) + tmp4 = tmp2 + tmp3 + tmp5 = tmp2 < 0 + tmp6 = tl.where(tmp5, tmp4, tmp2) + tl.device_assert(((0 <= tmp6) & (tmp6 < 151936)) | ~(xmask), "index out of bounds: 0 <= tmp6 < 151936") + tmp8 = tl.load(in_ptr1 + (tmp6), xmask, eviction_policy='evict_last').to(tl.int1) + tmp9 = tmp8.to(tl.int32) + tmp10 = tmp9.to(tl.int64) + tmp12 = tmp10 * tmp11 + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1 = args + args.clear() + s24 = arg0_1 + arg1_1_size = arg1_1.size() + s14 = arg1_1_size[1] + assert_size_stride(arg1_1, (2, s14, 151936), (151936*s14, 151936, 1)) + assert_size_stride(arg2_1, (151936, ), (1, )) + assert_size_stride(arg3_1, (2, s14, 1), (s14, 1, 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf0 = empty_strided_cuda((2, s14), (s14, 1), torch.int64) + buf1 = reinterpret_tensor(buf0, (2, s14, 1), (s14, 1, 1), 0); del buf0 # reuse + # Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul] + triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0_xnumel = 2*s14 + stream0 = get_raw_stream(0) + triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0.run(buf1, arg1_1, arg2_1, arg3_1, triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0_xnumel, 151936, stream=stream0) + del arg1_1 + del arg2_1 + del arg3_1 + return (buf1, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 1130 + arg1_1 = rand_strided((2, 1130, 151936), (171687680, 151936, 1), device='cuda:0', dtype=torch.bfloat16) + arg2_1 = rand_strided((151936, ), (1, ), device='cuda:0', dtype=torch.bool) + arg3_1 = rand_strided((2, 1130, 1), (1130, 1, 1), device='cuda:0', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/3x/88057732cb1d7a775c254455fe42105016cd2d1ced3af1bd1fb079691b5972a1.best_config b/SpecForge-ext/cache/compiled_kernels/3x/88057732cb1d7a775c254455fe42105016cd2d1ced3af1bd1fb079691b5972a1.best_config new file mode 100644 index 0000000000000000000000000000000000000000..266f460278e85c345185451f023083ef4f3937ee --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/3x/88057732cb1d7a775c254455fe42105016cd2d1ced3af1bd1fb079691b5972a1.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "b6ac5ef64fddcad8fc8d2c05fa12424871fd9baa5a4158ff38ecebbafb55a4b1", "found_by_coordesc": false, "time_taken_ms": 35, "triton_cache_hash": "MMGM2ESHRXPRFAROBBDYKTZUJ2JVVKU2TB5DVA3EL4OF2SOELPMQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/3x/c3xxifdzdkxpgs4yujq2dd2lwhcylvzgstk2hao7g2yjirkxlafr.py b/SpecForge-ext/cache/compiled_kernels/3x/c3xxifdzdkxpgs4yujq2dd2lwhcylvzgstk2hao7g2yjirkxlafr.py new file mode 100644 index 0000000000000000000000000000000000000000..528d400f3ea553cc337b25179cfa55fa99686b8a --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/3x/c3xxifdzdkxpgs4yujq2dd2lwhcylvzgstk2hao7g2yjirkxlafr.py @@ -0,0 +1,86 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 128, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr4': '*i32', 'out_ptr5': '*i32', 'out_ptr6': '*i32', 'out_ptr7': '*i32', 'out_ptr8': '*i32', 'out_ptr9': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr7', 'out_ptr9'], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(in_ptr0, out_ptr4, out_ptr5, out_ptr6, out_ptr7, out_ptr8, out_ptr9, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 128 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + tmp0 = tl.load(in_ptr0 + (r0_1 + 16*x0), xmask, other=0.0) + tmp1 = tl.full([1, 1], 0, tl.int64) + tmp2 = tmp0 > tmp1 + tmp3 = tl.full([1, 1], 16384, tl.int64) + tmp4 = tmp0 < tmp3 + tmp5 = tmp2 & tmp4 + tmp6 = tmp5.to(tl.int8) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8.to(tl.int16) + tmp10 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp11 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12, tmp13, = triton_helpers.sort_with_index(tmp10, tmp11, None, 1, stable=True, descending=True) + tmp14 = tmp0 == tmp3 + tmp15 = tmp14.to(tl.int8) + tmp16 = tmp15.to(tl.int32) + tmp17 = tl.broadcast_to(tmp16, [XBLOCK, R0_BLOCK]) + tmp18, tmp19, = triton_helpers.sort_with_index(tmp17, tmp11, None, 1, stable=True, descending=True) + tmp20 = tmp7.to(tl.int64) + tmp21 = tl.broadcast_to(tmp20, [XBLOCK, R0_BLOCK]) + tmp23 = tl.where(xmask, tmp21, 0) + tmp24 = tl.sum(tmp23, 1)[:, None].to(tl.int64) + tmp25 = tmp16.to(tl.int64) + tmp26 = tl.broadcast_to(tmp25, [XBLOCK, R0_BLOCK]) + tmp28 = tl.where(xmask, tmp26, 0) + tmp29 = tl.sum(tmp28, 1)[:, None].to(tl.int64) + tmp30 = tmp24.to(tl.int32) + tmp31 = tmp29.to(tl.int32) + tmp32 = tmp13.to(tl.int64) + tmp33 = tmp32.to(tl.int32) + tmp34 = tmp8 < tmp30 + tmp35 = tl.full([1, 1], 16, tl.int32) + tmp36 = tl.where(tmp34, tmp33, tmp35) + tmp37 = tl.full([XBLOCK, R0_BLOCK], 17, tl.int32) + tmp38 = tmp36 + tmp37 + tmp39 = tmp36 < 0 + tmp40 = tl.where(tmp39, tmp38, tmp36) + tl.device_assert(((0 <= tmp40) & (tmp40 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp40 < 17") + tmp42 = tl.full([1, 1], 1, tl.int32) + tmp43 = tmp19.to(tl.int64) + tmp44 = tmp43.to(tl.int32) + tmp45 = tmp8 < tmp31 + tmp46 = tl.where(tmp45, tmp44, tmp35) + tmp47 = tmp46 + tmp37 + tmp48 = tmp46 < 0 + tmp49 = tl.where(tmp48, tmp47, tmp46) + tl.device_assert(((0 <= tmp49) & (tmp49 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp49 < 17") + tl.store(out_ptr4 + (x0), tmp30, xmask) + tl.store(out_ptr5 + (x0), tmp31, xmask) + tl.store(out_ptr6 + (r0_1 + 16*x0), tmp33, xmask) + tl.store(out_ptr7 + (tl.broadcast_to(tmp40 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) + tl.store(out_ptr8 + (r0_1 + 16*x0), tmp44, xmask) + tl.store(out_ptr9 + (tl.broadcast_to(tmp49 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/42/c42h7visn4guss7swxj4up2er4ije4hyno7yrughuvurnenh2pvd.py b/SpecForge-ext/cache/compiled_kernels/42/c42h7visn4guss7swxj4up2er4ije4hyno7yrughuvurnenh2pvd.py new file mode 100644 index 0000000000000000000000000000000000000000..b30fb55292538655098a6231dec65a10d456c1c0 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/42/c42h7visn4guss7swxj4up2er4ije4hyno7yrughuvurnenh2pvd.py @@ -0,0 +1,45 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/42/c42olsblh7ymaib2tr5gwhfzuighing5bkpmabq5hx7nxumtbsig.py b/SpecForge-ext/cache/compiled_kernels/42/c42olsblh7ymaib2tr5gwhfzuighing5bkpmabq5hx7nxumtbsig.py new file mode 100644 index 0000000000000000000000000000000000000000..9ecabd0be565ee97512eb001cbcf0b4eaa07ee41 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/42/c42olsblh7ymaib2tr5gwhfzuighing5bkpmabq5hx7nxumtbsig.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks2 + stride_q_idx_h = 16*ks3 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/44/0525feab4902a63d7e5c68635e4f503cce07497c146c803e2f852ae21bd67e9c.best_config b/SpecForge-ext/cache/compiled_kernels/44/0525feab4902a63d7e5c68635e4f503cce07497c146c803e2f852ae21bd67e9c.best_config new file mode 100644 index 0000000000000000000000000000000000000000..6c3a3559496e6e4d68292da2e678eca0b03342ab --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/44/0525feab4902a63d7e5c68635e4f503cce07497c146c803e2f852ae21bd67e9c.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 65, "triton_cache_hash": "NFABHOURJ57C2IKXWDMS2VHZ76PCVKJVD7V6CBWJDLMT5TQE5GFA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/44/c444sn4254wny52itl5mlassuxuptb3kc3p6r3r3stsx4lyt6t3r.py b/SpecForge-ext/cache/compiled_kernels/44/c444sn4254wny52itl5mlassuxuptb3kc3p6r3r3stsx4lyt6t3r.py new file mode 100644 index 0000000000000000000000000000000000000000..060f967bb0d81c2065f4dcb42defd961219a1cef --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/44/c444sn4254wny52itl5mlassuxuptb3kc3p6r3r3stsx4lyt6t3r.py @@ -0,0 +1,527 @@ +# AOT ID: ['8_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/7m/c7mmadjna7dltm72lxvsoktdadnw2jtxufsj2eoflefh2r5jo4gq.py +# Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] +# Source node to ATen node mapping: +# dense_mask_2 => full_default_1 +# Graph fragment: +# %full_default_1 : Tensor "i32[8, 1, 16, (((s37 + 127)//128)) + 1][16*Max(1, (((s37 + 127)//128)) + 1), 16*Max(1, (((s37 + 127)//128)) + 1), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, %add_166], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:1, pin_memory: False}) +# return %index_put +triton_poi_fused_new_zeros_0 = async_compile.triton('triton_poi_fused_new_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 8192}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3g/c3gmll5f74ypurxotx73fmzfaldqb5oaua4nvjazcxzzjafvjoo2.py +# Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_1, mask_2, mask_3, mask_block_sum, gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, full_blocks, full_blocks_1, dense_mask_1], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.constant_pad_nd, aten.permute, aten.sum, aten.gt, aten._to_copy] +# Source node to ATen node mapping: +# and_2 => bitwise_and_1 +# and_3 => bitwise_and_2 +# and_4 => bitwise_and_3, view_8 +# b => iota +# batched_outputs_2 => view_9 +# causal_mask => ge_1, view +# dense_mask => convert_element_type_2 +# dense_mask_1 => convert_element_type_5 +# diagnol_mask => eq_12 +# full_blocks => eq_24 +# full_blocks_1 => convert_element_type_1 +# gt => gt +# index => index +# index_1 => index_1 +# index_2 => index_2 +# lt => lt, view_1 +# lt_1 => lt_1, view_2 +# lt_3 => lt_3 +# m => iota_2 +# mask_1 => constant_pad_nd +# mask_2 => view_10 +# mask_3 => permute +# mask_block_sum => sum_1 +# n => iota_3 +# padding_mask => bitwise_and, view_3, view_4 +# padding_mask_1 => lt_2, view_6 +# partial_blocks => bitwise_and_4 +# partial_blocks_1 => convert_element_type +# remainder => remainder +# remainder_1 => remainder_1 +# result_1 => bitwise_or, full_default +# result_2 => bitwise_or_1 +# sub => sub_12, view_7 +# suffix_mask => ge_2 +# Graph fragment: +# %arg1_1 : Tensor "i64[8][1]cuda:1" = PlaceHolder[target=arg1_1] +# %sum_1 : Tensor "i64[8, 1, 16, ((s37 + 127)//128)][16*(((s37 + 127)//128)), 128*(((s37 + 127)//128)), ((s37 + 127)//128), 1]cuda:1" = PlaceHolder[target=sum_1] +# %full_default : Tensor "b8[8, 1, 1][1, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1], False), kwargs = {dtype: torch.bool, layout: torch.strided, device: cuda:1, pin_memory: False}) +# %iota_2 : Tensor "i64[2048][1]cuda:1"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (2048,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:1, requires_grad: False}) +# %view : Tensor "i64[2048, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {}) +# %iota_3 : Tensor "i64[s37][1]cuda:1"[num_users=5] = call_function[target=torch.ops.prims.iota.default](args = (%arg0_1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:1, requires_grad: False}) +# %ge_1 : Tensor "b8[2048, s37][Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.ge.Tensor](args = (%view, %iota_3), kwargs = {}) +# %iota : Tensor "i64[8][1]cuda:1"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:1, requires_grad: False}) +# %index : Tensor "i64[8][1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg1_1, [%iota]), kwargs = {}) +# %view_1 : Tensor "i64[8, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index, [8, 1]), kwargs = {}) +# %lt : Tensor "b8[8, s37][Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_3, %view_1), kwargs = {}) +# %view_4 : Tensor "b8[8, 1, s37][Max(1, s37), s37, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt, [8, 1, %arg0_1]), kwargs = {}) +# %index_1 : Tensor "i64[8][1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg1_1, [%iota]), kwargs = {}) +# %view_2 : Tensor "i64[8, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_1, [8, 1]), kwargs = {}) +# %lt_1 : Tensor "b8[8, 2048][2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_2, %view_2), kwargs = {}) +# %view_3 : Tensor "b8[8, 2048, 1][2048, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt_1, [8, 2048, 1]), kwargs = {}) +# %bitwise_and : Tensor "b8[8, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_4, %view_3), kwargs = {}) +# %bitwise_and_1 : Tensor "b8[8, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_1, %bitwise_and), kwargs = {}) +# %bitwise_or : Tensor "b8[8, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%full_default, %bitwise_and_1), kwargs = {}) +# %ge_2 : Tensor "b8[s37][1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%iota_3, 2048), kwargs = {}) +# %remainder : Tensor "i64[s37][1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%iota_3, 2048), kwargs = {}) +# %index_2 : Tensor "i64[8][1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg1_1, [%iota]), kwargs = {}) +# %view_6 : Tensor "i64[8, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_2, [8, 1]), kwargs = {}) +# %lt_2 : Tensor "b8[8, s37][Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%remainder, %view_6), kwargs = {}) +# %bitwise_and_2 : Tensor "b8[8, s37][Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_2, %lt_2), kwargs = {}) +# %view_8 : Tensor "b8[8, 1, s37][Max(1, s37), s37, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_and_2, [8, 1, %arg0_1]), kwargs = {}) +# %view_7 : Tensor "i64[2048, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {}) +# %sub_12 : Tensor "i64[2048, s37][Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%iota_3, %view_7), kwargs = {}) +# %remainder_1 : Tensor "i64[2048, s37][Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%sub_12, 2048), kwargs = {}) +# %eq_12 : Tensor "b8[2048, s37][Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%remainder_1, 0), kwargs = {}) +# %bitwise_and_3 : Tensor "b8[8, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_8, %eq_12), kwargs = {}) +# %bitwise_or_1 : Tensor "b8[8, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%bitwise_or, %bitwise_and_3), kwargs = {}) +# %view_9 : Tensor "b8[8, 1, 2048, s37][2048*Max(1, s37), 2048*Max(1, s37), Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_or_1, [8, 1, 2048, %arg0_1]), kwargs = {}) +# %constant_pad_nd : Tensor "b8[8, 1, 2048, 128*(((s37 + 127)//128))][2048*Max(1, 128*(((s37 + 127)//128))), 2048*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s37 + 127)//128))), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.constant_pad_nd.default](args = (%expand, [0, %sub_23, 0, 0], 0.0), kwargs = {}) +# %view_10 : Tensor "b8[8, 1, 16, 128, ((s37 + 127)//128), 128][2048*Max(1, 128*(((s37 + 127)//128))), 2048*Max(1, 128*(((s37 + 127)//128))), 128*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s37 + 127)//128))), 128, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%constant_pad_nd, [8, 1, 16, 128, %floordiv_1, 128]), kwargs = {}) +# %permute : Tensor "b8[8, 1, 16, ((s37 + 127)//128), 128, 128][2048*Max(1, 128*(((s37 + 127)//128))), 2048*Max(1, 128*(((s37 + 127)//128))), 128*Max(1, 128*(((s37 + 127)//128))), 128, Max(1, 128*(((s37 + 127)//128))), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 1, 2, 4, 3, 5]), kwargs = {}) +# %sum_1 : Tensor "i64[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=3] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute, [-2, -1]), kwargs = {}) +# %gt : Tensor "b8[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {}) +# %lt_3 : Tensor "b8[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %bitwise_and_4 : Tensor "b8[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%gt, %lt_3), kwargs = {}) +# %convert_element_type : Tensor "i8[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%bitwise_and_4, torch.int8), kwargs = {}) +# %convert_element_type_2 : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type, torch.int32), kwargs = {}) +# %eq_24 : Tensor "b8[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %convert_element_type_1 : Tensor "i8[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%eq_24, torch.int8), kwargs = {}) +# %convert_element_type_5 : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_1, torch.int32), kwargs = {}) +# return %sum_1,%convert_element_type_2,%convert_element_type_5 +triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1 = async_compile.triton('triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1(in_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % 16) + x2 = xindex // ks2 + _tmp36 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x5 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = (r0_index % 128) + r0_4 = r0_index // 128 + tmp0 = r0_3 + 128*x0 + tmp1 = ks1 + tmp2 = tmp0 < tmp1 + tmp3 = r0_4 + 128*x1 + tmp4 = r0_3 + 128*x0 + tmp5 = tmp3 >= tmp4 + tmp6 = tl.load(in_ptr0 + (tl.broadcast_to(x2, [XBLOCK, R0_BLOCK])), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp7 = tmp4 < tmp6 + tmp8 = tmp3 < tmp6 + tmp9 = tmp7 & tmp8 + tmp10 = tmp5 & tmp9 + tmp11 = tl.full([1, 1], False, tl.int1) + tmp12 = tmp11 | tmp10 + tmp13 = tl.full([1, 1], 2048, tl.int64) + tmp14 = tmp4 >= tmp13 + tmp15 = ((r0_3 + 128*x0) % 2048) + tmp16 = tmp15 < tmp6 + tmp17 = tmp14 & tmp16 + tmp18 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp19 = (tmp18 % tmp13) + tmp20 = tl.full([1, 1], 0, tl.int32) + tmp21 = tmp19 != tmp20 + tmp22 = (libdevice.signbit(tmp19) != 0) if (tmp19).dtype is tl.float32 else tmp19 < 0 + tmp23 = (libdevice.signbit(tmp13) != 0) if (tmp13).dtype is tl.float32 else tmp13 < 0 + tmp24 = tmp22 != tmp23 + tmp25 = tmp21 & tmp24 + tmp26 = tmp19 + tmp13 + tmp27 = tl.where(tmp25, tmp26, tmp19) + tmp28 = tl.full([1, 1], 0, tl.int64) + tmp29 = tmp27 == tmp28 + tmp30 = tmp17 & tmp29 + tmp31 = tmp12 | tmp30 + tmp32 = tl.full(tmp31.shape, False, tmp31.dtype) + tmp33 = tl.where(tmp2, tmp31, tmp32) + tmp34 = tmp33.to(tl.int64) + tmp35 = tl.broadcast_to(tmp34, [XBLOCK, R0_BLOCK]) + tmp37 = _tmp36 + tmp35 + _tmp36 = tl.where(r0_mask & xmask, tmp37, _tmp36) + tmp36 = tl.sum(_tmp36, 1)[:, None] + tmp38 = tl.full([1, 1], 0, tl.int64) + tmp39 = tmp36 > tmp38 + tmp40 = tl.full([1, 1], 16384, tl.int64) + tmp41 = tmp36 < tmp40 + tmp42 = tmp39 & tmp41 + tmp43 = tmp42.to(tl.int8) + tmp44 = tmp43.to(tl.int32) + tmp45 = tmp36 == tmp40 + tmp46 = tmp45.to(tl.int8) + tmp47 = tmp46.to(tl.int32) + tl.store(out_ptr1 + (x5), tmp44, xmask) + tl.store(out_ptr2 + (x5), tmp47, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/rf/crf75ojmgx3s35d4vq6bm6ahr7jskra4xhldkhobbo2elpsvqhja.py +# Topologically Sorted Source Nodes: [dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten._to_copy, aten.lt, aten.scalar_tensor, aten.where, aten.view, aten.index_put] +# Source node to ATen node mapping: +# arange_4 => iota_4 +# child_3 => convert_element_type_3 +# child_4 => convert_element_type_4 +# col_range => iota_5 +# dense_mask_2 => full_default_1 +# index_mask => lt_4 +# num_blocks_in_row => sum_2 +# row_indices => unsqueeze +# setitem => full_default_2, index_put, iota_6, iota_7, unsqueeze_2, unsqueeze_3, unsqueeze_4, unsqueeze_5, unsqueeze_6 +# unsqueeze_1 => unsqueeze_1 +# valid_indices => scalar_tensor, where +# Graph fragment: +# %convert_element_type_2 : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*(((s37 + 127)//128)), 128*(((s37 + 127)//128)), ((s37 + 127)//128), 1]cuda:1" = PlaceHolder[target=convert_element_type_2] +# %sum_2 : Tensor "i64[8, 1, 16][16, 128, 1]cuda:1" = PlaceHolder[target=sum_2] +# %getitem_1 : Tensor "i64[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 128*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1" = PlaceHolder[target=getitem_1] +# %convert_element_type_3 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:1" = PlaceHolder[target=convert_element_type_3] +# %convert_element_type_4 : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1" = PlaceHolder[target=convert_element_type_4] +# %index_put : Tensor "i32[8, 1, 16, (((s37 + 127)//128)) + 1][16*(((s37 + 127)//128)) + 16, 16*(((s37 + 127)//128)) + 16, (((s37 + 127)//128)) + 1, 1]cuda:1" = PlaceHolder[target=index_put] +# %full_default_1 : Tensor "i32[8, 1, 16, (((s37 + 127)//128)) + 1][16*Max(1, (((s37 + 127)//128)) + 1), 16*Max(1, (((s37 + 127)//128)) + 1), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, %add_166], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:1, pin_memory: False}) +# %iota_7 : Tensor "i64[8][1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:1, requires_grad: False}) +# %unsqueeze_4 : Tensor "i64[8, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_7, -1), kwargs = {}) +# %unsqueeze_5 : Tensor "i64[8, 1, 1][1, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_4, -1), kwargs = {}) +# %unsqueeze_6 : Tensor "i64[8, 1, 1, 1][1, 1, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_5, -1), kwargs = {}) +# %iota_6 : Tensor "i64[1][1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:1, requires_grad: False}) +# %unsqueeze_2 : Tensor "i64[1, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_6, -1), kwargs = {}) +# %unsqueeze_3 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_2, -1), kwargs = {}) +# %iota_4 : Tensor "i32[16][1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:1, requires_grad: False}) +# %unsqueeze : Tensor "i32[16, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_4, -1), kwargs = {}) +# %iota_5 : Tensor "i32[((s37 + 127)//128)][1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (%floordiv_1,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:1, requires_grad: False}) +# %sum_2 : Tensor "i64[8, 1, 16][16, 16, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_2, [-1]), kwargs = {}) +# %convert_element_type_3 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_2, torch.int32), kwargs = {}) +# %unsqueeze_1 : Tensor "i32[8, 1, 16, 1][16, 16, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_3, 3), kwargs = {}) +# %lt_4 : Tensor "b8[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_5, %unsqueeze_1), kwargs = {}) +# %convert_element_type_4 : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_1, torch.int32), kwargs = {}) +# %scalar_tensor : Tensor "i32[][]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (%floordiv_1,), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:1}) +# %where : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_4, %convert_element_type_4, %scalar_tensor), kwargs = {}) +# %full_default_2 : Tensor "i32[8, 1, 1, 1][1, 1, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:1, pin_memory: False}) +# %index_put : Tensor "i32[8, 1, 16, (((s37 + 127)//128)) + 1][16*Max(1, (((s37 + 127)//128)) + 1), 16*Max(1, (((s37 + 127)//128)) + 1), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_1, [%unsqueeze_6, %unsqueeze_3, %unsqueeze, %where], %full_default_2), kwargs = {}) +# return %sum_2,%convert_element_type_3,%convert_element_type_4,%buf13 +triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2 = async_compile.triton('triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 128, 'r0_': 32}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'in_ptr1': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr3'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2(in_ptr0, in_ptr1, out_ptr1, out_ptr2, out_ptr3, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x0), tmp5, xmask) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp6 = tl.load(in_ptr1 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8 < tmp5 + tmp10 = ks0 + tmp11 = tl.where(tmp9, tmp7, tmp10) + tmp12 = 1 + ks0 + tmp13 = tmp11 + tmp12 + tmp14 = tmp11 < 0 + tmp15 = tl.where(tmp14, tmp13, tmp11) + tl.device_assert(((0 <= tmp15) & (tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128)))) | ~(r0_mask & xmask), "index out of bounds: 0 <= tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128))") + tmp17 = tl.full([1, 1], 1, tl.int32) + tl.store(out_ptr2 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp7, r0_mask & xmask) + tl.store(out_ptr3 + (tl.broadcast_to(tmp15 + x0 + ks0*x0, [XBLOCK, R0_BLOCK])), tmp17, r0_mask & xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/gn/cgnsrigp6qu2lbqq76g27kshvt2bzkyjnupza5ds7znhjxrnwhif.py +# Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] +# Source node to ATen node mapping: +# batched_outputs_3 => clone_4, slice_2 +# col_indices_2 => sort_2 +# num_blocks_in_row_2 => sum_4 +# q_indices => clone_6, convert_element_type_9 +# q_num_blocks => convert_element_type_8 +# transpose => permute_1 +# Graph fragment: +# %buf13 : Tensor "i32[8, 1, 16, (((s37 + 127)//128)) + 1][16*(((s37 + 127)//128)) + 16, 16*(((s37 + 127)//128)) + 16, (((s37 + 127)//128)) + 1, 1]cuda:1" = PlaceHolder[target=buf13] +# %buf15 : Tensor "i16[8, 1, ((s37 + 127)//128), 16][16*(((s37 + 127)//128)), 128*(((s37 + 127)//128)), 16, 1]cuda:1" = PlaceHolder[target=buf15] +# %sum_4 : Tensor "i64[8, 1, ((s37 + 127)//128)][((s37 + 127)//128), 8*(((s37 + 127)//128)), 1]cuda:1" = PlaceHolder[target=sum_4] +# %slice_2 : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*Max(1, (((s37 + 127)//128)) + 1), 16*Max(1, (((s37 + 127)//128)) + 1), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, %floordiv_1), kwargs = {}) +# %clone_4 : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_2,), kwargs = {memory_format: torch.contiguous_format}) +# %permute_1 : Tensor "i32[8, 1, ((s37 + 127)//128), 16][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:1"[num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%clone_4, [0, 1, 3, 2]), kwargs = {}) +# %sort_2 : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%permute_1,), kwargs = {stable: True, descending: True}) +# %convert_element_type_9 : Tensor "i32[8, 1, ((s37 + 127)//128), 16][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_5, torch.int32), kwargs = {}) +# %clone_6 : Tensor "i32[8, 1, ((s37 + 127)//128), 16][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), 16, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_9,), kwargs = {memory_format: torch.contiguous_format}) +# %sum_4 : Tensor "i64[8, 1, ((s37 + 127)//128)][Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute_1, [-1]), kwargs = {}) +# %convert_element_type_8 : Tensor "i32[8, 1, ((s37 + 127)//128)][Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_4, torch.int32), kwargs = {}) +# return %buf15,%sum_4,%clone_6,%convert_element_type_8 +triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3 = async_compile.triton('triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 256, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % ks0) + x1 = xindex // ks0 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (r0_2 + x0 + 16*x1 + ks0*r0_2 + 16*ks0*x1), xmask, eviction_policy='evict_last', other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x0 + 16*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp13, xmask) + tl.store(out_ptr3 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp14, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1 = args + args.clear() + s37 = arg0_1 + assert_size_stride(arg1_1, (8, ), (1, )) + with torch.cuda._DeviceGuard(1): + torch.cuda.set_device(1) + buf12 = empty_strided_cuda((8, 1, 16, 1 + ((127 + s37) // 128)), (16 + 16*((127 + s37) // 128), 16 + 16*((127 + s37) // 128), 1 + ((127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] + triton_poi_fused_new_zeros_0_xnumel = 128 + 128*((127 + s37) // 128) + stream1 = get_raw_stream(1) + triton_poi_fused_new_zeros_0.run(buf12, triton_poi_fused_new_zeros_0_xnumel, stream=stream1) + buf19 = empty_strided_cuda((8, 1, 16, 1 + ((127 + s37) // 128)), (16 + 16*((127 + s37) // 128), 16 + 16*((127 + s37) // 128), 1 + ((127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros] + triton_poi_fused_new_zeros_0_xnumel = 128 + 128*((127 + s37) // 128) + stream1 = get_raw_stream(1) + triton_poi_fused_new_zeros_0.run(buf19, triton_poi_fused_new_zeros_0_xnumel, stream=stream1) + ps0 = (127 + s37) // 128 + ps1 = 16*((127 + s37) // 128) + buf1 = empty_strided_cuda((8, 1, 16, (127 + s37) // 128), (16*((127 + s37) // 128), 128*((127 + s37) // 128), (127 + s37) // 128, 1), torch.int32) + buf5 = empty_strided_cuda((8, 1, 16, (127 + s37) // 128), (16*((127 + s37) // 128), 128*((127 + s37) // 128), (127 + s37) // 128, 1), torch.int32) + # Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_1, mask_2, mask_3, mask_block_sum, gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, full_blocks, full_blocks_1, dense_mask_1], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.constant_pad_nd, aten.permute, aten.sum, aten.gt, aten._to_copy] + triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1_xnumel = 128*((127 + s37) // 128) + stream1 = get_raw_stream(1) + triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1.run(arg1_1, buf1, buf5, ps0, s37, ps1, triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1_xnumel, 16384, stream=stream1) + del arg1_1 + # Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort] + buf2 = torch.ops.aten.sort.stable(buf1, stable=True, dim=3, descending=True) + buf4 = buf2[1] + assert_size_stride(buf4, (8, 1, 16, (127 + s37) // 128), (16*max(1, (127 + s37) // 128), 128*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), 'torch.ops.aten.sort.stable') + assert_alignment(buf4, 16, 'torch.ops.aten.sort.stable') + del buf2 + buf10 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32) + buf11 = empty_strided_cuda((8, 1, 16, (127 + s37) // 128), (16*max(1, (127 + s37) // 128), 16*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten._to_copy, aten.lt, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2_r0_numel = (127 + s37) // 128 + stream1 = get_raw_stream(1) + triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2.run(buf1, buf4, buf10, buf11, buf12, ps0, s37, 128, triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2_r0_numel, stream=stream1) + del buf1 + del buf4 + buf26 = empty_strided_cuda((8, 1, (127 + s37) // 128, 16), (16*max(1, (127 + s37) // 128), 16*max(1, (127 + s37) // 128), 16, 1), torch.int32) + buf28 = empty_strided_cuda((8, 1, (127 + s37) // 128), (max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3_xnumel = 8*((127 + s37) // 128) + stream1 = get_raw_stream(1) + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf12, buf26, buf28, ps0, triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3_xnumel, 16, stream=stream1) + del buf12 + # Topologically Sorted Source Nodes: [full_blocks, full_blocks_1, dense_mask_1, col_indices_1], Original ATen: [aten.eq, aten._to_copy, aten.sort] + buf6 = torch.ops.aten.sort.stable(buf5, stable=True, dim=3, descending=True) + buf8 = buf6[1] + assert_size_stride(buf8, (8, 1, 16, (127 + s37) // 128), (16*max(1, (127 + s37) // 128), 128*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), 'torch.ops.aten.sort.stable') + assert_alignment(buf8, 16, 'torch.ops.aten.sort.stable') + del buf6 + buf17 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32) + buf18 = empty_strided_cuda((8, 1, 16, (127 + s37) // 128), (16*max(1, (127 + s37) // 128), 16*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, num_blocks_in_row_1, child_7, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten._to_copy, aten.lt, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2_r0_numel = (127 + s37) // 128 + stream1 = get_raw_stream(1) + triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2.run(buf5, buf8, buf17, buf18, buf19, ps0, s37, 128, triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2_r0_numel, stream=stream1) + del buf5 + del buf8 + buf23 = empty_strided_cuda((8, 1, (127 + s37) // 128, 16), (16*max(1, (127 + s37) // 128), 16*max(1, (127 + s37) // 128), 16, 1), torch.int32) + buf25 = empty_strided_cuda((8, 1, (127 + s37) // 128), (max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, col_indices_3, full_q_indices, num_blocks_in_row_3, full_q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3_xnumel = 8*((127 + s37) // 128) + stream1 = get_raw_stream(1) + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf19, buf23, buf25, ps0, triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3_xnumel, 16, stream=stream1) + del buf19 + return (buf23, buf25, buf26, buf28, buf18, buf17, buf11, buf10, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 4096 + arg1_1 = rand_strided((8, ), (1, ), device='cuda:1', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/44/c44kjaqobzzgvmjyd6g2ial2qqjsfjve7v3q6locl7ykhfs2td6p.py b/SpecForge-ext/cache/compiled_kernels/44/c44kjaqobzzgvmjyd6g2ial2qqjsfjve7v3q6locl7ykhfs2td6p.py new file mode 100644 index 0000000000000000000000000000000000000000..4791966e5a5b15be3eb0f2b25bd1ec227aa12f8a --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/44/c44kjaqobzzgvmjyd6g2ial2qqjsfjve7v3q6locl7ykhfs2td6p.py @@ -0,0 +1,56 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 67108864}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/47/c477sca7o5mbhj2pknepbw5b3rzush4uzefidcyxm6ysescvabgf.py b/SpecForge-ext/cache/compiled_kernels/47/c477sca7o5mbhj2pknepbw5b3rzush4uzefidcyxm6ysescvabgf.py new file mode 100644 index 0000000000000000000000000000000000000000..303baf2604202ae138dca0a55c83b17009eaf8dd --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/47/c477sca7o5mbhj2pknepbw5b3rzush4uzefidcyxm6ysescvabgf.py @@ -0,0 +1,48 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 65536, 'r0_': 524288000}} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 4096 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = xindex // 2048 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + 65760000*x1), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, None) diff --git a/SpecForge-ext/cache/compiled_kernels/47/c47ib4wwcplmaudqv7246conlxcjjylxi5ahlxye6ebv6onrgoxg.py b/SpecForge-ext/cache/compiled_kernels/47/c47ib4wwcplmaudqv7246conlxcjjylxi5ahlxye6ebv6onrgoxg.py new file mode 100644 index 0000000000000000000000000000000000000000..eecba3c01f055aab09d797b5affe3f2c80816913 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/47/c47ib4wwcplmaudqv7246conlxcjjylxi5ahlxye6ebv6onrgoxg.py @@ -0,0 +1,37 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 32}, + reduction_hint=ReductionHint.OUTER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_mul_sum_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_mul_sum_1(in_ptr0, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + ks0*r0_1), xmask, other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.where(xmask, tmp1, 0) + tmp4 = tl.sum(tmp3, 1)[:, None].to(tl.float32) + tl.store(out_ptr0 + (x0), tmp4, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/47/c47rre73srjytbq7fn2vqqophv2xicf4cmcdwzenpxfzmxo7jyzi.py b/SpecForge-ext/cache/compiled_kernels/47/c47rre73srjytbq7fn2vqqophv2xicf4cmcdwzenpxfzmxo7jyzi.py new file mode 100644 index 0000000000000000000000000000000000000000..fb2c8296af049fa3e4adb170844e9a44f0152608 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/47/c47rre73srjytbq7fn2vqqophv2xicf4cmcdwzenpxfzmxo7jyzi.py @@ -0,0 +1,46 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 65536, 'r0_': 262144000}} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 4096 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, None) diff --git a/SpecForge-ext/cache/compiled_kernels/4d/71d6529f8e555bde19385922da8b4def675c0510b5b664b1bf4faebe9f8928eb.best_config b/SpecForge-ext/cache/compiled_kernels/4d/71d6529f8e555bde19385922da8b4def675c0510b5b664b1bf4faebe9f8928eb.best_config new file mode 100644 index 0000000000000000000000000000000000000000..7921a12b007ca46a00e959ad115401adf0bd4471 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/4d/71d6529f8e555bde19385922da8b4def675c0510b5b664b1bf4faebe9f8928eb.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "B46RWD5PEMKEQR7EBR6IG3BGTK4P7CWBVNOODNZQX5NAVXXVIH2A"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/4d/c4dlhtn4sdcx4l7dqawohb4u6hvu3xnzxvwazc5qrp6haqgnm3ev.py b/SpecForge-ext/cache/compiled_kernels/4d/c4dlhtn4sdcx4l7dqawohb4u6hvu3xnzxvwazc5qrp6haqgnm3ev.py new file mode 100644 index 0000000000000000000000000000000000000000..44193b7b0366afab4b01bec085ebb493950144f8 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/4d/c4dlhtn4sdcx4l7dqawohb4u6hvu3xnzxvwazc5qrp6haqgnm3ev.py @@ -0,0 +1,24 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 2048}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/4d/c4dw6ykxgbwk3glacutxkpzwhapvr5oszjet3n4i4q3snumjzm3x.py b/SpecForge-ext/cache/compiled_kernels/4d/c4dw6ykxgbwk3glacutxkpzwhapvr5oszjet3n4i4q3snumjzm3x.py new file mode 100644 index 0000000000000000000000000000000000000000..b9bbb584a2610116b0500060b002b0a2e4cc2ca2 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/4d/c4dw6ykxgbwk3glacutxkpzwhapvr5oszjet3n4i4q3snumjzm3x.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 16 + stride_q_idx_h = 256 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/57/c57svaeo74ael4oxqveudfvhx4xfmu3ikrmvljcnixb4kiqagrzn.py b/SpecForge-ext/cache/compiled_kernels/57/c57svaeo74ael4oxqveudfvhx4xfmu3ikrmvljcnixb4kiqagrzn.py new file mode 100644 index 0000000000000000000000000000000000000000..13bd7003aeb441f56a922c22d4749c7184ff64cc --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/57/c57svaeo74ael4oxqveudfvhx4xfmu3ikrmvljcnixb4kiqagrzn.py @@ -0,0 +1,47 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + ks1*x1), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/5b/c5bepmksmwsj5jf67nbhozcficyzxyuodtllwlc5wms64ubwgqh6.py b/SpecForge-ext/cache/compiled_kernels/5b/c5bepmksmwsj5jf67nbhozcficyzxyuodtllwlc5wms64ubwgqh6.py new file mode 100644 index 0000000000000000000000000000000000000000..3a52b7ad1d0e543044fd6ee71cdb25ad1339a1b5 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/5b/c5bepmksmwsj5jf67nbhozcficyzxyuodtllwlc5wms64ubwgqh6.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32', 'ks8': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128*ks1, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 8 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks8 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = ks8 + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/5b/c5bjpwebwz42wx5vrxljhccvdqdyttlxd2hpbtzw7ia3oq6c33ne.py b/SpecForge-ext/cache/compiled_kernels/5b/c5bjpwebwz42wx5vrxljhccvdqdyttlxd2hpbtzw7ia3oq6c33ne.py new file mode 100644 index 0000000000000000000000000000000000000000..91b12e5d118f7cd44518a1b2a97eb2d909b68a5d --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/5b/c5bjpwebwz42wx5vrxljhccvdqdyttlxd2hpbtzw7ia3oq6c33ne.py @@ -0,0 +1,62 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'in_ptr3': '*i64', 'out_ptr2': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'r0_': 131072}} +) +@triton.jit +def triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 1 + r0_numel = 4096 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp4 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + _tmp11 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp9 = tl.load(in_ptr3 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = _tmp11 + tmp10 + _tmp11 = tl.where(r0_mask, tmp12, _tmp11) + tmp11 = tl.sum(_tmp11, 1)[:, None] + tmp13 = tmp7.to(tl.float32) + tmp14 = tmp11.to(tl.float32) + tmp15 = 1e-06 + tmp16 = triton_helpers.maximum(tmp14, tmp15) + tmp17 = (tmp13 / tmp16) + tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp17, None) diff --git a/SpecForge-ext/cache/compiled_kernels/5g/c5g26r4ygcctmxuptx453t3kikkqukh73touvd4yxv7futs36kgf.py b/SpecForge-ext/cache/compiled_kernels/5g/c5g26r4ygcctmxuptx453t3kikkqukh73touvd4yxv7futs36kgf.py new file mode 100644 index 0000000000000000000000000000000000000000..dc2a463eba6f2e5a18a113d693da6c7fca1b7e07 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/5g/c5g26r4ygcctmxuptx453t3kikkqukh73touvd4yxv7futs36kgf.py @@ -0,0 +1,66 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 67108864}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/5g/c5g4egnryommgtc4braxeh3xxhfypsb6ec3v4sx3gxmfoeho5bzl.py b/SpecForge-ext/cache/compiled_kernels/5g/c5g4egnryommgtc4braxeh3xxhfypsb6ec3v4sx3gxmfoeho5bzl.py new file mode 100644 index 0000000000000000000000000000000000000000..c33474d1c37c84c80f3f5fb76fb8f293e9fc58c4 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/5g/c5g4egnryommgtc4braxeh3xxhfypsb6ec3v4sx3gxmfoeho5bzl.py @@ -0,0 +1,1083 @@ +# AOT ID: ['13_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/of/cofwz2ulo5xzqhau3cyhif5tweuyn7cqvg27usnkxh25zmnsmxqm.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[8, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:0" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[8, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:0" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[8, 32, s37][32*s37, s37, 1]cuda:0" = PlaceHolder[target=buf0] +# %full_default : Tensor "f32[8, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 32, %primals_10], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_17, %primals_20, %primals_22, %primals_25, %primals_27, %primals_30, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_14, %primals_15)), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_zeros_0 = async_compile.triton('triton_red_fused_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 524288, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % 32) + x2 = xindex // ks1 + x5 = triton_helpers.div_floor_integer(xindex, ks0) + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x4 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 4096*ks0*x2), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x0 + 128*x5*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/2s/c2sberivpmlochymf7ley5gcelz4rvptoxduf2zbu47hbzpiry2g.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_2 : Tensor "bf16[8, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:0" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:0" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:0" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[8, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:0" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[8, 32, s37][32*s37, s37, 1]cuda:0" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[8, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:0" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[8, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:0" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:0" = PlaceHolder[target=getitem_5] +# %primals_13 : Tensor "i32[8, 1, s99][s99, s99, 1]cuda:0" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[8, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:0" = PlaceHolder[target=primals_9] +# %primals_22 : Tensor "i32[8, 1, s56][s56, s56, 1]cuda:0" = PlaceHolder[target=primals_22] +# %primals_25 : Tensor "i32[8, 1, s84, s53][s53*s84, s53*s84, s53, 1]cuda:0" = PlaceHolder[target=primals_25] +# %primals_17 : Tensor "i32[8, 1, s94][s94, s94, 1]cuda:0" = PlaceHolder[target=primals_17] +# %primals_20 : Tensor "i32[8, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:0" = PlaceHolder[target=primals_20] +# %primals_27 : Tensor "i32[8, 1, s100][s100, s100, 1]cuda:0" = PlaceHolder[target=primals_27] +# %primals_30 : Tensor "i32[8, 1, s6, s10][s10*s6, s10*s6, s10, 1]cuda:0" = PlaceHolder[target=primals_30] +# %primals_14 : Tensor "i64[8][1]cuda:0" = PlaceHolder[target=primals_14] +# %full_default : Tensor "f32[8, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 32, %primals_10], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_17, %primals_20, %primals_22, %primals_25, %primals_27, %primals_30, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_14, %primals_15)), kwargs = {}) +# return %getitem_4 +triton_tem_fused_zeros_1 = async_compile.triton('triton_tem_fused_zeros_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32', 'ks8': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128*ks1, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 8 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks8 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = ks8 + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_10, primals_11, primals_15, primals_7, primals_8, primals_12, primals_16, primals_18, primals_19, primals_21, primals_24, primals_23, primals_26, primals_29, primals_28, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_17, primals_20, primals_22, primals_25, primals_27, primals_30, getitem, getitem_1, tangents_1 = args + args.clear() + s37 = primals_10 + s0 = primals_11 + s75 = primals_15 + s22 = primals_7 + s72 = primals_8 + s99 = primals_12 + s94 = primals_16 + s28 = primals_18 + s4 = primals_19 + s56 = primals_21 + s53 = primals_24 + s84 = primals_23 + s100 = primals_26 + s10 = primals_29 + s6 = primals_28 + assert_size_stride(primals_2, (8, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (8, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_6, (8, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_9, (8, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (8, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_14, (8, ), (1, )) + assert_size_stride(primals_17, (8, 1, s94), (s94, s94, 1)) + assert_size_stride(primals_20, (8, 1, s28, s4), (s28*s4, s28*s4, s4, 1)) + assert_size_stride(primals_22, (8, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_25, (8, 1, s84, s53), (s53*s84, s53*s84, s53, 1)) + assert_size_stride(primals_27, (8, 1, s100), (s100, s100, 1)) + assert_size_stride(primals_30, (8, 1, s6, s10), (s10*s6, s10*s6, s10, 1)) + assert_size_stride(getitem, (8, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(getitem_1, (8, 32, s37), (32*max(1, s37), max(1, s37), 1)) + assert_size_stride(tangents_1, (8, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + ps0 = 32*s37 + buf1 = empty_strided_cuda((8, 32, s37), (32*s37, s37, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + triton_red_fused_zeros_0_xnumel = 256*s37 + stream0 = get_raw_stream(0) + triton_red_fused_zeros_0.run(getitem, tangents_1, buf1, s37, ps0, triton_red_fused_zeros_0_xnumel, 128, stream=stream0) + del getitem + buf3 = empty_strided_cuda((8, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((8, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16) + buf5 = empty_strided_cuda((8, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream0 = get_raw_stream(0) + triton_tem_fused_zeros_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_13, primals_9, primals_22, primals_25, primals_17, primals_20, primals_27, primals_30, primals_14, buf5, s37, s0, s99, s22, s72, s56, s53, s84, s75, 4*((127 + s37) // 128) + ((127 + s0) // 128), 8, 8, stream=stream0) + del buf1 + del getitem_1 + del primals_13 + del primals_14 + del primals_17 + del primals_2 + del primals_20 + del primals_22 + del primals_25 + del primals_27 + del primals_30 + del primals_4 + del primals_6 + del primals_9 + del tangents_1 + return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_10 = 2009 + primals_11 = 2009 + primals_15 = 2009 + primals_7 = 16 + primals_8 = 16 + primals_12 = 16 + primals_16 = 16 + primals_18 = 16 + primals_19 = 16 + primals_21 = 16 + primals_24 = 16 + primals_23 = 16 + primals_26 = 16 + primals_29 = 16 + primals_28 = 16 + primals_2 = rand_strided((8, 32, 2009, 128), (8228864, 128, 4096, 1), device='cuda:0', dtype=torch.bfloat16) + primals_4 = rand_strided((8, 8, 2009, 128), (2057216, 257152, 128, 1), device='cuda:0', dtype=torch.bfloat16) + primals_6 = rand_strided((8, 8, 2009, 128), (2057216, 257152, 128, 1), device='cuda:0', dtype=torch.bfloat16) + primals_9 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:0', dtype=torch.int32) + primals_13 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:0', dtype=torch.int32) + primals_14 = rand_strided((8, ), (1, ), device='cuda:0', dtype=torch.int64) + primals_17 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:0', dtype=torch.int32) + primals_20 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:0', dtype=torch.int32) + primals_22 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:0', dtype=torch.int32) + primals_25 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:0', dtype=torch.int32) + primals_27 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:0', dtype=torch.int32) + primals_30 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:0', dtype=torch.int32) + getitem = rand_strided((8, 32, 2009, 128), (8228864, 128, 4096, 1), device='cuda:0', dtype=torch.bfloat16) + getitem_1 = rand_strided((8, 32, 2009), (64288, 2009, 1), device='cuda:0', dtype=torch.float32) + tangents_1 = rand_strided((8, 32, 2009, 128), (8228864, 257152, 128, 1), device='cuda:0', dtype=torch.bfloat16) + fn = lambda: call([primals_10, primals_11, primals_15, primals_7, primals_8, primals_12, primals_16, primals_18, primals_19, primals_21, primals_24, primals_23, primals_26, primals_29, primals_28, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_17, primals_20, primals_22, primals_25, primals_27, primals_30, getitem, getitem_1, tangents_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/5g/f7b122e9e44d29c2d695b9633c17bd6eee619900dd2325a29224a34ca8164da3.best_config b/SpecForge-ext/cache/compiled_kernels/5g/f7b122e9e44d29c2d695b9633c17bd6eee619900dd2325a29224a34ca8164da3.best_config new file mode 100644 index 0000000000000000000000000000000000000000..128251849e0d90499e31f76727557122755609e2 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/5g/f7b122e9e44d29c2d695b9633c17bd6eee619900dd2325a29224a34ca8164da3.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 65, "triton_cache_hash": "UQSFYICF6CFQWZOBHCGZ7JZ457GHWVO6RMPN5ABNWOATFMKI6GQA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/5j/c5jkwkcyjhbh7wkjfrwxf42i764iqjvwig2sk7me5dmfpiuqgldn.py b/SpecForge-ext/cache/compiled_kernels/5j/c5jkwkcyjhbh7wkjfrwxf42i764iqjvwig2sk7me5dmfpiuqgldn.py new file mode 100644 index 0000000000000000000000000000000000000000..c1ffe73ee44c18cf09d1844a5ae59838d9ef8a9c --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/5j/c5jkwkcyjhbh7wkjfrwxf42i764iqjvwig2sk7me5dmfpiuqgldn.py @@ -0,0 +1,410 @@ +# AOT ID: ['7_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xg/cxgas2jt2e7frwpm3nd5h7wlzdo2fb2yonkvkfoarpyafiwosg23.py +# Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax => argmax +# Graph fragment: +# %arg0_1 : Tensor "bf16[8, 2048, 32000][65536000, 32000, 1]cuda:3" = PlaceHolder[target=arg0_1] +# %argmax : Tensor "i64[8, 2048][2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg0_1, -1), kwargs = {}) +# return %argmax +triton_red_fused_argmax_0 = async_compile.triton('triton_red_fused_argmax_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 262144, 'r0_': 1048576000}} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/fv/cfvm655j5cm4524gmdyhr7yli6dffpakysuycuozkqmyuaonkwbg.py +# Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax_1 => argmax_1 +# Graph fragment: +# %arg1_1 : Tensor "f32[8, 2048, 32000][65760000, 32000, 1]cuda:3" = PlaceHolder[target=arg1_1] +# %argmax_1 : Tensor "i64[8, 2048][2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {}) +# return %argmax_1 +triton_red_fused_argmax_1 = async_compile.triton('triton_red_fused_argmax_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 262144, 'r0_': 2097152000}} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = xindex // 2048 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + 65760000*x1), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/tz/ctz36xgzfd5jcgzrek7dztousbfkxkdiqxeixt3t36guolxobku7.py +# Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum] +# Source node to ATen node mapping: +# eq => eq +# mul => mul +# squeeze => squeeze +# sum_1 => sum_1 +# Graph fragment: +# %argmax : Tensor "i64[8, 2048][2048, 1]cuda:3" = PlaceHolder[target=argmax] +# %argmax_1 : Tensor "i64[8, 2048][2048, 1]cuda:3" = PlaceHolder[target=argmax_1] +# %arg2_1 : Tensor "i64[8, 2048, 1][2048, 1, 1]cuda:3" = PlaceHolder[target=arg2_1] +# %eq : Tensor "b8[8, 2048][2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {}) +# %squeeze : Tensor "i64[8, 2048][2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg2_1, -1), kwargs = {}) +# %mul : Tensor "i64[8, 2048][2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq, %squeeze), kwargs = {}) +# %sum_1 : Tensor "i64[][]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul,), kwargs = {}) +# return %buf3 +triton_red_fused_eq_mul_squeeze_sum_2 = async_compile.triton('triton_red_fused_eq_mul_squeeze_sum_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2, 'r0_': 8192}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 8, 'r0_': 393216}} +) +@triton.jit +def triton_red_fused_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2 + r0_numel = 8192 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.load(in_ptr1 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp4 = tl.load(in_ptr2 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask & xmask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + tl.store(out_ptr0 + (x0), tmp7, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/eu/ceuhopmcdleig6m43h7kk4fhghkl5w2umfjuyngxydc4pr3zpumg.py +# Topologically Sorted Source Nodes: [sum_2], Original ATen: [aten.sum] +# Source node to ATen node mapping: +# sum_2 => sum_2 +# Graph fragment: +# %arg3_1 : Tensor "i64[8, 2048, 1][2048, 1, 1]cuda:3" = PlaceHolder[target=arg3_1] +# %sum_2 : Tensor "i64[][]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg3_1,), kwargs = {}) +# return %buf5 +triton_red_fused_sum_3 = async_compile.triton('triton_red_fused_sum_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2, 'r0_': 8192}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 8, 'r0_': 131072}} +) +@triton.jit +def triton_red_fused_sum_3(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2 + r0_numel = 8192 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = _tmp2 + tmp1 + _tmp2 = tl.where(r0_mask & xmask, tmp3, _tmp2) + tmp2 = tl.sum(_tmp2, 1)[:, None] + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/mw/cmw2hjnuubs2eh7cuc54pem6cjhaz4jgplmqlhrsxfkzljxf7ndg.py +# Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1, sum_2, clamp_min, truediv], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum, aten.clamp_min, aten.div] +# Source node to ATen node mapping: +# clamp_min => clamp_min +# eq => eq +# mul => mul +# squeeze => squeeze +# sum_1 => sum_1 +# sum_2 => sum_2 +# truediv => div +# Graph fragment: +# %buf3 : Tensor "i64[2][1]cuda:3" = PlaceHolder[target=buf3] +# %buf5 : Tensor "i64[2][1]cuda:3" = PlaceHolder[target=buf5] +# %sum_1 : Tensor "i64[][]cuda:3" = PlaceHolder[target=sum_1] +# %sum_2 : Tensor "i64[][]cuda:3" = PlaceHolder[target=sum_2] +# %eq : Tensor "b8[8, 2048][2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {}) +# %squeeze : Tensor "i64[8, 2048][2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg2_1, -1), kwargs = {}) +# %mul : Tensor "i64[8, 2048][2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq, %squeeze), kwargs = {}) +# %sum_1 : Tensor "i64[][]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul,), kwargs = {}) +# %sum_2 : Tensor "i64[][]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg3_1,), kwargs = {}) +# %clamp_min : Tensor "f32[][]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%sum_2, 1e-06), kwargs = {}) +# %div : Tensor "f32[][]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, %clamp_min), kwargs = {}) +# return %sum_1,%sum_2,%div +triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4 = async_compile.triton('triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 1, 'r0_': 2}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'out_ptr2': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'r0_': 8}} +) +@triton.jit +def triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4(in_ptr0, in_ptr1, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 1 + r0_numel = 2 + R0_BLOCK: tl.constexpr = 2 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), None) + tmp4 = tl.load(in_ptr1 + (r0_0), None) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.sum(tmp1, 1)[:, None].to(tl.int64) + tmp5 = tl.broadcast_to(tmp4, [XBLOCK, R0_BLOCK]) + tmp7 = tl.sum(tmp5, 1)[:, None].to(tl.int64) + tmp8 = tmp3.to(tl.float32) + tmp9 = tmp7.to(tl.float32) + tmp10 = 1e-06 + tmp11 = triton_helpers.maximum(tmp9, tmp10) + tmp12 = (tmp8 / tmp11) + tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp12, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1 = args + args.clear() + assert_size_stride(arg0_1, (8, 2048, 32000), (65536000, 32000, 1)) + assert_size_stride(arg1_1, (8, 2048, 32000), (65760000, 32000, 1)) + assert_size_stride(arg2_1, (8, 2048, 1), (2048, 1, 1)) + assert_size_stride(arg3_1, (8, 2048, 1), (2048, 1, 1)) + with torch.cuda._DeviceGuard(3): + torch.cuda.set_device(3) + buf0 = empty_strided_cuda((8, 2048), (2048, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] + stream3 = get_raw_stream(3) + triton_red_fused_argmax_0.run(arg0_1, buf0, 16384, 32000, stream=stream3) + del arg0_1 + buf1 = empty_strided_cuda((8, 2048), (2048, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] + stream3 = get_raw_stream(3) + triton_red_fused_argmax_1.run(arg1_1, buf1, 16384, 32000, stream=stream3) + del arg1_1 + buf3 = empty_strided_cuda((2, ), (1, ), torch.int64) + # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum] + stream3 = get_raw_stream(3) + triton_red_fused_eq_mul_squeeze_sum_2.run(buf0, buf1, arg2_1, buf3, 2, 8192, stream=stream3) + del arg2_1 + del buf0 + del buf1 + buf5 = empty_strided_cuda((2, ), (1, ), torch.int64) + # Topologically Sorted Source Nodes: [sum_2], Original ATen: [aten.sum] + stream3 = get_raw_stream(3) + triton_red_fused_sum_3.run(arg3_1, buf5, 2, 8192, stream=stream3) + del arg3_1 + buf7 = empty_strided_cuda((), (), torch.float32) + # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1, sum_2, clamp_min, truediv], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum, aten.clamp_min, aten.div] + stream3 = get_raw_stream(3) + triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4.run(buf3, buf5, buf7, 1, 2, stream=stream3) + del buf3 + del buf5 + return (buf7, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((8, 2048, 32000), (65536000, 32000, 1), device='cuda:3', dtype=torch.bfloat16) + arg1_1 = rand_strided((8, 2048, 32000), (65760000, 32000, 1), device='cuda:3', dtype=torch.float32) + arg2_1 = rand_strided((8, 2048, 1), (2048, 1, 1), device='cuda:3', dtype=torch.int64) + arg3_1 = rand_strided((8, 2048, 1), (2048, 1, 1), device='cuda:3', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/5j/c5jvrblz5ym34kn4ssfnzfxabvx53ffrcnqlmwnjv3gmkqfkgo4v.py b/SpecForge-ext/cache/compiled_kernels/5j/c5jvrblz5ym34kn4ssfnzfxabvx53ffrcnqlmwnjv3gmkqfkgo4v.py new file mode 100644 index 0000000000000000000000000000000000000000..957e19a7d949031001fbdccc4bd3272419c65b64 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/5j/c5jvrblz5ym34kn4ssfnzfxabvx53ffrcnqlmwnjv3gmkqfkgo4v.py @@ -0,0 +1,1083 @@ +# AOT ID: ['13_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/s4/cs4afdu7ezaeekoshfdryoga6jabuq2nrx5xdkgxrehrkuvy5jri.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[2, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:3" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[2, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:3" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[2, 32, s37][32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf0] +# %full_default : Tensor "f32[2, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 32, %primals_10], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:3, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_17, %primals_20, %primals_22, %primals_25, %primals_27, %primals_30, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_14, %primals_15)), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_zeros_0 = async_compile.triton('triton_red_fused_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 131072, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % 32) + x2 = xindex // ks1 + x5 = triton_helpers.div_floor_integer(xindex, ks0) + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x4 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 4096*ks0*x2), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x0 + 128*x5*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3n/c3nhszg76l7meq3aapcfdfxknr3a44zlaammv6ewmbubopxjxjqh.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_2 : Tensor "bf16[2, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:3" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:3" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:3" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[2, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:3" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[2, 32, s37][32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[2, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:3" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[2, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:3" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:3" = PlaceHolder[target=getitem_5] +# %primals_13 : Tensor "i32[2, 1, s99][s99, s99, 1]cuda:3" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[2, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:3" = PlaceHolder[target=primals_9] +# %primals_22 : Tensor "i32[2, 1, s56][s56, s56, 1]cuda:3" = PlaceHolder[target=primals_22] +# %primals_25 : Tensor "i32[2, 1, s84, s53][s53*s84, s53*s84, s53, 1]cuda:3" = PlaceHolder[target=primals_25] +# %primals_17 : Tensor "i32[2, 1, s94][s94, s94, 1]cuda:3" = PlaceHolder[target=primals_17] +# %primals_20 : Tensor "i32[2, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:3" = PlaceHolder[target=primals_20] +# %primals_27 : Tensor "i32[2, 1, s100][s100, s100, 1]cuda:3" = PlaceHolder[target=primals_27] +# %primals_30 : Tensor "i32[2, 1, s6, s10][s10*s6, s10*s6, s10, 1]cuda:3" = PlaceHolder[target=primals_30] +# %primals_14 : Tensor "i64[2][1]cuda:3" = PlaceHolder[target=primals_14] +# %full_default : Tensor "f32[2, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 32, %primals_10], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:3, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_17, %primals_20, %primals_22, %primals_25, %primals_27, %primals_30, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_14, %primals_15)), kwargs = {}) +# return %getitem_4 +triton_tem_fused_zeros_1 = async_compile.triton('triton_tem_fused_zeros_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32', 'ks8': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128*ks1, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 2 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks8 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = ks8 + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_10, primals_11, primals_15, primals_7, primals_8, primals_12, primals_16, primals_18, primals_19, primals_21, primals_24, primals_23, primals_26, primals_29, primals_28, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_17, primals_20, primals_22, primals_25, primals_27, primals_30, getitem, getitem_1, tangents_1 = args + args.clear() + s37 = primals_10 + s0 = primals_11 + s75 = primals_15 + s22 = primals_7 + s72 = primals_8 + s99 = primals_12 + s94 = primals_16 + s28 = primals_18 + s4 = primals_19 + s56 = primals_21 + s53 = primals_24 + s84 = primals_23 + s100 = primals_26 + s10 = primals_29 + s6 = primals_28 + assert_size_stride(primals_2, (2, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_6, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_9, (2, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (2, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_14, (2, ), (1, )) + assert_size_stride(primals_17, (2, 1, s94), (s94, s94, 1)) + assert_size_stride(primals_20, (2, 1, s28, s4), (s28*s4, s28*s4, s4, 1)) + assert_size_stride(primals_22, (2, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_25, (2, 1, s84, s53), (s53*s84, s53*s84, s53, 1)) + assert_size_stride(primals_27, (2, 1, s100), (s100, s100, 1)) + assert_size_stride(primals_30, (2, 1, s6, s10), (s10*s6, s10*s6, s10, 1)) + assert_size_stride(getitem, (2, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(getitem_1, (2, 32, s37), (32*max(1, s37), max(1, s37), 1)) + assert_size_stride(tangents_1, (2, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1)) + with torch.cuda._DeviceGuard(3): + torch.cuda.set_device(3) + ps0 = 32*s37 + buf1 = empty_strided_cuda((2, 32, s37), (32*s37, s37, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + triton_red_fused_zeros_0_xnumel = 64*s37 + stream3 = get_raw_stream(3) + triton_red_fused_zeros_0.run(getitem, tangents_1, buf1, s37, ps0, triton_red_fused_zeros_0_xnumel, 128, stream=stream3) + del getitem + buf3 = empty_strided_cuda((2, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((2, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16) + buf5 = empty_strided_cuda((2, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream3 = get_raw_stream(3) + triton_tem_fused_zeros_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_13, primals_9, primals_22, primals_25, primals_17, primals_20, primals_27, primals_30, primals_14, buf5, s37, s0, s99, s22, s72, s56, s53, s84, s75, 4*((127 + s37) // 128) + ((127 + s0) // 128), 2, 8, stream=stream3) + del buf1 + del getitem_1 + del primals_13 + del primals_14 + del primals_17 + del primals_2 + del primals_20 + del primals_22 + del primals_25 + del primals_27 + del primals_30 + del primals_4 + del primals_6 + del primals_9 + del tangents_1 + return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_10 = 2014 + primals_11 = 2014 + primals_15 = 2014 + primals_7 = 16 + primals_8 = 16 + primals_12 = 16 + primals_16 = 16 + primals_18 = 16 + primals_19 = 16 + primals_21 = 16 + primals_24 = 16 + primals_23 = 16 + primals_26 = 16 + primals_29 = 16 + primals_28 = 16 + primals_2 = rand_strided((2, 32, 2014, 128), (8249344, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16) + primals_4 = rand_strided((2, 8, 2014, 128), (2062336, 257792, 128, 1), device='cuda:3', dtype=torch.bfloat16) + primals_6 = rand_strided((2, 8, 2014, 128), (2062336, 257792, 128, 1), device='cuda:3', dtype=torch.bfloat16) + primals_9 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:3', dtype=torch.int32) + primals_13 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32) + primals_14 = rand_strided((2, ), (1, ), device='cuda:3', dtype=torch.int64) + primals_17 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32) + primals_20 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:3', dtype=torch.int32) + primals_22 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32) + primals_25 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:3', dtype=torch.int32) + primals_27 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32) + primals_30 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:3', dtype=torch.int32) + getitem = rand_strided((2, 32, 2014, 128), (8249344, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16) + getitem_1 = rand_strided((2, 32, 2014), (64448, 2014, 1), device='cuda:3', dtype=torch.float32) + tangents_1 = rand_strided((2, 32, 2014, 128), (8249344, 257792, 128, 1), device='cuda:3', dtype=torch.bfloat16) + fn = lambda: call([primals_10, primals_11, primals_15, primals_7, primals_8, primals_12, primals_16, primals_18, primals_19, primals_21, primals_24, primals_23, primals_26, primals_29, primals_28, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_17, primals_20, primals_22, primals_25, primals_27, primals_30, getitem, getitem_1, tangents_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/5o/c5oswnr7dpwwcqp5m7thpaf4owvpseka6amdvdroewnorzsre6n2.py b/SpecForge-ext/cache/compiled_kernels/5o/c5oswnr7dpwwcqp5m7thpaf4owvpseka6amdvdroewnorzsre6n2.py new file mode 100644 index 0000000000000000000000000000000000000000..e0cee102b98273d187c0268d422802b1a24e8bc2 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/5o/c5oswnr7dpwwcqp5m7thpaf4owvpseka6amdvdroewnorzsre6n2.py @@ -0,0 +1,56 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/5o/fd9062ce8e19c42a2ac9826803e021d98494b253c62ff7bfe753f34e0c863929.best_config b/SpecForge-ext/cache/compiled_kernels/5o/fd9062ce8e19c42a2ac9826803e021d98494b253c62ff7bfe753f34e0c863929.best_config new file mode 100644 index 0000000000000000000000000000000000000000..3fc56f57c375dacceeb71ea2e7a129667d8c493f --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/5o/fd9062ce8e19c42a2ac9826803e021d98494b253c62ff7bfe753f34e0c863929.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 50, "triton_cache_hash": "NFABHOURJ57C2IKXWDMS2VHZ76PCVKJVD7V6CBWJDLMT5TQE5GFA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/5r/67cf09099929c6923eb0884c24406ef33ddeccd796aae57f0484cb9e81164741.best_config b/SpecForge-ext/cache/compiled_kernels/5r/67cf09099929c6923eb0884c24406ef33ddeccd796aae57f0484cb9e81164741.best_config new file mode 100644 index 0000000000000000000000000000000000000000..a570e8d663ff6e600f50df05a811c859065ec3c4 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/5r/67cf09099929c6923eb0884c24406ef33ddeccd796aae57f0484cb9e81164741.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 21, "triton_cache_hash": "Z2RWAHMO7VUWQKIIRA5A46JYV2SEXHWLKREQM7TOP6VGUWDXAYAQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/5r/a92927e5b439ed1b110bb08d838b1d456f558deda739d61f97499093c88a877a.best_config b/SpecForge-ext/cache/compiled_kernels/5r/a92927e5b439ed1b110bb08d838b1d456f558deda739d61f97499093c88a877a.best_config new file mode 100644 index 0000000000000000000000000000000000000000..96dd92ec5b0239e781a57a11e2928c5c0f286636 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/5r/a92927e5b439ed1b110bb08d838b1d456f558deda739d61f97499093c88a877a.best_config @@ -0,0 +1 @@ +{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "42NVHDOVRHC3TSIT2M6NVJU72L5EVVTGFXWS47GDCP2GM2XRN7KA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/5r/c5rruemruzlohybkl4bagtqtk5athtuxf3eoj37rjptdltrrri3r.py b/SpecForge-ext/cache/compiled_kernels/5r/c5rruemruzlohybkl4bagtqtk5athtuxf3eoj37rjptdltrrri3r.py new file mode 100644 index 0000000000000000000000000000000000000000..dc9d49965895d9f6985a8f0ae22a40e1f65e9b18 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/5r/c5rruemruzlohybkl4bagtqtk5athtuxf3eoj37rjptdltrrri3r.py @@ -0,0 +1,56 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4194304}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/5r/c5rs7ak7dbb5csdzszewryjtnxnhv7xpwdjvgqrsp5lfmmej4poi.py b/SpecForge-ext/cache/compiled_kernels/5r/c5rs7ak7dbb5csdzszewryjtnxnhv7xpwdjvgqrsp5lfmmej4poi.py new file mode 100644 index 0000000000000000000000000000000000000000..c5158fee48d05bd1c30f21f70141771318cbd3e0 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/5r/c5rs7ak7dbb5csdzszewryjtnxnhv7xpwdjvgqrsp5lfmmej4poi.py @@ -0,0 +1,24 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 256}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/5w/c5w735qbviioww7vfjj36tk57xo254oei3wqkunaiekkjd5pfcph.py b/SpecForge-ext/cache/compiled_kernels/5w/c5w735qbviioww7vfjj36tk57xo254oei3wqkunaiekkjd5pfcph.py new file mode 100644 index 0000000000000000000000000000000000000000..d248b3fd0f5f6c888333b3e7c71eaf38bba11907 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/5w/c5w735qbviioww7vfjj36tk57xo254oei3wqkunaiekkjd5pfcph.py @@ -0,0 +1,43 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 128, 'r0_': 16}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_sum_2(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + x2 = (xindex % ks1) + x3 = xindex // ks1 + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x2 + x3*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp5, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/5z/c5z5oj5ee2bvvg2pkzwf6smszdy73565nillm7gopvokmvrvu2dp.py b/SpecForge-ext/cache/compiled_kernels/5z/c5z5oj5ee2bvvg2pkzwf6smszdy73565nillm7gopvokmvrvu2dp.py new file mode 100644 index 0000000000000000000000000000000000000000..7a1fd4de25a335da1a22cf07b87f1e9f2fb7a053 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/5z/c5z5oj5ee2bvvg2pkzwf6smszdy73565nillm7gopvokmvrvu2dp.py @@ -0,0 +1,711 @@ +# AOT ID: ['13_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +from torch._C import _cuda_getCurrentRawStream as get_raw_stream +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ik/ciksm4jphopwjgs55fbipcxecpw4d643lh76mj27636ryec4e3kg.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_2 : Tensor "bf16[2, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:4" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:4" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:4" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[2, 32, s37][32*s37, s37, 1]cuda:4" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[2, 32, s37][32*s37, s37, 1]cuda:4" = PlaceHolder[target=buf1] +# %primals_13 : Tensor "i32[2, 1, s99][s99, s99, 1]cuda:4" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[2, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:4" = PlaceHolder[target=primals_9] +# %primals_17 : Tensor "i32[2, 1, s94][s94, s94, 1]cuda:4" = PlaceHolder[target=primals_17] +# %primals_20 : Tensor "i32[2, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:4" = PlaceHolder[target=primals_20] +# %primals_14 : Tensor "i64[2][1]cuda:4" = PlaceHolder[target=primals_14] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_17, %primals_20, %primals_22, %primals_25, %primals_27, %primals_30, 128, 128, %sdpa_mask0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_14, %primals_15)), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 2 + HQ = 32 + Q_LEN = ks0 + ZKV = 2 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 4096*idx_zq*ks0, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks5 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30 = args + args.clear() + s50 = primals_1 + s0 = primals_3 + s43 = primals_5 + s22 = primals_7 + s72 = primals_8 + s37 = primals_10 + s71 = primals_11 + s99 = primals_12 + s75 = primals_15 + s94 = primals_16 + s28 = primals_18 + s4 = primals_19 + s56 = primals_21 + s84 = primals_23 + s53 = primals_24 + s100 = primals_26 + s6 = primals_28 + s10 = primals_29 + assert_size_stride(primals_2, (2, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_6, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_9, (2, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (2, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_14, (2, ), (1, )) + assert_size_stride(primals_17, (2, 1, s94), (s94, s94, 1)) + assert_size_stride(primals_20, (2, 1, s28, s4), (s28*s4, s28*s4, s4, 1)) + assert_size_stride(primals_22, (2, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_25, (2, 1, s84, s53), (s53*s84, s53*s84, s53, 1)) + assert_size_stride(primals_27, (2, 1, s100), (s100, s100, 1)) + assert_size_stride(primals_30, (2, 1, s6, s10), (s10*s6, s10*s6, s10, 1)) + with torch.cuda._DeviceGuard(4): + torch.cuda.set_device(4) + buf0 = empty_strided_cuda((2, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((2, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((2, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream4 = get_raw_stream(4) + triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_13, primals_9, primals_17, primals_20, primals_14, buf2, s37, s0, s99, s22, s72, s75, (127 + s37) // 128, 2, 32, stream=stream4) + del buf1 + return (buf2, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_17, primals_20, primals_22, primals_25, primals_27, primals_30, buf2, buf0, s37, s0, s75, s22, s72, s99, s94, s28, s4, s56, s53, s84, s100, s10, s6, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 1543 + primals_2 = rand_strided((2, 32, 1543, 128), (6320128, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16) + primals_3 = 1543 + primals_4 = rand_strided((2, 8, 1543, 128), (1580032, 197504, 128, 1), device='cuda:4', dtype=torch.bfloat16) + primals_5 = 1543 + primals_6 = rand_strided((2, 8, 1543, 128), (1580032, 197504, 128, 1), device='cuda:4', dtype=torch.bfloat16) + primals_7 = 13 + primals_8 = 13 + primals_9 = rand_strided((2, 1, 13, 13), (169, 169, 13, 1), device='cuda:4', dtype=torch.int32) + primals_10 = 1543 + primals_11 = 1543 + primals_12 = 13 + primals_13 = rand_strided((2, 1, 13), (13, 13, 1), device='cuda:4', dtype=torch.int32) + primals_14 = rand_strided((2, ), (1, ), device='cuda:4', dtype=torch.int64) + primals_15 = 1543 + primals_16 = 13 + primals_17 = rand_strided((2, 1, 13), (13, 13, 1), device='cuda:4', dtype=torch.int32) + primals_18 = 13 + primals_19 = 13 + primals_20 = rand_strided((2, 1, 13, 13), (169, 169, 13, 1), device='cuda:4', dtype=torch.int32) + primals_21 = 13 + primals_22 = rand_strided((2, 1, 13), (13, 13, 1), device='cuda:4', dtype=torch.int32) + primals_23 = 13 + primals_24 = 13 + primals_25 = rand_strided((2, 1, 13, 13), (169, 169, 13, 1), device='cuda:4', dtype=torch.int32) + primals_26 = 13 + primals_27 = rand_strided((2, 1, 13), (13, 13, 1), device='cuda:4', dtype=torch.int32) + primals_28 = 13 + primals_29 = 13 + primals_30 = rand_strided((2, 1, 13, 13), (169, 169, 13, 1), device='cuda:4', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/6j/c6jzxztdxbjv5b23nfmgzgtizqp77h7aeak5j2jukmz3roqeiw3k.py b/SpecForge-ext/cache/compiled_kernels/6j/c6jzxztdxbjv5b23nfmgzgtizqp77h7aeak5j2jukmz3roqeiw3k.py new file mode 100644 index 0000000000000000000000000000000000000000..1438ea9eacac1b8aedacd52a6ff092cd84811474 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/6j/c6jzxztdxbjv5b23nfmgzgtizqp77h7aeak5j2jukmz3roqeiw3k.py @@ -0,0 +1,24 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 512}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/6j/e9590d30530b6f20cd8332cd18dfb56bc33c5ce0f73ebafd83fbd8da1a7ab8fe.best_config b/SpecForge-ext/cache/compiled_kernels/6j/e9590d30530b6f20cd8332cd18dfb56bc33c5ce0f73ebafd83fbd8da1a7ab8fe.best_config new file mode 100644 index 0000000000000000000000000000000000000000..96dd92ec5b0239e781a57a11e2928c5c0f286636 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/6j/e9590d30530b6f20cd8332cd18dfb56bc33c5ce0f73ebafd83fbd8da1a7ab8fe.best_config @@ -0,0 +1 @@ +{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "42NVHDOVRHC3TSIT2M6NVJU72L5EVVTGFXWS47GDCP2GM2XRN7KA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/6m/c6mwcfy2ykv3p5alrzh4sx4ajhl5davetqobw2pytyc2kalbo2wk.py b/SpecForge-ext/cache/compiled_kernels/6m/c6mwcfy2ykv3p5alrzh4sx4ajhl5davetqobw2pytyc2kalbo2wk.py new file mode 100644 index 0000000000000000000000000000000000000000..ebeb6e2d229ad0d6422e9385c7aa44f962abd378 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/6m/c6mwcfy2ykv3p5alrzh4sx4ajhl5davetqobw2pytyc2kalbo2wk.py @@ -0,0 +1,168 @@ +# AOT ID: ['10_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/mb/cmboruk2gyuhq43degftaqzb2abxergkmetbzmcgprn7eynqywpe.py +# Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul] +# Source node to ATen node mapping: +# getitem_1 => unsqueeze +# position_mask => mul_6 +# target_mask => index +# target_mask_1 => convert_element_type +# target_max_token => argmax +# Graph fragment: +# %arg1_1 : Tensor "bf16[8, s14, 151936][151936*s14, 151936, 1]cuda:7" = PlaceHolder[target=arg1_1] +# %argmax : Tensor "i64[8, s14][s14, 1]cuda:7" = PlaceHolder[target=argmax] +# %arg2_1 : Tensor "b8[151936][1]cuda:7" = PlaceHolder[target=arg2_1] +# %arg3_1 : Tensor "i64[8, s14, 1][s14, 1, 1]cuda:7" = PlaceHolder[target=arg3_1] +# %argmax : Tensor "i64[8, s14][s14, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {}) +# %index : Tensor "b8[8, s14][s14, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%argmax]), kwargs = {}) +# %unsqueeze : Tensor "b8[8, s14, 1][s14, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 2), kwargs = {}) +# %convert_element_type : Tensor "i32[8, s14, 1][s14, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze, torch.int32), kwargs = {}) +# %mul_6 : Tensor "i64[8, s14, 1][s14, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %arg3_1), kwargs = {}) +# return %argmax,%mul_6 +triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0 = async_compile.triton('triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 262144}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i64', 'r0_numel': 'i64', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 151936 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0).to(tl.int64) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None].to(tl.int64) + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :].to(tl.int64) + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 9223372036854775807, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tmp11 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last') + tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32) + tmp4 = tmp2 + tmp3 + tmp5 = tmp2 < 0 + tmp6 = tl.where(tmp5, tmp4, tmp2) + tl.device_assert(((0 <= tmp6) & (tmp6 < 151936)) | ~(xmask), "index out of bounds: 0 <= tmp6 < 151936") + tmp8 = tl.load(in_ptr1 + (tmp6), xmask, eviction_policy='evict_last').to(tl.int1) + tmp9 = tmp8.to(tl.int32) + tmp10 = tmp9.to(tl.int64) + tmp12 = tmp10 * tmp11 + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1 = args + args.clear() + s24 = arg0_1 + arg1_1_size = arg1_1.size() + s14 = arg1_1_size[1] + assert_size_stride(arg1_1, (8, s14, 151936), (151936*s14, 151936, 1)) + assert_size_stride(arg2_1, (151936, ), (1, )) + assert_size_stride(arg3_1, (8, s14, 1), (s14, 1, 1)) + with torch.cuda._DeviceGuard(7): + torch.cuda.set_device(7) + buf0 = empty_strided_cuda((8, s14), (s14, 1), torch.int64) + buf1 = reinterpret_tensor(buf0, (8, s14, 1), (s14, 1, 1), 0); del buf0 # reuse + # Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul] + triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0_xnumel = 8*s14 + stream7 = get_raw_stream(7) + triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0.run(buf1, arg1_1, arg2_1, arg3_1, triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0_xnumel, 151936, stream=stream7) + del arg1_1 + del arg2_1 + del arg3_1 + return (buf1, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 2025 + arg1_1 = rand_strided((8, 2025, 151936), (307670400, 151936, 1), device='cuda:7', dtype=torch.bfloat16) + arg2_1 = rand_strided((151936, ), (1, ), device='cuda:7', dtype=torch.bool) + arg3_1 = rand_strided((8, 2025, 1), (2025, 1, 1), device='cuda:7', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/6o/c6o7jlqhfbi4ry3uni47hefilsmfptqopfdxwc3plgg65s2mqzse.py b/SpecForge-ext/cache/compiled_kernels/6o/c6o7jlqhfbi4ry3uni47hefilsmfptqopfdxwc3plgg65s2mqzse.py new file mode 100644 index 0000000000000000000000000000000000000000..293edc5bd58aae994976a0dc62d6803bd5843d9d --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/6o/c6o7jlqhfbi4ry3uni47hefilsmfptqopfdxwc3plgg65s2mqzse.py @@ -0,0 +1,322 @@ +# AOT ID: ['4_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/gy/cgypquf4bysldt6yik5b24uoywlbfrbaqlpvmqscjgvur4u7ckpi.py +# Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add] +# Source node to ATen node mapping: +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# Graph fragment: +# %tangents_2 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:5" = PlaceHolder[target=tangents_2] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:5" = PlaceHolder[target=primals_8] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:5" = PlaceHolder[target=primals_6] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:5" = PlaceHolder[target=primals_4] +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %mul_84 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, %unsqueeze_1), kwargs = {}) +# %slice_5 : Tensor "bf16[s48, s25, s9, s24 - ((s24//2))][s24*s25*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_84, 3, 0, %add_96), kwargs = {}) +# %slice_6 : Tensor "bf16[s48, s25, s9, (s24//2)][s24*s25*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_84, 3, %sub_72, %primals_2), kwargs = {}) +# %neg_2 : Tensor "bf16[s48, s25, s9, s24 - ((s24//2))][s25*s9*Max(1, s24 - ((s24//2))), s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_5,), kwargs = {}) +# %full_default : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.full.default](args = ([%primals_10, %primals_13, %primals_7, %primals_2], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:5, pin_memory: False}) +# %slice_scatter_default : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default, %neg_2, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %slice_scatter_default_1 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default, %slice_6, 3, 0, %floordiv), kwargs = {}) +# %add_100 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default, %slice_scatter_default_1), kwargs = {}) +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %mul_85 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, %unsqueeze), kwargs = {}) +# %add_101 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_100, %mul_85), kwargs = {}) +# return %add_101 +triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0 = async_compile.triton('triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4194304}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bp/cbpuekklhjdszdnpjmnzg77zhi5rum3iueweicitfxwda6abrl2a.py +# Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add] +# Source node to ATen node mapping: +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# Graph fragment: +# %tangents_1 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:5" = PlaceHolder[target=tangents_1] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:5" = PlaceHolder[target=primals_8] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:5" = PlaceHolder[target=primals_6] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:5" = PlaceHolder[target=primals_4] +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %mul_86 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %unsqueeze_1), kwargs = {}) +# %slice_7 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s24*s34*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_86, 3, 0, %sub_72), kwargs = {}) +# %slice_8 : Tensor "bf16[s48, s34, s9, (s24//2)][s24*s34*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_86, 3, %sub_72, %primals_2), kwargs = {}) +# %neg_3 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s34*s9*Max(1, s24 - ((s24//2))), s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_7,), kwargs = {}) +# %full_default_2 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.full.default](args = ([%primals_10, %primals_11, %primals_7, %primals_2], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:5, pin_memory: False}) +# %slice_scatter_default_2 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_2, %neg_3, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %slice_scatter_default_3 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_2, %slice_8, 3, 0, %floordiv), kwargs = {}) +# %add_106 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default_2, %slice_scatter_default_3), kwargs = {}) +# %mul_87 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %unsqueeze), kwargs = {}) +# %add_107 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_106, %mul_87), kwargs = {}) +# return %add_107 +triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1 = async_compile.triton('triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_2, primals_7, primals_10, primals_11, primals_13, primals_1, primals_3, primals_5, floordiv, add_96, primals_4, primals_6, primals_8, tangents_1, tangents_2 = args + args.clear() + s24 = primals_2 + s9 = primals_7 + s48 = primals_10 + s34 = primals_11 + s25 = primals_13 + s92 = primals_1 + s96 = primals_3 + s79 = primals_5 + assert_size_stride(primals_4, (1, 1, s92, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_6, (1, 1, s79, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_8, (1, s9), (s9, 1)) + assert_size_stride(tangents_1, (s48, s34, s9, s24), (s24*s34*s9, s24*s9, s24, 1)) + assert_size_stride(tangents_2, (s48, s25, s9, s24), (s24*s25*s9, s24*s9, s24, 1)) + with torch.cuda._DeviceGuard(5): + torch.cuda.set_device(5) + buf0 = empty_strided_cuda((s48, s25, s9, s24), (s24*s25*s9, s24*s9, s24, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add] + triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0_xnumel = s24*s25*s48*s9 + stream5 = get_raw_stream(5) + triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0.run(tangents_2, primals_8, primals_6, primals_4, buf0, s24, s9, s79, s92, triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0_xnumel, stream=stream5) + del tangents_2 + buf1 = empty_strided_cuda((s48, s34, s9, s24), (s24*s34*s9, s24*s9, s24, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add] + triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1_xnumel = s24*s34*s48*s9 + stream5 = get_raw_stream(5) + triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1.run(tangents_1, primals_8, primals_6, primals_4, buf1, s24, s9, s79, s92, triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1_xnumel, stream=stream5) + del primals_4 + del primals_6 + del primals_8 + del tangents_1 + return (None, None, None, None, None, None, None, None, None, None, None, buf1, None, buf0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_2 = 128 + primals_7 = 2048 + primals_10 = 2 + primals_11 = 32 + primals_13 = 8 + primals_1 = 2048 + primals_3 = 5245440 + primals_5 = 2048 + floordiv = 64 + add_96 = 64 + primals_4 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:5', dtype=torch.bfloat16) + primals_6 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:5', dtype=torch.bfloat16) + primals_8 = rand_strided((1, 2048), (2048, 1), device='cuda:5', dtype=torch.int64) + tangents_1 = rand_strided((2, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:5', dtype=torch.bfloat16) + tangents_2 = rand_strided((2, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:5', dtype=torch.bfloat16) + fn = lambda: call([primals_2, primals_7, primals_10, primals_11, primals_13, primals_1, primals_3, primals_5, floordiv, add_96, primals_4, primals_6, primals_8, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/6r/734ee9f72fcbbc036c304bd9fc428175dc6febf6da61f182679d20ad4d8b7f41.best_config b/SpecForge-ext/cache/compiled_kernels/6r/734ee9f72fcbbc036c304bd9fc428175dc6febf6da61f182679d20ad4d8b7f41.best_config new file mode 100644 index 0000000000000000000000000000000000000000..b9c83cd70cc4f7d46eca037549afe001d843ad6c --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/6r/734ee9f72fcbbc036c304bd9fc428175dc6febf6da61f182679d20ad4d8b7f41.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 49, "triton_cache_hash": "Z2RWAHMO7VUWQKIIRA5A46JYV2SEXHWLKREQM7TOP6VGUWDXAYAQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/6r/c6r6adrqwwhzfcdd5cyhmwl3cptpvwwhedzdpranw7esxeg5oyia.py b/SpecForge-ext/cache/compiled_kernels/6r/c6r6adrqwwhzfcdd5cyhmwl3cptpvwwhedzdpranw7esxeg5oyia.py new file mode 100644 index 0000000000000000000000000000000000000000..d8edf768799bc1552b63e856dc666fa48a5a3090 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/6r/c6r6adrqwwhzfcdd5cyhmwl3cptpvwwhedzdpranw7esxeg5oyia.py @@ -0,0 +1,56 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/6r/c6rbvgm53jr3nux66durqhisanccgaebzxcdjdhdrphqjpyu2t5r.py b/SpecForge-ext/cache/compiled_kernels/6r/c6rbvgm53jr3nux66durqhisanccgaebzxcdjdhdrphqjpyu2t5r.py new file mode 100644 index 0000000000000000000000000000000000000000..dff889f2e461a03841d2d244b1b09652c4b67470 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/6r/c6rbvgm53jr3nux66durqhisanccgaebzxcdjdhdrphqjpyu2t5r.py @@ -0,0 +1,62 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32) + _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + + _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine( + _tmp3_max, _tmp3_sum, tmp2, False + ) + + _tmp3_max = tl.where(r0_mask & xmask, _tmp3_max_next, _tmp3_max) + _tmp3_sum = tl.where(r0_mask & xmask, _tmp3_sum_next, _tmp3_sum) + + tmp3, tmp4 = triton_helpers.online_softmax_reduce( + _tmp3_max, _tmp3_sum, 1, False) + tmp3 = tmp3[:, None] + tmp4 = tmp4[:, None] + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp6 - tmp3 + tmp8 = libdevice.exp(tmp7) + tmp9 = (tmp8 / tmp4) + tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask & xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/6w/c6wvripzhtxw7rn3rkxdy7t5nn75rtccbq7dutodyr7a2qtepl2o.py b/SpecForge-ext/cache/compiled_kernels/6w/c6wvripzhtxw7rn3rkxdy7t5nn75rtccbq7dutodyr7a2qtepl2o.py new file mode 100644 index 0000000000000000000000000000000000000000..5f93d8144cdbdd9e3c35d43b18ff07ddece0b5c8 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/6w/c6wvripzhtxw7rn3rkxdy7t5nn75rtccbq7dutodyr7a2qtepl2o.py @@ -0,0 +1,62 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'in_ptr3': '*i64', 'out_ptr2': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'r0_': 131072}} +) +@triton.jit +def triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 1 + r0_numel = 4096 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp4 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + _tmp11 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp9 = tl.load(in_ptr3 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = _tmp11 + tmp10 + _tmp11 = tl.where(r0_mask, tmp12, _tmp11) + tmp11 = tl.sum(_tmp11, 1)[:, None] + tmp13 = tmp7.to(tl.float32) + tmp14 = tmp11.to(tl.float32) + tmp15 = 1e-06 + tmp16 = triton_helpers.maximum(tmp14, tmp15) + tmp17 = (tmp13 / tmp16) + tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp17, None) diff --git a/SpecForge-ext/cache/compiled_kernels/6z/c6zba2r22yyctp3hlaofoasgbhbtwqg7txp443m7rffkvpbhn34q.py b/SpecForge-ext/cache/compiled_kernels/6z/c6zba2r22yyctp3hlaofoasgbhbtwqg7txp443m7rffkvpbhn34q.py new file mode 100644 index 0000000000000000000000000000000000000000..14913c870c4e282c8c403596c9504e38a2b7048f --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/6z/c6zba2r22yyctp3hlaofoasgbhbtwqg7txp443m7rffkvpbhn34q.py @@ -0,0 +1,48 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 65536, 'r0_': 524288000}} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 4096 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = xindex // 2048 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + 65760000*x1), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, None) diff --git a/SpecForge-ext/cache/compiled_kernels/7a/c7avdnhdkg25qkzpvb4jgb3wfrta3u7po7rrnynujrskgetlvslk.py b/SpecForge-ext/cache/compiled_kernels/7a/c7avdnhdkg25qkzpvb4jgb3wfrta3u7po7rrnynujrskgetlvslk.py new file mode 100644 index 0000000000000000000000000000000000000000..dc448a9819e802853b414b734f5198b8c8ae4e7d --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/7a/c7avdnhdkg25qkzpvb4jgb3wfrta3u7po7rrnynujrskgetlvslk.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 2097152, 262144, 128, 1 + + ZQ = 8 + HQ = 32 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/7d/c7dz7a56ncjefkfohcx6ikzhlzkt7ndm7e4lx7jklfxmuwr6ixv4.py b/SpecForge-ext/cache/compiled_kernels/7d/c7dz7a56ncjefkfohcx6ikzhlzkt7ndm7e4lx7jklfxmuwr6ixv4.py new file mode 100644 index 0000000000000000000000000000000000000000..a744745ab7ce71fca3ccb46c19a0d24397f9e0ec --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/7d/c7dz7a56ncjefkfohcx6ikzhlzkt7ndm7e4lx7jklfxmuwr6ixv4.py @@ -0,0 +1,161 @@ +# AOT ID: ['11_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/nv/cnvg2yg4bk5hvx3icicawj3fv22kweu7g2qr57nzvvakuyk5ie46.py +# Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax] +# Source node to ATen node mapping: +# target_head => convert_element_type +# target_p => div +# Graph fragment: +# %arg1_1 : Tensor "bf16[2, s67, 32000][32000*s67, 32000, 1]cuda:0" = PlaceHolder[target=arg1_1] +# %getitem : Tensor "f32[2, s67, 1][s67, 1, 2*s67]cuda:0" = PlaceHolder[target=getitem] +# %getitem_1 : Tensor "f32[2, s67, 1][s67, 1, 2*s67]cuda:0" = PlaceHolder[target=getitem_1] +# %convert_element_type : Tensor "f32[2, s67, 32000][32000*s67, 32000, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%arg1_1, torch.float32), kwargs = {}) +# %prepare_softmax_online_default : [num_users=2] = call_function[target=torch.ops.prims.prepare_softmax_online.default](args = (%convert_element_type, 2), kwargs = {}) +# %sub_tensor : Tensor "f32[2, s67, 32000][32000*s67, 32000, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type, %getitem), kwargs = {}) +# %exp_default : Tensor "f32[2, s67, 32000][32000*s67, 32000, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub_tensor,), kwargs = {}) +# %div : Tensor "f32[2, s67, 32000][32000*s67, 32000, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%exp_default, %getitem_1), kwargs = {}) +# return %getitem,%getitem_1,%div +triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0 = async_compile.triton('triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32) + _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + + _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine( + _tmp3_max, _tmp3_sum, tmp2, False + ) + + _tmp3_max = tl.where(r0_mask & xmask, _tmp3_max_next, _tmp3_max) + _tmp3_sum = tl.where(r0_mask & xmask, _tmp3_sum_next, _tmp3_sum) + + tmp3, tmp4 = triton_helpers.online_softmax_reduce( + _tmp3_max, _tmp3_sum, 1, False) + tmp3 = tmp3[:, None] + tmp4 = tmp4[:, None] + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp6 - tmp3 + tmp8 = libdevice.exp(tmp7) + tmp9 = (tmp8 / tmp4) + tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask & xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1 = args + args.clear() + s67 = arg0_1 + assert_size_stride(arg1_1, (2, s67, 32000), (32000*s67, 32000, 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf2 = empty_strided_cuda((2, s67, 32000), (32000*s67, 32000, 1), torch.float32) + # Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax] + triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0_xnumel = 2*s67 + stream0 = get_raw_stream(0) + triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0.run(arg1_1, buf2, triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0_xnumel, 32000, stream=stream0) + del arg1_1 + return (buf2, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 1130 + arg1_1 = rand_strided((2, 1130, 32000), (36160000, 32000, 1), device='cuda:0', dtype=torch.bfloat16) + fn = lambda: call([arg0_1, arg1_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/7n/c7n5dmnt7xgsigufe4yi33j6jfpm6swr4c6a4ixhpeelccbiypdq.py b/SpecForge-ext/cache/compiled_kernels/7n/c7n5dmnt7xgsigufe4yi33j6jfpm6swr4c6a4ixhpeelccbiypdq.py new file mode 100644 index 0000000000000000000000000000000000000000..5d9409cf589a2b9a65004f54cd08f0051eaa3c82 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/7n/c7n5dmnt7xgsigufe4yi33j6jfpm6swr4c6a4ixhpeelccbiypdq.py @@ -0,0 +1,675 @@ +# AOT ID: ['6_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +from torch._C import _cuda_getCurrentRawStream as get_raw_stream +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/s3/cs3yasiv6lrv3gv7zuav4rb3xghwnuf7k2xh3nhv5aukawhau2k2.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:0" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[2, 8, 2048, 128][2097152, 262144, 128, 1]cuda:0" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[2, 8, 2048, 128][2097152, 262144, 128, 1]cuda:0" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:0" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:0" = PlaceHolder[target=buf1] +# %primals_5 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:0" = PlaceHolder[target=primals_5] +# %primals_4 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:0" = PlaceHolder[target=primals_4] +# %primals_7 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:0" = PlaceHolder[target=primals_7] +# %primals_8 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:0" = PlaceHolder[target=primals_8] +# %primals_6 : Tensor "i64[2][1]cuda:0" = PlaceHolder[target=primals_6] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_2, %primals_3, %sdpa_score0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %sdpa_mask0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 2097152, 262144, 128, 1 + + ZQ = 2 + HQ = 32 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12 = args + args.clear() + assert_size_stride(primals_1, (2, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_2, (2, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_3, (2, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_4, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_5, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_6, (2, ), (1, )) + assert_size_stride(primals_7, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_8, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_9, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_11, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_12, (2, 1, 16, 16), (256, 256, 16, 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf0 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32) + buf1 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32) + buf2 = empty_strided_cuda((2, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream0 = get_raw_stream(0) + triton_tem_fused_0.run(primals_1, primals_2, primals_3, buf0, buf1, primals_5, primals_4, primals_7, primals_8, primals_6, buf2, 16, 2, 32, stream=stream0) + del buf1 + return (buf2, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, buf2, buf0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:0', dtype=torch.bfloat16) + primals_2 = rand_strided((2, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) + primals_3 = rand_strided((2, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) + primals_4 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:0', dtype=torch.int32) + primals_5 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:0', dtype=torch.int32) + primals_6 = rand_strided((2, ), (1, ), device='cuda:0', dtype=torch.int64) + primals_7 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:0', dtype=torch.int32) + primals_8 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:0', dtype=torch.int32) + primals_9 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:0', dtype=torch.int32) + primals_10 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:0', dtype=torch.int32) + primals_11 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:0', dtype=torch.int32) + primals_12 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:0', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/a3/ca3dtjga4dneofjf6uchajqfop34u2rd7jlztuwf6iiqiojwc5qt.py b/SpecForge-ext/cache/compiled_kernels/a3/ca3dtjga4dneofjf6uchajqfop34u2rd7jlztuwf6iiqiojwc5qt.py new file mode 100644 index 0000000000000000000000000000000000000000..a5755739f3df75b6006cae316e92423febb100dd --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/a3/ca3dtjga4dneofjf6uchajqfop34u2rd7jlztuwf6iiqiojwc5qt.py @@ -0,0 +1,47 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + ks1*x1), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/a5/ca5ifqrw6w6uu2recydoe7f7w772u2wc3ax4txzstyyzzr5pvmfa.py b/SpecForge-ext/cache/compiled_kernels/a5/ca5ifqrw6w6uu2recydoe7f7w772u2wc3ax4txzstyyzzr5pvmfa.py new file mode 100644 index 0000000000000000000000000000000000000000..4df67f42cef736a4c8b48c9aa753d65f1d67d68a --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/a5/ca5ifqrw6w6uu2recydoe7f7w772u2wc3ax4txzstyyzzr5pvmfa.py @@ -0,0 +1,303 @@ +# AOT ID: ['7_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/i7/ci7k4nsws2k2ul2efnmk5kpdyf23awz656g7ls3el2bksnqpwzrz.py +# Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax => argmax +# Graph fragment: +# %arg0_1 : Tensor "bf16[2, 2048, 32000][65536000, 32000, 1]cuda:7" = PlaceHolder[target=arg0_1] +# %argmax : Tensor "i64[2, 2048][2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg0_1, -1), kwargs = {}) +# return %argmax +triton_red_fused_argmax_0 = async_compile.triton('triton_red_fused_argmax_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 65536, 'r0_': 262144000}} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 4096 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/w3/cw3qkrr2ihmnhdpotixr5nwhtnm2i42mu7tpjtb5rih5k77viw3s.py +# Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax_1 => argmax_1 +# Graph fragment: +# %arg1_1 : Tensor "f32[2, 2048, 32000][65760000, 32000, 1]cuda:7" = PlaceHolder[target=arg1_1] +# %argmax_1 : Tensor "i64[2, 2048][2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {}) +# return %argmax_1 +triton_red_fused_argmax_1 = async_compile.triton('triton_red_fused_argmax_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 65536, 'r0_': 524288000}} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 4096 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = xindex // 2048 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + 65760000*x1), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/5d/c5dxklofifxzswxjhdjvko4ncyrk6vkfrbohhy3eg5kffm63zqjg.py +# Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1, sum_2, clamp_min, truediv], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum, aten.clamp_min, aten.div] +# Source node to ATen node mapping: +# clamp_min => clamp_min +# eq => eq +# mul => mul +# squeeze => squeeze +# sum_1 => sum_1 +# sum_2 => sum_2 +# truediv => div +# Graph fragment: +# %argmax : Tensor "i64[2, 2048][2048, 1]cuda:7" = PlaceHolder[target=argmax] +# %argmax_1 : Tensor "i64[2, 2048][2048, 1]cuda:7" = PlaceHolder[target=argmax_1] +# %arg2_1 : Tensor "i64[2, 2048, 1][2048, 1, 1]cuda:7" = PlaceHolder[target=arg2_1] +# %arg3_1 : Tensor "i64[2, 2048, 1][2048, 1, 1]cuda:7" = PlaceHolder[target=arg3_1] +# %sum_1 : Tensor "i64[][]cuda:7" = PlaceHolder[target=sum_1] +# %sum_2 : Tensor "i64[][]cuda:7" = PlaceHolder[target=sum_2] +# %eq : Tensor "b8[2, 2048][2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {}) +# %squeeze : Tensor "i64[2, 2048][2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg2_1, -1), kwargs = {}) +# %mul : Tensor "i64[2, 2048][2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq, %squeeze), kwargs = {}) +# %sum_1 : Tensor "i64[][]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul,), kwargs = {}) +# %sum_2 : Tensor "i64[][]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg3_1,), kwargs = {}) +# %clamp_min : Tensor "f32[][]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%sum_2, 1e-06), kwargs = {}) +# %div : Tensor "f32[][]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, %clamp_min), kwargs = {}) +# return %sum_1,%sum_2,%div +triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2 = async_compile.triton('triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'in_ptr3': '*i64', 'out_ptr2': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'r0_': 131072}} +) +@triton.jit +def triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 1 + r0_numel = 4096 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp4 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + _tmp11 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp9 = tl.load(in_ptr3 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = _tmp11 + tmp10 + _tmp11 = tl.where(r0_mask, tmp12, _tmp11) + tmp11 = tl.sum(_tmp11, 1)[:, None] + tmp13 = tmp7.to(tl.float32) + tmp14 = tmp11.to(tl.float32) + tmp15 = 1e-06 + tmp16 = triton_helpers.maximum(tmp14, tmp15) + tmp17 = (tmp13 / tmp16) + tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp17, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1 = args + args.clear() + assert_size_stride(arg0_1, (2, 2048, 32000), (65536000, 32000, 1)) + assert_size_stride(arg1_1, (2, 2048, 32000), (65760000, 32000, 1)) + assert_size_stride(arg2_1, (2, 2048, 1), (2048, 1, 1)) + assert_size_stride(arg3_1, (2, 2048, 1), (2048, 1, 1)) + with torch.cuda._DeviceGuard(7): + torch.cuda.set_device(7) + buf0 = empty_strided_cuda((2, 2048), (2048, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] + stream7 = get_raw_stream(7) + triton_red_fused_argmax_0.run(arg0_1, buf0, 4096, 32000, stream=stream7) + del arg0_1 + buf1 = empty_strided_cuda((2, 2048), (2048, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] + stream7 = get_raw_stream(7) + triton_red_fused_argmax_1.run(arg1_1, buf1, 4096, 32000, stream=stream7) + del arg1_1 + buf4 = empty_strided_cuda((), (), torch.float32) + # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1, sum_2, clamp_min, truediv], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum, aten.clamp_min, aten.div] + stream7 = get_raw_stream(7) + triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2.run(buf0, buf1, arg2_1, arg3_1, buf4, 1, 4096, stream=stream7) + del arg2_1 + del arg3_1 + del buf0 + del buf1 + return (buf4, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((2, 2048, 32000), (65536000, 32000, 1), device='cuda:7', dtype=torch.bfloat16) + arg1_1 = rand_strided((2, 2048, 32000), (65760000, 32000, 1), device='cuda:7', dtype=torch.float32) + arg2_1 = rand_strided((2, 2048, 1), (2048, 1, 1), device='cuda:7', dtype=torch.int64) + arg3_1 = rand_strided((2, 2048, 1), (2048, 1, 1), device='cuda:7', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/aotautograd/a2l2i4as3fxihygn6tifr2c7sv7z7uhhd3afgps3no524n4ulrlz/xc4sidtdcxu2oe6234y3vnte4napepvz5b3k5dai7p7g3ozv6ot b/SpecForge-ext/cache/compiled_kernels/aotautograd/a2l2i4as3fxihygn6tifr2c7sv7z7uhhd3afgps3no524n4ulrlz/xc4sidtdcxu2oe6234y3vnte4napepvz5b3k5dai7p7g3ozv6ot new file mode 100644 index 0000000000000000000000000000000000000000..1fc033ab08662bb56b03546c905cc5caa2d3815c Binary files /dev/null and b/SpecForge-ext/cache/compiled_kernels/aotautograd/a2l2i4as3fxihygn6tifr2c7sv7z7uhhd3afgps3no524n4ulrlz/xc4sidtdcxu2oe6234y3vnte4napepvz5b3k5dai7p7g3ozv6ot differ diff --git a/SpecForge-ext/cache/compiled_kernels/aotautograd/a456ohte55cfk34vt5ictu52mkuerc2s3xmdc7ukdfial345mmzj/bv4svuirv4gwtxrgkp3cby6wdbm4nabyzxn7bp6vk5l4sqlsol6 b/SpecForge-ext/cache/compiled_kernels/aotautograd/a456ohte55cfk34vt5ictu52mkuerc2s3xmdc7ukdfial345mmzj/bv4svuirv4gwtxrgkp3cby6wdbm4nabyzxn7bp6vk5l4sqlsol6 new file mode 100644 index 0000000000000000000000000000000000000000..0e7b7c9ae02f02447077283ce8334c779a533c84 Binary files /dev/null and b/SpecForge-ext/cache/compiled_kernels/aotautograd/a456ohte55cfk34vt5ictu52mkuerc2s3xmdc7ukdfial345mmzj/bv4svuirv4gwtxrgkp3cby6wdbm4nabyzxn7bp6vk5l4sqlsol6 differ diff --git a/SpecForge-ext/cache/compiled_kernels/aotautograd/a62uezssexfkttaiqhlsqrga62w4o3aigkc5fgz2cf2jf2fjnknu/ih4smsfn6wh2ck5gd73imsfzb7p5cehaxijxv5rbnqnkajttwfi b/SpecForge-ext/cache/compiled_kernels/aotautograd/a62uezssexfkttaiqhlsqrga62w4o3aigkc5fgz2cf2jf2fjnknu/ih4smsfn6wh2ck5gd73imsfzb7p5cehaxijxv5rbnqnkajttwfi new file mode 100644 index 0000000000000000000000000000000000000000..ea44bf94229efa08d66f2867256671bdbcdbe990 Binary files /dev/null and b/SpecForge-ext/cache/compiled_kernels/aotautograd/a62uezssexfkttaiqhlsqrga62w4o3aigkc5fgz2cf2jf2fjnknu/ih4smsfn6wh2ck5gd73imsfzb7p5cehaxijxv5rbnqnkajttwfi differ diff --git a/SpecForge-ext/cache/compiled_kernels/aotautograd/aaopjbc2qt6rz6f2ee3k6fpwdg6tvr35uxq5xhaovifrjcqlmbm5/jmy44dduaxjnjhua3sixyyroj5oo3uh2kzp573iou5xv76pofmt b/SpecForge-ext/cache/compiled_kernels/aotautograd/aaopjbc2qt6rz6f2ee3k6fpwdg6tvr35uxq5xhaovifrjcqlmbm5/jmy44dduaxjnjhua3sixyyroj5oo3uh2kzp573iou5xv76pofmt new file mode 100644 index 0000000000000000000000000000000000000000..edb1119efa322a9b32e9af1cf9421b6fd12468bc Binary files /dev/null and b/SpecForge-ext/cache/compiled_kernels/aotautograd/aaopjbc2qt6rz6f2ee3k6fpwdg6tvr35uxq5xhaovifrjcqlmbm5/jmy44dduaxjnjhua3sixyyroj5oo3uh2kzp573iou5xv76pofmt differ diff --git a/SpecForge-ext/cache/compiled_kernels/aotautograd/ab4r2mf76rggr73biag5n7wsbamntsirluyhr7bxmkcph3jgul3o/fb3tea2fqx7epnylieagusmrpykxlrddub5ewxl26pth7pivb25 b/SpecForge-ext/cache/compiled_kernels/aotautograd/ab4r2mf76rggr73biag5n7wsbamntsirluyhr7bxmkcph3jgul3o/fb3tea2fqx7epnylieagusmrpykxlrddub5ewxl26pth7pivb25 new file mode 100644 index 0000000000000000000000000000000000000000..d1e8bbaa34a1f26cfebecc6a95b8216b77bb31b9 Binary files /dev/null and b/SpecForge-ext/cache/compiled_kernels/aotautograd/ab4r2mf76rggr73biag5n7wsbamntsirluyhr7bxmkcph3jgul3o/fb3tea2fqx7epnylieagusmrpykxlrddub5ewxl26pth7pivb25 differ diff --git a/SpecForge-ext/cache/compiled_kernels/aotautograd/abv7t3czhy3uk7fr3mpib27u2hw6nynh4z3nquhfh6ztudbjhjey/a2i6qau2htkihiuqomjqlnbtg66jrzqunom2twsu2u3qt3bkj2i b/SpecForge-ext/cache/compiled_kernels/aotautograd/abv7t3czhy3uk7fr3mpib27u2hw6nynh4z3nquhfh6ztudbjhjey/a2i6qau2htkihiuqomjqlnbtg66jrzqunom2twsu2u3qt3bkj2i new file mode 100644 index 0000000000000000000000000000000000000000..af62d0dc822442b71c0f9041c5bb431576357ad1 Binary files /dev/null and b/SpecForge-ext/cache/compiled_kernels/aotautograd/abv7t3czhy3uk7fr3mpib27u2hw6nynh4z3nquhfh6ztudbjhjey/a2i6qau2htkihiuqomjqlnbtg66jrzqunom2twsu2u3qt3bkj2i differ diff --git a/SpecForge-ext/cache/compiled_kernels/aotautograd/acamb65tznnqf6srygto3fra5dfda2jmjp6s3mk4mjukrbptfnyj/7ekph3cbu6xwjzdk7wv4vmef5fne7ba6dmse42t6x7skuvbcl43 b/SpecForge-ext/cache/compiled_kernels/aotautograd/acamb65tznnqf6srygto3fra5dfda2jmjp6s3mk4mjukrbptfnyj/7ekph3cbu6xwjzdk7wv4vmef5fne7ba6dmse42t6x7skuvbcl43 new file mode 100644 index 0000000000000000000000000000000000000000..72e219b3c40a84118a6529f98b7f774c41ba1998 Binary files /dev/null and b/SpecForge-ext/cache/compiled_kernels/aotautograd/acamb65tznnqf6srygto3fra5dfda2jmjp6s3mk4mjukrbptfnyj/7ekph3cbu6xwjzdk7wv4vmef5fne7ba6dmse42t6x7skuvbcl43 differ diff --git a/SpecForge-ext/cache/compiled_kernels/aotautograd/adrt7qyfbnhn7q56izbq3rg2i6zjjibvj5jv4inw6srybb5buu6d/7e7wpzxk5zupxxqw2phknxiw5mywegwvqdhqbbv6h7axa5yp3qz b/SpecForge-ext/cache/compiled_kernels/aotautograd/adrt7qyfbnhn7q56izbq3rg2i6zjjibvj5jv4inw6srybb5buu6d/7e7wpzxk5zupxxqw2phknxiw5mywegwvqdhqbbv6h7axa5yp3qz new file mode 100644 index 0000000000000000000000000000000000000000..5699ac7f467112711092a1e6c1b399ead3717f51 Binary files /dev/null and b/SpecForge-ext/cache/compiled_kernels/aotautograd/adrt7qyfbnhn7q56izbq3rg2i6zjjibvj5jv4inw6srybb5buu6d/7e7wpzxk5zupxxqw2phknxiw5mywegwvqdhqbbv6h7axa5yp3qz differ diff --git a/SpecForge-ext/cache/compiled_kernels/aotautograd/aeexk6dxkvicenso7gzqonf2ipyz6hiqgyui23b2nf3biw2qek7s/i4h5mkorflin6q65j4k5az6u6ctguosze67p6nmv6c5zpgo3gkj b/SpecForge-ext/cache/compiled_kernels/aotautograd/aeexk6dxkvicenso7gzqonf2ipyz6hiqgyui23b2nf3biw2qek7s/i4h5mkorflin6q65j4k5az6u6ctguosze67p6nmv6c5zpgo3gkj new file mode 100644 index 0000000000000000000000000000000000000000..1811a2ae43b85c5fcb1859a32533c2000b6d6a38 Binary files /dev/null and b/SpecForge-ext/cache/compiled_kernels/aotautograd/aeexk6dxkvicenso7gzqonf2ipyz6hiqgyui23b2nf3biw2qek7s/i4h5mkorflin6q65j4k5az6u6ctguosze67p6nmv6c5zpgo3gkj differ diff --git a/SpecForge-ext/cache/compiled_kernels/aotautograd/agrfzybyxd2c3zlgwlkjr62faa5uqzeuqzbvcwi4ppfq3pleiojs/pld5hmmxrugzjoi5pzb6rq4j3fyudglikcelrjzs6k64jogakro b/SpecForge-ext/cache/compiled_kernels/aotautograd/agrfzybyxd2c3zlgwlkjr62faa5uqzeuqzbvcwi4ppfq3pleiojs/pld5hmmxrugzjoi5pzb6rq4j3fyudglikcelrjzs6k64jogakro new file mode 100644 index 0000000000000000000000000000000000000000..6bca4cae2d0a14e374cdd16d0a81cdeffed701ef Binary files /dev/null and b/SpecForge-ext/cache/compiled_kernels/aotautograd/agrfzybyxd2c3zlgwlkjr62faa5uqzeuqzbvcwi4ppfq3pleiojs/pld5hmmxrugzjoi5pzb6rq4j3fyudglikcelrjzs6k64jogakro differ diff --git a/SpecForge-ext/cache/compiled_kernels/aotautograd/ahmatoegix2xeelesidck6impbl4yjy7noyhyofplibqlcpx57eh/efzmvw5knzdb5uafbtkprzcxre3lv5sik3yl65rbkimlfsl36pw b/SpecForge-ext/cache/compiled_kernels/aotautograd/ahmatoegix2xeelesidck6impbl4yjy7noyhyofplibqlcpx57eh/efzmvw5knzdb5uafbtkprzcxre3lv5sik3yl65rbkimlfsl36pw new file mode 100644 index 0000000000000000000000000000000000000000..6847e02e7a115b2a1a46accfa732aeebce03633e Binary files /dev/null and b/SpecForge-ext/cache/compiled_kernels/aotautograd/ahmatoegix2xeelesidck6impbl4yjy7noyhyofplibqlcpx57eh/efzmvw5knzdb5uafbtkprzcxre3lv5sik3yl65rbkimlfsl36pw differ diff --git a/SpecForge-ext/cache/compiled_kernels/aotautograd/ai5ywn6yjrrzfd6rp66hmsut54wgpruttarbq4rpsjyqad2bjykv/mmzmobb5pyx3ydvmalhr5y3t5tyyghacy2cijvon6bfoxxgwbvr b/SpecForge-ext/cache/compiled_kernels/aotautograd/ai5ywn6yjrrzfd6rp66hmsut54wgpruttarbq4rpsjyqad2bjykv/mmzmobb5pyx3ydvmalhr5y3t5tyyghacy2cijvon6bfoxxgwbvr new file mode 100644 index 0000000000000000000000000000000000000000..60a5c50d2d3fba67961b3fad753556a811a51399 Binary files /dev/null and b/SpecForge-ext/cache/compiled_kernels/aotautograd/ai5ywn6yjrrzfd6rp66hmsut54wgpruttarbq4rpsjyqad2bjykv/mmzmobb5pyx3ydvmalhr5y3t5tyyghacy2cijvon6bfoxxgwbvr differ diff --git a/SpecForge-ext/cache/compiled_kernels/aotautograd/akapppapcdrtzyfomxocaqrnzeesbl5zcx6nhepzu25rie5erdvn/g23m2x4ozmnrf5by44bymx2jmpqa5win6iklerr277dsuv3bkek b/SpecForge-ext/cache/compiled_kernels/aotautograd/akapppapcdrtzyfomxocaqrnzeesbl5zcx6nhepzu25rie5erdvn/g23m2x4ozmnrf5by44bymx2jmpqa5win6iklerr277dsuv3bkek new file mode 100644 index 0000000000000000000000000000000000000000..cfac86540b30ebf77c16490150f88a90a00c5db6 Binary files /dev/null and b/SpecForge-ext/cache/compiled_kernels/aotautograd/akapppapcdrtzyfomxocaqrnzeesbl5zcx6nhepzu25rie5erdvn/g23m2x4ozmnrf5by44bymx2jmpqa5win6iklerr277dsuv3bkek differ diff --git a/SpecForge-ext/cache/compiled_kernels/aotautograd/alhkixfhakvf3xwt6phsatdzszmf6kk77cvpugesqhhr43kvjt4v/yrtg3tyqjkcosj233eslccjhh5bbjqfiiwxmwsmt6sh76o75vwa b/SpecForge-ext/cache/compiled_kernels/aotautograd/alhkixfhakvf3xwt6phsatdzszmf6kk77cvpugesqhhr43kvjt4v/yrtg3tyqjkcosj233eslccjhh5bbjqfiiwxmwsmt6sh76o75vwa new file mode 100644 index 0000000000000000000000000000000000000000..5a3266f6d699147b35aa28de8718a367d8374304 Binary files /dev/null and b/SpecForge-ext/cache/compiled_kernels/aotautograd/alhkixfhakvf3xwt6phsatdzszmf6kk77cvpugesqhhr43kvjt4v/yrtg3tyqjkcosj233eslccjhh5bbjqfiiwxmwsmt6sh76o75vwa differ diff --git a/SpecForge-ext/cache/compiled_kernels/aotautograd/amlmr7hb4tk2s75yvi66waicyne2pnec2jf3r63sd3y3exeyq2ts/pnih2yndjh5wvo5cfzsb352zycyjqqopy55v6wvthzurf5abkwz b/SpecForge-ext/cache/compiled_kernels/aotautograd/amlmr7hb4tk2s75yvi66waicyne2pnec2jf3r63sd3y3exeyq2ts/pnih2yndjh5wvo5cfzsb352zycyjqqopy55v6wvthzurf5abkwz new file mode 100644 index 0000000000000000000000000000000000000000..8e7b306cec04a555d02f919aff4974426fcf4c22 Binary files /dev/null and b/SpecForge-ext/cache/compiled_kernels/aotautograd/amlmr7hb4tk2s75yvi66waicyne2pnec2jf3r63sd3y3exeyq2ts/pnih2yndjh5wvo5cfzsb352zycyjqqopy55v6wvthzurf5abkwz differ diff --git a/SpecForge-ext/cache/compiled_kernels/aotautograd/amqbjvk3q3kcrqolinap76mo5c3zsgfxdqhvc2dsgiym35mkodpk/cnew6xtghruee6vtw7ywwk66usud2hxtzsygxiusntoyeal44q4 b/SpecForge-ext/cache/compiled_kernels/aotautograd/amqbjvk3q3kcrqolinap76mo5c3zsgfxdqhvc2dsgiym35mkodpk/cnew6xtghruee6vtw7ywwk66usud2hxtzsygxiusntoyeal44q4 new file mode 100644 index 0000000000000000000000000000000000000000..6342fc93eb597fec1fa4a63cc8aacf0d0d6f852c Binary files /dev/null and b/SpecForge-ext/cache/compiled_kernels/aotautograd/amqbjvk3q3kcrqolinap76mo5c3zsgfxdqhvc2dsgiym35mkodpk/cnew6xtghruee6vtw7ywwk66usud2hxtzsygxiusntoyeal44q4 differ diff --git a/SpecForge-ext/cache/compiled_kernels/aotautograd/aqb6vpjszd3wenobtvz4hdzlvqxreudnr76pyffaevz3h5dbcoh2/ujvb5itj5kjcb2aqiv3xg26edgvkdoahpq5ea5gy3s5bvmb2jlz b/SpecForge-ext/cache/compiled_kernels/aotautograd/aqb6vpjszd3wenobtvz4hdzlvqxreudnr76pyffaevz3h5dbcoh2/ujvb5itj5kjcb2aqiv3xg26edgvkdoahpq5ea5gy3s5bvmb2jlz new file mode 100644 index 0000000000000000000000000000000000000000..77c3b839108f6ab0bbc860f5da6e4648107af265 Binary files /dev/null and b/SpecForge-ext/cache/compiled_kernels/aotautograd/aqb6vpjszd3wenobtvz4hdzlvqxreudnr76pyffaevz3h5dbcoh2/ujvb5itj5kjcb2aqiv3xg26edgvkdoahpq5ea5gy3s5bvmb2jlz differ diff --git a/SpecForge-ext/cache/compiled_kernels/aotautograd/avlngrmukln62qt6ggizm4k7rqyo3njz4zexcqhguovncamftnmx/r4vglcjqpj6y77yy4mzkovec2m3vi3t3eabmavk66zevwc5tlic b/SpecForge-ext/cache/compiled_kernels/aotautograd/avlngrmukln62qt6ggizm4k7rqyo3njz4zexcqhguovncamftnmx/r4vglcjqpj6y77yy4mzkovec2m3vi3t3eabmavk66zevwc5tlic new file mode 100644 index 0000000000000000000000000000000000000000..906d4961aedbb3f0f02aa98cbd552170a994620e Binary files /dev/null and b/SpecForge-ext/cache/compiled_kernels/aotautograd/avlngrmukln62qt6ggizm4k7rqyo3njz4zexcqhguovncamftnmx/r4vglcjqpj6y77yy4mzkovec2m3vi3t3eabmavk66zevwc5tlic differ diff --git a/SpecForge-ext/cache/compiled_kernels/aotautograd/awvjrhp5yfyuz2dlq3kljdd5ypvnmueb3u2kq4dpp324jwckpvgy/v3brvplf357m6lgatrxxeow6kj7yeebaq6jljnznial57g7hebo b/SpecForge-ext/cache/compiled_kernels/aotautograd/awvjrhp5yfyuz2dlq3kljdd5ypvnmueb3u2kq4dpp324jwckpvgy/v3brvplf357m6lgatrxxeow6kj7yeebaq6jljnznial57g7hebo new file mode 100644 index 0000000000000000000000000000000000000000..dc1d306247f810909711b1916cda38ca53459a1b Binary files /dev/null and b/SpecForge-ext/cache/compiled_kernels/aotautograd/awvjrhp5yfyuz2dlq3kljdd5ypvnmueb3u2kq4dpp324jwckpvgy/v3brvplf357m6lgatrxxeow6kj7yeebaq6jljnznial57g7hebo differ diff --git a/SpecForge-ext/cache/compiled_kernels/aotautograd/ax7a436d5flseatop4dhlhzeadeiux4zjowi4mcflqag67auomif/prqyyeqifq2cbrmdfgea6f7ief6nhttxcm6hestaz27cmgdgbgp b/SpecForge-ext/cache/compiled_kernels/aotautograd/ax7a436d5flseatop4dhlhzeadeiux4zjowi4mcflqag67auomif/prqyyeqifq2cbrmdfgea6f7ief6nhttxcm6hestaz27cmgdgbgp new file mode 100644 index 0000000000000000000000000000000000000000..f22d7be91c9d232a900b770c08835922795fdc29 Binary files /dev/null and b/SpecForge-ext/cache/compiled_kernels/aotautograd/ax7a436d5flseatop4dhlhzeadeiux4zjowi4mcflqag67auomif/prqyyeqifq2cbrmdfgea6f7ief6nhttxcm6hestaz27cmgdgbgp differ diff --git a/SpecForge-ext/cache/compiled_kernels/aotautograd/axfbbjuwexflmpuzzww77cz3eolof5sheez5j65o66zcf7fe5vsb/zr7pifqnhve4duoopntuz3bw64caqyqvugyfyfnoxonrzu3elzw b/SpecForge-ext/cache/compiled_kernels/aotautograd/axfbbjuwexflmpuzzww77cz3eolof5sheez5j65o66zcf7fe5vsb/zr7pifqnhve4duoopntuz3bw64caqyqvugyfyfnoxonrzu3elzw new file mode 100644 index 0000000000000000000000000000000000000000..5d0e83d4e79216632e769296bf359217515f91ba Binary files /dev/null and b/SpecForge-ext/cache/compiled_kernels/aotautograd/axfbbjuwexflmpuzzww77cz3eolof5sheez5j65o66zcf7fe5vsb/zr7pifqnhve4duoopntuz3bw64caqyqvugyfyfnoxonrzu3elzw differ diff --git a/SpecForge-ext/cache/compiled_kernels/aotautograd/axrbyqm2q7yyjvihxrmlltrk3g6bm4gonysyorg7wcuv7uqe7vjb/u6pbkc7bgpnhglcb6tvtybdlowkajitlu5ev2vd53h5l4a7iim5 b/SpecForge-ext/cache/compiled_kernels/aotautograd/axrbyqm2q7yyjvihxrmlltrk3g6bm4gonysyorg7wcuv7uqe7vjb/u6pbkc7bgpnhglcb6tvtybdlowkajitlu5ev2vd53h5l4a7iim5 new file mode 100644 index 0000000000000000000000000000000000000000..34e2b19f544a288c821caa872e9da5c4ed73e2fb Binary files /dev/null and b/SpecForge-ext/cache/compiled_kernels/aotautograd/axrbyqm2q7yyjvihxrmlltrk3g6bm4gonysyorg7wcuv7uqe7vjb/u6pbkc7bgpnhglcb6tvtybdlowkajitlu5ev2vd53h5l4a7iim5 differ diff --git a/SpecForge-ext/cache/compiled_kernels/b3/ab1ea38acb4bf1df3980644f9601c8115dc462d59ebac29bb9dcc4573e4c9191.best_config b/SpecForge-ext/cache/compiled_kernels/b3/ab1ea38acb4bf1df3980644f9601c8115dc462d59ebac29bb9dcc4573e4c9191.best_config new file mode 100644 index 0000000000000000000000000000000000000000..252797a257cbca60dcb912b3691ab023979ae155 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/b3/ab1ea38acb4bf1df3980644f9601c8115dc462d59ebac29bb9dcc4573e4c9191.best_config @@ -0,0 +1 @@ +{"XBLOCK": 64, "R0_BLOCK": 64, "num_warps": 16, "num_stages": 1, "configs_hash": "48464ea7d171263ae4fed5184e32a30841f1081b8df295ec1f8e2f76e5287c9d", "found_by_coordesc": false, "time_taken_ms": 48, "triton_cache_hash": "BXWZSSWKBTIG7YDOE6QDLF3DYUHLUN57GPEDYW37ZDRQO2XWRGCQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/b3/cb3wpx6rjkaergqliask6nznyzrunubqguqa3smp5sn7b4uz6idx.py b/SpecForge-ext/cache/compiled_kernels/b3/cb3wpx6rjkaergqliask6nznyzrunubqguqa3smp5sn7b4uz6idx.py new file mode 100644 index 0000000000000000000000000000000000000000..58dfb300c8e8bdf9aaeb7e42c3b474ac71db504d --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/b3/cb3wpx6rjkaergqliask6nznyzrunubqguqa3smp5sn7b4uz6idx.py @@ -0,0 +1,1051 @@ +# AOT ID: ['6_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/vp/cvpf44p6d56ladmrqf7fvuche2mzrfe4zaaok6pmgcub4rrhj5rm.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:3" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 262144, 128, 1]cuda:3" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[2, 32, 2048][65536, 2048, 1]cuda:3" = PlaceHolder[target=buf0] +# %full_default : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:3, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_zeros_0 = async_compile.triton('triton_red_fused_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 131072, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 1048576, 'r0_': 67108864}} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 131072 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = ((xindex // 2048) % 32) + x2 = xindex // 65536 + x4 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/v4/cv43dygvyi45kg6vkof4c4ine37x3os266mathuyrzj2acxe5qrs.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:3" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[2, 8, 2048, 128][2097152, 262144, 128, 1]cuda:3" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[2, 8, 2048, 128][2097152, 262144, 128, 1]cuda:3" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:3" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:3" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 262144, 128, 1]cuda:3" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:3" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[2, 8, 2048, 128][2097152, 262144, 128, 1]cuda:3" = PlaceHolder[target=getitem_5] +# %primals_5 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:3" = PlaceHolder[target=primals_5] +# %primals_4 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:3" = PlaceHolder[target=primals_4] +# %primals_9 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:3" = PlaceHolder[target=primals_9] +# %primals_10 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:3" = PlaceHolder[target=primals_10] +# %primals_7 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:3" = PlaceHolder[target=primals_7] +# %primals_8 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:3" = PlaceHolder[target=primals_8] +# %primals_11 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:3" = PlaceHolder[target=primals_11] +# %primals_12 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:3" = PlaceHolder[target=primals_12] +# %primals_6 : Tensor "i64[2][1]cuda:3" = PlaceHolder[target=primals_6] +# %full_default : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:3, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {}) +# return %getitem_4 +triton_tem_fused_zeros_1 = async_compile.triton('triton_tem_fused_zeros_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 16 + stride_q_idx_h = 256 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, getitem, getitem_1, tangents_1 = args + args.clear() + assert_size_stride(primals_1, (2, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_2, (2, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_3, (2, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_4, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_5, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_6, (2, ), (1, )) + assert_size_stride(primals_7, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_8, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_9, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_11, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_12, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(getitem, (2, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(getitem_1, (2, 32, 2048), (65536, 2048, 1)) + assert_size_stride(tangents_1, (2, 32, 2048, 128), (8388608, 262144, 128, 1)) + with torch.cuda._DeviceGuard(3): + torch.cuda.set_device(3) + buf1 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream3 = get_raw_stream(3) + triton_red_fused_zeros_0.run(getitem, tangents_1, buf1, 131072, 128, stream=stream3) + del getitem + buf3 = empty_strided_cuda((2, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((2, 8, 2048, 128), (2097152, 262144, 128, 1), torch.bfloat16) + buf5 = empty_strided_cuda((2, 8, 2048, 128), (2097152, 262144, 128, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream3 = get_raw_stream(3) + triton_tem_fused_zeros_1.run(primals_1, primals_2, primals_3, getitem_1, buf1, tangents_1, buf3, buf4, primals_5, primals_4, primals_9, primals_10, primals_7, primals_8, primals_11, primals_12, primals_6, buf5, 80, 2, 8, stream=stream3) + del buf1 + del getitem_1 + del primals_1 + del primals_10 + del primals_11 + del primals_12 + del primals_2 + del primals_3 + del primals_4 + del primals_5 + del primals_6 + del primals_7 + del primals_8 + del primals_9 + del tangents_1 + return (buf3, buf5, buf4, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16) + primals_2 = rand_strided((2, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:3', dtype=torch.bfloat16) + primals_3 = rand_strided((2, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:3', dtype=torch.bfloat16) + primals_4 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:3', dtype=torch.int32) + primals_5 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32) + primals_6 = rand_strided((2, ), (1, ), device='cuda:3', dtype=torch.int64) + primals_7 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32) + primals_8 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:3', dtype=torch.int32) + primals_9 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32) + primals_10 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:3', dtype=torch.int32) + primals_11 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32) + primals_12 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:3', dtype=torch.int32) + getitem = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16) + getitem_1 = rand_strided((2, 32, 2048), (65536, 2048, 1), device='cuda:3', dtype=torch.float32) + tangents_1 = rand_strided((2, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:3', dtype=torch.bfloat16) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, getitem, getitem_1, tangents_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/b3/cb3wvsz364lkwor5ugj65snlotdptbzo235wic5ntoxv6ukogiv6.py b/SpecForge-ext/cache/compiled_kernels/b3/cb3wvsz364lkwor5ugj65snlotdptbzo235wic5ntoxv6ukogiv6.py new file mode 100644 index 0000000000000000000000000000000000000000..26dd7cf4a551707c069a708076789e039367908c --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/b3/cb3wvsz364lkwor5ugj65snlotdptbzo235wic5ntoxv6ukogiv6.py @@ -0,0 +1,309 @@ +# AOT ID: ['4_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/5o/c5oswnr7dpwwcqp5m7thpaf4owvpseka6amdvdroewnorzsre6n2.py +# Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul, x1, x2, neg, cat, mul_1, q_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] +# Source node to ATen node mapping: +# cat => cat +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# mul => mul_24 +# mul_1 => mul_45 +# neg => neg +# q_embed => add_54 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# x1 => slice_1 +# x2 => slice_2 +# Graph fragment: +# %primals_12 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:6" = PlaceHolder[target=primals_12] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:6" = PlaceHolder[target=primals_8] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:6" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:6" = PlaceHolder[target=primals_6] +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %mul_24 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_12, %unsqueeze), kwargs = {}) +# %slice_1 : Tensor "bf16[s48, s34, s9, (s24//2)][s24*s34*s9, s24, s24*s34, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_12, 3, 0, %floordiv), kwargs = {}) +# %slice_2 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s24*s34*s9, s24, s24*s34, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_12, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %neg : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s34*s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), s34*Max(1, s24 - ((s24//2))), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_2,), kwargs = {}) +# %cat : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg, %slice_1], -1), kwargs = {}) +# %mul_45 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat, %unsqueeze_1), kwargs = {}) +# %add_54 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_24, %mul_45), kwargs = {}) +# return %add_54 +triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0 = async_compile.triton('triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ym/cymroqqgat67pxc4rnyt5dzn7thy5u4fbc3jl5cbhuh5rx3hz3bb.py +# Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul_2, x1_1, x2_1, neg_1, cat_1, mul_3, k_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] +# Source node to ATen node mapping: +# cat_1 => cat_1 +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# k_embed => add_90 +# mul_2 => mul_54 +# mul_3 => mul_75 +# neg_1 => neg_1 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# x1_1 => slice_3 +# x2_1 => slice_4 +# Graph fragment: +# %primals_14 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24, s24*s25, 1]cuda:6" = PlaceHolder[target=primals_14] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:6" = PlaceHolder[target=primals_8] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:6" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:6" = PlaceHolder[target=primals_6] +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %mul_54 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24, s24*s25, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_14, %unsqueeze), kwargs = {}) +# %slice_3 : Tensor "bf16[s48, s25, s9, (s24//2)][s24*s25*s9, s24, s24*s25, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_14, 3, 0, %floordiv), kwargs = {}) +# %slice_4 : Tensor "bf16[s48, s25, s9, s24 - ((s24//2))][s24*s25*s9, s24, s24*s25, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_14, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %neg_1 : Tensor "bf16[s48, s25, s9, s24 - ((s24//2))][s25*s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), s25*Max(1, s24 - ((s24//2))), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_4,), kwargs = {}) +# %cat_1 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_1, %slice_3], -1), kwargs = {}) +# %mul_75 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_1, %unsqueeze_1), kwargs = {}) +# %add_90 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24, s24*s25, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_54, %mul_75), kwargs = {}) +# return %add_90 +triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1 = async_compile.triton('triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4194304}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14 = args + args.clear() + s92 = primals_1 + s24 = primals_2 + s96 = primals_3 + s79 = primals_5 + s9 = primals_7 + s38 = primals_9 + s48 = primals_10 + s34 = primals_11 + s25 = primals_13 + assert_size_stride(primals_4, (1, 1, s92, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_6, (1, 1, s79, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_8, (1, s9), (s9, 1)) + assert_size_stride(primals_12, (s48, s34, s9, s24), (s24*s34*s9, s24, s24*s34, 1)) + assert_size_stride(primals_14, (s48, s25, s9, s24), (s24*s25*s9, s24, s24*s25, 1)) + with torch.cuda._DeviceGuard(6): + torch.cuda.set_device(6) + ps0 = s24*s34 + buf0 = empty_strided_cuda((s48, s34, s9, s24), (s24*s34*s9, s24, s24*s34, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul, x1, x2, neg, cat, mul_1, q_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0_xnumel = s24*s34*s48*s9 + stream6 = get_raw_stream(6) + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0.run(primals_12, primals_8, primals_4, primals_6, buf0, ps0, s9, s92, s24, s79, triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0_xnumel, stream=stream6) + del primals_12 + ps1 = s24*s25 + buf1 = empty_strided_cuda((s48, s25, s9, s24), (s24*s25*s9, s24, s24*s25, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul_2, x1_1, x2_1, neg_1, cat_1, mul_3, k_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1_xnumel = s24*s25*s48*s9 + stream6 = get_raw_stream(6) + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1.run(primals_14, primals_8, primals_4, primals_6, buf1, ps1, s9, s92, s24, s79, triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1_xnumel, stream=stream6) + del primals_14 + return (buf0, buf1, primals_4, primals_6, primals_8, s24, s9, s48, s34, s25, s92, s96, s79, s24 // 2, s24 + (-1)*(s24 // 2), ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 2048 + primals_2 = 128 + primals_3 = 5245440 + primals_4 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:6', dtype=torch.bfloat16) + primals_5 = 2048 + primals_6 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:6', dtype=torch.bfloat16) + primals_7 = 2048 + primals_8 = rand_strided((1, 2048), (2048, 1), device='cuda:6', dtype=torch.int64) + primals_9 = 1 + primals_10 = 2 + primals_11 = 32 + primals_12 = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:6', dtype=torch.bfloat16) + primals_13 = 8 + primals_14 = rand_strided((2, 8, 2048, 128), (2097152, 128, 1024, 1), device='cuda:6', dtype=torch.bfloat16) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/b3/cb3yxtybf744swmcpe2lvz7uxmfgl5a6kt4up2cmxf36y3ryayam.py b/SpecForge-ext/cache/compiled_kernels/b3/cb3yxtybf744swmcpe2lvz7uxmfgl5a6kt4up2cmxf36y3ryayam.py new file mode 100644 index 0000000000000000000000000000000000000000..27b6caf0e829749167f88e538e46ee9647af194a --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/b3/cb3yxtybf744swmcpe2lvz7uxmfgl5a6kt4up2cmxf36y3ryayam.py @@ -0,0 +1,43 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_sum_2(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + x2 = (xindex % ks1) + x3 = xindex // ks1 + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x2 + x3*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp5, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/b3/cb3zn7havdmrrxxamfnztb3sb6m4ob25cezvpphrxbtdmmxc33dr.py b/SpecForge-ext/cache/compiled_kernels/b3/cb3zn7havdmrrxxamfnztb3sb6m4ob25cezvpphrxbtdmmxc33dr.py new file mode 100644 index 0000000000000000000000000000000000000000..61c21495669c2ac9bb9f757d9de25301253284e5 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/b3/cb3zn7havdmrrxxamfnztb3sb6m4ob25cezvpphrxbtdmmxc33dr.py @@ -0,0 +1,49 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 131072, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % 32) + x2 = xindex // ks1 + x5 = triton_helpers.div_floor_integer(xindex, ks0) + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x4 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 4096*ks0*x2), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x0 + 128*x5*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/b5/330fb706ef2647f5b047a3b25af14cf31ad187cba1101ee6ad86347cbba08571.best_config b/SpecForge-ext/cache/compiled_kernels/b5/330fb706ef2647f5b047a3b25af14cf31ad187cba1101ee6ad86347cbba08571.best_config new file mode 100644 index 0000000000000000000000000000000000000000..ed4bbafec32134c55e06add8fdbae259cebe3543 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/b5/330fb706ef2647f5b047a3b25af14cf31ad187cba1101ee6ad86347cbba08571.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "6fcabd0411a839b7b5d117b5e6638bd1b5d7bc3379312c678d803859f08278a9", "found_by_coordesc": false, "time_taken_ms": 18, "triton_cache_hash": "EB4J5U2HKNQBLXRWK6B5L6ATOH55AWD3MB7P63KH5AKRGRDZER7A"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/b5/a174d34de4f5840ed8ab9627d4d36ac8e8b4a39f57f3b50a970ebbad815b55c8.best_config b/SpecForge-ext/cache/compiled_kernels/b5/a174d34de4f5840ed8ab9627d4d36ac8e8b4a39f57f3b50a970ebbad815b55c8.best_config new file mode 100644 index 0000000000000000000000000000000000000000..c2d9b36c5180887fa413aa1eb230c04dc216dd00 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/b5/a174d34de4f5840ed8ab9627d4d36ac8e8b4a39f57f3b50a970ebbad815b55c8.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "6fcabd0411a839b7b5d117b5e6638bd1b5d7bc3379312c678d803859f08278a9", "found_by_coordesc": false, "time_taken_ms": 18, "triton_cache_hash": "G2LU7LIHIOEHQSWVLFBJATACJ76YHM672CUBUDGJGAJUEQVWVOFQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/b5/cb5pa6ru5ujw6hi6je3xi7pbejqaaykzw3kd6ldbytfarnm2bc4c.py b/SpecForge-ext/cache/compiled_kernels/b5/cb5pa6ru5ujw6hi6je3xi7pbejqaaykzw3kd6ldbytfarnm2bc4c.py new file mode 100644 index 0000000000000000000000000000000000000000..1f18ce6b8a8c40508c52f5543e62e4425b17372b --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/b5/cb5pa6ru5ujw6hi6je3xi7pbejqaaykzw3kd6ldbytfarnm2bc4c.py @@ -0,0 +1,50 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 256, 'r0_': 4096}} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 32 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % 16) + x1 = xindex // 16 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + 17*r0_2 + 272*x1), xmask, other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x3), tmp13, xmask) + tl.store(out_ptr3 + (x3), tmp14, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/b5/cb5x7i45gu46spgpzunuljyjvppeggj22t7uf7nlbn6nvqjroiob.py b/SpecForge-ext/cache/compiled_kernels/b5/cb5x7i45gu46spgpzunuljyjvppeggj22t7uf7nlbn6nvqjroiob.py new file mode 100644 index 0000000000000000000000000000000000000000..45d0a90598a16c7999fb99b884064c7af32a3f2e --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/b5/cb5x7i45gu46spgpzunuljyjvppeggj22t7uf7nlbn6nvqjroiob.py @@ -0,0 +1,49 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 64, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % ks0) + x1 = xindex // ks0 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (r0_2 + x0 + 16*x1 + ks0*r0_2 + 16*ks0*x1), xmask, eviction_policy='evict_last', other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x0 + 16*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp13, xmask) + tl.store(out_ptr3 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp14, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/b5/cb5ykkkerkiarsqyvptvtskbx5eis4sgeqpymd7t7y6a2mgsuxni.py b/SpecForge-ext/cache/compiled_kernels/b5/cb5ykkkerkiarsqyvptvtskbx5eis4sgeqpymd7t7y6a2mgsuxni.py new file mode 100644 index 0000000000000000000000000000000000000000..f5837fe017774d202034b6fa3da6668d3334d412 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/b5/cb5ykkkerkiarsqyvptvtskbx5eis4sgeqpymd7t7y6a2mgsuxni.py @@ -0,0 +1,56 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 262144}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 151936 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tmp11 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last') + tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32) + tmp4 = tmp2 + tmp3 + tmp5 = tmp2 < 0 + tmp6 = tl.where(tmp5, tmp4, tmp2) + tl.device_assert(((0 <= tmp6) & (tmp6 < 151936)) | ~(xmask), "index out of bounds: 0 <= tmp6 < 151936") + tmp8 = tl.load(in_ptr1 + (tmp6), xmask, eviction_policy='evict_last').to(tl.int1) + tmp9 = tmp8.to(tl.int32) + tmp10 = tmp9.to(tl.int64) + tmp12 = tmp10 * tmp11 + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/bc/8d8709128bf09bdaa511e12b9eb09259b1c714ede40de5b3b2590a9f47abd324.best_config b/SpecForge-ext/cache/compiled_kernels/bc/8d8709128bf09bdaa511e12b9eb09259b1c714ede40de5b3b2590a9f47abd324.best_config new file mode 100644 index 0000000000000000000000000000000000000000..013261ee4b82e38df2fbf329df3c69013117eca4 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/bc/8d8709128bf09bdaa511e12b9eb09259b1c714ede40de5b3b2590a9f47abd324.best_config @@ -0,0 +1 @@ +{"XBLOCK": 64, "R0_BLOCK": 64, "num_warps": 16, "num_stages": 1, "configs_hash": "48464ea7d171263ae4fed5184e32a30841f1081b8df295ec1f8e2f76e5287c9d", "found_by_coordesc": false, "time_taken_ms": 57, "triton_cache_hash": "BXWZSSWKBTIG7YDOE6QDLF3DYUHLUN57GPEDYW37ZDRQO2XWRGCQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/bc/cbcjmmdd4rwzphcjmnxlmi32as4rv3cdhd37zaajsactabdidubh.py b/SpecForge-ext/cache/compiled_kernels/bc/cbcjmmdd4rwzphcjmnxlmi32as4rv3cdhd37zaajsactabdidubh.py new file mode 100644 index 0000000000000000000000000000000000000000..02bd2a757e27be59be7fde3261a78cca1f94fa13 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/bc/cbcjmmdd4rwzphcjmnxlmi32as4rv3cdhd37zaajsactabdidubh.py @@ -0,0 +1,49 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 131072, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % 32) + x2 = xindex // ks1 + x5 = triton_helpers.div_floor_integer(xindex, ks0) + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x4 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 4096*ks0*x2), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x0 + 128*x5*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/bh/cbhqle56n7we4b4miasvgh4jqrjbkehmv3legvjui32dka2bilvr.py b/SpecForge-ext/cache/compiled_kernels/bh/cbhqle56n7we4b4miasvgh4jqrjbkehmv3legvjui32dka2bilvr.py new file mode 100644 index 0000000000000000000000000000000000000000..0b509dcd7b37bbf5e817f207469af0100b3660c3 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/bh/cbhqle56n7we4b4miasvgh4jqrjbkehmv3legvjui32dka2bilvr.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks2 + stride_q_idx_h = 16*ks3 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/bh/cbhqxwc4tpqiixfdw6gq7wiyahgiwrsqvijvtkrjlvsrt2cutdua.py b/SpecForge-ext/cache/compiled_kernels/bh/cbhqxwc4tpqiixfdw6gq7wiyahgiwrsqvijvtkrjlvsrt2cutdua.py new file mode 100644 index 0000000000000000000000000000000000000000..20e8de8d13357d2e6edcd59bd57a45980fbf0d09 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/bh/cbhqxwc4tpqiixfdw6gq7wiyahgiwrsqvijvtkrjlvsrt2cutdua.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32', 'ks8': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128*ks1, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 2 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks8 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = ks8 + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/SpecForge-ext/cache/compiled_kernels/bk/cbkz6lgnz7445qug7slfnuiiwthgo7iovyqztbp7hzqg2rjr5o62.py b/SpecForge-ext/cache/compiled_kernels/bk/cbkz6lgnz7445qug7slfnuiiwthgo7iovyqztbp7hzqg2rjr5o62.py new file mode 100644 index 0000000000000000000000000000000000000000..3674619193b8e78f7ebef06aafaeefaa369881da --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/bk/cbkz6lgnz7445qug7slfnuiiwthgo7iovyqztbp7hzqg2rjr5o62.py @@ -0,0 +1,543 @@ +# AOT ID: ['5_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/mq/cmq33rjxz47uciqulhibuqonab77aboxlu3alymsqaa5zvrwrl5c.py +# Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_2, mask_3, mask_block_sum], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.permute, aten.sum] +# Source node to ATen node mapping: +# and_2 => bitwise_and_1 +# and_3 => bitwise_and_2 +# and_4 => bitwise_and_3, view_8 +# b => iota +# batched_outputs_2 => view_9 +# causal_mask => ge, view +# diagnol_mask => eq +# index => index +# index_1 => index_1 +# index_2 => index_2 +# lt => lt, view_1 +# lt_1 => lt_1, view_2 +# m => iota_2 +# mask_2 => view_10 +# mask_3 => permute +# mask_block_sum => sum_1 +# n => iota_3 +# padding_mask => bitwise_and, view_3, view_4 +# padding_mask_1 => lt_2, view_6 +# remainder => remainder +# remainder_1 => remainder_1 +# result_1 => bitwise_or, full_default +# result_2 => bitwise_or_1 +# sub => sub, view_7 +# suffix_mask => ge_1 +# Graph fragment: +# %arg0_1 : Tensor "i64[2][1]cuda:7" = PlaceHolder[target=arg0_1] +# %full_default : Tensor "b8[2, 1, 1][1, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 1], False), kwargs = {dtype: torch.bool, layout: torch.strided, device: cuda:7, pin_memory: False}) +# %iota_2 : Tensor "i64[2048][1]cuda:7"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (2048,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:7, requires_grad: False}) +# %view : Tensor "i64[2048, 1][1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {}) +# %iota_3 : Tensor "i64[2048][1]cuda:7"[num_users=5] = call_function[target=torch.ops.prims.iota.default](args = (2048,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:7, requires_grad: False}) +# %ge : Tensor "b8[2048, 2048][2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.ge.Tensor](args = (%view, %iota_3), kwargs = {}) +# %iota : Tensor "i64[2][1]cuda:7"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (2,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:7, requires_grad: False}) +# %index : Tensor "i64[2][1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {}) +# %view_1 : Tensor "i64[2, 1][1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index, [2, 1]), kwargs = {}) +# %lt : Tensor "b8[2, 2048][2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_3, %view_1), kwargs = {}) +# %view_4 : Tensor "b8[2, 1, 2048][2048, 2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt, [2, 1, 2048]), kwargs = {}) +# %index_1 : Tensor "i64[2][1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {}) +# %view_2 : Tensor "i64[2, 1][1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_1, [2, 1]), kwargs = {}) +# %lt_1 : Tensor "b8[2, 2048][2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_2, %view_2), kwargs = {}) +# %view_3 : Tensor "b8[2, 2048, 1][2048, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt_1, [2, 2048, 1]), kwargs = {}) +# %bitwise_and : Tensor "b8[2, 2048, 2048][4194304, 2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_4, %view_3), kwargs = {}) +# %bitwise_and_1 : Tensor "b8[2, 2048, 2048][4194304, 2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge, %bitwise_and), kwargs = {}) +# %bitwise_or : Tensor "b8[2, 2048, 2048][4194304, 2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%full_default, %bitwise_and_1), kwargs = {}) +# %ge_1 : Tensor "b8[2048][1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%iota_3, 2048), kwargs = {}) +# %remainder : Tensor "i64[2048][1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%iota_3, 2048), kwargs = {}) +# %index_2 : Tensor "i64[2][1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {}) +# %view_6 : Tensor "i64[2, 1][1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_2, [2, 1]), kwargs = {}) +# %lt_2 : Tensor "b8[2, 2048][2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%remainder, %view_6), kwargs = {}) +# %bitwise_and_2 : Tensor "b8[2, 2048][2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_1, %lt_2), kwargs = {}) +# %view_8 : Tensor "b8[2, 1, 2048][2048, 2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_and_2, [2, 1, 2048]), kwargs = {}) +# %view_7 : Tensor "i64[2048, 1][1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {}) +# %sub : Tensor "i64[2048, 2048][2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%iota_3, %view_7), kwargs = {}) +# %remainder_1 : Tensor "i64[2048, 2048][2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%sub, 2048), kwargs = {}) +# %eq : Tensor "b8[2048, 2048][2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%remainder_1, 0), kwargs = {}) +# %bitwise_and_3 : Tensor "b8[2, 2048, 2048][4194304, 2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_8, %eq), kwargs = {}) +# %bitwise_or_1 : Tensor "b8[2, 2048, 2048][4194304, 2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%bitwise_or, %bitwise_and_3), kwargs = {}) +# %view_9 : Tensor "b8[2, 1, 2048, 2048][4194304, 4194304, 2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_or_1, [2, 1, 2048, 2048]), kwargs = {}) +# %view_10 : Tensor "b8[2, 1, 16, 128, 16, 128][4194304, 4194304, 262144, 2048, 128, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand, [2, 1, 16, 128, 16, 128]), kwargs = {}) +# %permute : Tensor "b8[2, 1, 16, 16, 128, 128][4194304, 4194304, 262144, 128, 2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 1, 2, 4, 3, 5]), kwargs = {}) +# %sum_1 : Tensor "i64[2, 1, 16, 16][256, 256, 16, 1]cuda:7"[num_users=3] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute, [-2, -1]), kwargs = {}) +# return %sum_1 +triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0 = async_compile.triton('triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 512, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 8192, 'r0_': 0}} +) +@triton.jit +def triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 512 + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = ((xindex // 16) % 16) + x0 = (xindex % 16) + x2 = xindex // 256 + tmp3 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + _tmp29 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x6 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_4 = r0_index // 128 + r0_3 = (r0_index % 128) + tmp0 = r0_4 + 128*x1 + tmp1 = r0_3 + 128*x0 + tmp2 = tmp0 >= tmp1 + tmp4 = tmp1 < tmp3 + tmp5 = tmp0 < tmp3 + tmp6 = tmp4 & tmp5 + tmp7 = tmp2 & tmp6 + tmp8 = tl.full([1, 1], False, tl.int1) + tmp9 = tmp8 | tmp7 + tmp10 = tl.full([1, 1], 2048, tl.int64) + tmp11 = tmp1 >= tmp10 + tmp12 = tmp11 & tmp4 + tmp13 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp14 = (tmp13 % tmp10) + tmp15 = tl.full([1, 1], 0, tl.int32) + tmp16 = tmp14 != tmp15 + tmp17 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp18 = (libdevice.signbit(tmp10) != 0) if (tmp10).dtype is tl.float32 else tmp10 < 0 + tmp19 = tmp17 != tmp18 + tmp20 = tmp16 & tmp19 + tmp21 = tmp14 + tmp10 + tmp22 = tl.where(tmp20, tmp21, tmp14) + tmp23 = tl.full([1, 1], 0, tl.int64) + tmp24 = tmp22 == tmp23 + tmp25 = tmp12 & tmp24 + tmp26 = tmp9 | tmp25 + tmp27 = tmp26.to(tl.int64) + tmp28 = tl.broadcast_to(tmp27, [XBLOCK, R0_BLOCK]) + tmp30 = _tmp29 + tmp28 + _tmp29 = tl.where(r0_mask & xmask, tmp30, _tmp29) + tmp29 = tl.sum(_tmp29, 1)[:, None] + tl.store(out_ptr0 + (x6), tmp29, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/s6/cs6iqg54vezijca5pa3j2zarnynctaaf35ewgtya5ewf73uv3olg.py +# Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros] +# Source node to ATen node mapping: +# dense_mask_4 => full_default_4 +# Graph fragment: +# %full_default_4 : Tensor "i32[2, 1, 16, 17][272, 272, 17, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:7, pin_memory: False}) +# return %index_put_1 +triton_poi_fused_new_zeros_1 = async_compile.triton('triton_poi_fused_new_zeros_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 1024}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 4352}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_1(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 544 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/fh/cfhw7jsr5sk74kchdhqog3qyq7ge3qjqmmqwy37lr2be4lpykzx3.py +# Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices, full_blocks, full_blocks_1, dense_mask_1, col_indices_1, dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices, dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, num_blocks_in_row_1, child_7, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort, aten.eq, aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten.scalar_tensor, aten.where, aten.view, aten.index_put] +# Source node to ATen node mapping: +# arange_4 => iota_4 +# arange_6 => iota_8 +# child_3 => convert_element_type_3 +# child_4 => convert_element_type_4 +# child_7 => convert_element_type_6 +# child_8 => convert_element_type_7 +# col_indices => sort +# col_indices_1 => sort_1 +# col_range => iota_5 +# col_range_1 => iota_9 +# dense_mask => convert_element_type_2 +# dense_mask_1 => convert_element_type_5 +# dense_mask_2 => full_default_1 +# dense_mask_4 => full_default_4 +# full_blocks => eq_1 +# full_blocks_1 => convert_element_type_1 +# gt => gt +# index_mask => lt_4 +# index_mask_1 => lt_5 +# lt_3 => lt_3 +# num_blocks_in_row => sum_2 +# num_blocks_in_row_1 => sum_3 +# partial_blocks => bitwise_and_4 +# partial_blocks_1 => convert_element_type +# row_indices => unsqueeze +# row_indices_1 => unsqueeze_7 +# setitem => full_default_3, index_put, iota_6, iota_7, unsqueeze_2, unsqueeze_3, unsqueeze_4, unsqueeze_5, unsqueeze_6 +# setitem_1 => full_default_6, index_put_1, iota_10, iota_11, unsqueeze_10, unsqueeze_11, unsqueeze_12, unsqueeze_13, unsqueeze_9 +# unsqueeze_1 => unsqueeze_1 +# unsqueeze_3 => unsqueeze_8 +# valid_indices => full_default_2, where +# valid_indices_1 => full_default_5, where_1 +# Graph fragment: +# %sum_1 : Tensor "i64[2, 1, 16, 16][256, 512, 16, 1]cuda:7" = PlaceHolder[target=sum_1] +# %sum_2 : Tensor "i64[2, 1, 16][16, 32, 1]cuda:7" = PlaceHolder[target=sum_2] +# %sum_3 : Tensor "i64[2, 1, 16][16, 32, 1]cuda:7" = PlaceHolder[target=sum_3] +# %buf2 : Tensor "i16[2, 1, 16, 16][256, 512, 16, 1]cuda:7" = PlaceHolder[target=buf2] +# %convert_element_type_3 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:7" = PlaceHolder[target=convert_element_type_3] +# %convert_element_type_4 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:7" = PlaceHolder[target=convert_element_type_4] +# %index_put : Tensor "i32[2, 1, 16, 17][272, 272, 17, 1]cuda:7" = PlaceHolder[target=index_put] +# %buf4 : Tensor "i16[2, 1, 16, 16][256, 512, 16, 1]cuda:7" = PlaceHolder[target=buf4] +# %convert_element_type_6 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:7" = PlaceHolder[target=convert_element_type_6] +# %convert_element_type_7 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:7" = PlaceHolder[target=convert_element_type_7] +# %index_put_1 : Tensor "i32[2, 1, 16, 17][272, 272, 17, 1]cuda:7" = PlaceHolder[target=index_put_1] +# %gt : Tensor "b8[2, 1, 16, 16][256, 256, 16, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {}) +# %lt_3 : Tensor "b8[2, 1, 16, 16][256, 256, 16, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %bitwise_and_4 : Tensor "b8[2, 1, 16, 16][256, 256, 16, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%gt, %lt_3), kwargs = {}) +# %convert_element_type : Tensor "i8[2, 1, 16, 16][256, 256, 16, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%bitwise_and_4, torch.int8), kwargs = {}) +# %convert_element_type_2 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type, torch.int32), kwargs = {}) +# %sort : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%convert_element_type_2,), kwargs = {stable: True, descending: True}) +# %eq_1 : Tensor "b8[2, 1, 16, 16][256, 256, 16, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %convert_element_type_1 : Tensor "i8[2, 1, 16, 16][256, 256, 16, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%eq_1, torch.int8), kwargs = {}) +# %convert_element_type_5 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_1, torch.int32), kwargs = {}) +# %sort_1 : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%convert_element_type_5,), kwargs = {stable: True, descending: True}) +# %full_default_1 : Tensor "i32[2, 1, 16, 17][272, 272, 17, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:7, pin_memory: False}) +# %iota_7 : Tensor "i64[2][1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (2,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:7, requires_grad: False}) +# %unsqueeze_4 : Tensor "i64[2, 1][1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_7, -1), kwargs = {}) +# %unsqueeze_5 : Tensor "i64[2, 1, 1][1, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_4, -1), kwargs = {}) +# %unsqueeze_6 : Tensor "i64[2, 1, 1, 1][1, 1, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_5, -1), kwargs = {}) +# %iota_6 : Tensor "i64[1][1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:7, requires_grad: False}) +# %unsqueeze_2 : Tensor "i64[1, 1][1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_6, -1), kwargs = {}) +# %unsqueeze_3 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_2, -1), kwargs = {}) +# %iota_4 : Tensor "i32[16][1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:7, requires_grad: False}) +# %unsqueeze : Tensor "i32[16, 1][1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_4, -1), kwargs = {}) +# %iota_5 : Tensor "i32[16][1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:7, requires_grad: False}) +# %sum_2 : Tensor "i64[2, 1, 16][16, 16, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_2, [-1]), kwargs = {}) +# %convert_element_type_3 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_2, torch.int32), kwargs = {}) +# %unsqueeze_1 : Tensor "i32[2, 1, 16, 1][16, 16, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_3, 3), kwargs = {}) +# %lt_4 : Tensor "b8[2, 1, 16, 16][256, 256, 16, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_5, %unsqueeze_1), kwargs = {}) +# %convert_element_type_4 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_1, torch.int32), kwargs = {}) +# %full_default_2 : Tensor "i32[][]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 16), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:7, pin_memory: False}) +# %where : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_4, %convert_element_type_4, %full_default_2), kwargs = {}) +# %full_default_3 : Tensor "i32[2, 1, 1, 1][1, 1, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:7, pin_memory: False}) +# %index_put : Tensor "i32[2, 1, 16, 17][272, 272, 17, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_1, [%unsqueeze_6, %unsqueeze_3, %unsqueeze, %where], %full_default_3), kwargs = {}) +# %full_default_4 : Tensor "i32[2, 1, 16, 17][272, 272, 17, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:7, pin_memory: False}) +# %iota_11 : Tensor "i64[2][1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (2,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:7, requires_grad: False}) +# %unsqueeze_11 : Tensor "i64[2, 1][1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_11, -1), kwargs = {}) +# %unsqueeze_12 : Tensor "i64[2, 1, 1][1, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_11, -1), kwargs = {}) +# %unsqueeze_13 : Tensor "i64[2, 1, 1, 1][1, 1, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_12, -1), kwargs = {}) +# %iota_10 : Tensor "i64[1][1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:7, requires_grad: False}) +# %unsqueeze_9 : Tensor "i64[1, 1][1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_10, -1), kwargs = {}) +# %unsqueeze_10 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_9, -1), kwargs = {}) +# %iota_8 : Tensor "i32[16][1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:7, requires_grad: False}) +# %unsqueeze_7 : Tensor "i32[16, 1][1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_8, -1), kwargs = {}) +# %iota_9 : Tensor "i32[16][1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:7, requires_grad: False}) +# %sum_3 : Tensor "i64[2, 1, 16][16, 16, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_5, [-1]), kwargs = {}) +# %convert_element_type_6 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_3, torch.int32), kwargs = {}) +# %unsqueeze_8 : Tensor "i32[2, 1, 16, 1][16, 16, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_6, 3), kwargs = {}) +# %lt_5 : Tensor "b8[2, 1, 16, 16][256, 256, 16, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_9, %unsqueeze_8), kwargs = {}) +# %convert_element_type_7 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_3, torch.int32), kwargs = {}) +# %full_default_5 : Tensor "i32[][]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 16), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:7, pin_memory: False}) +# %where_1 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_5, %convert_element_type_7, %full_default_5), kwargs = {}) +# %full_default_6 : Tensor "i32[2, 1, 1, 1][1, 1, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:7, pin_memory: False}) +# %index_put_1 : Tensor "i32[2, 1, 16, 17][272, 272, 17, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_4, [%unsqueeze_13, %unsqueeze_10, %unsqueeze_7, %where_1], %full_default_6), kwargs = {}) +# return %buf2,%buf4,%sum_2,%sum_3,%convert_element_type_3,%convert_element_type_6,%convert_element_type_4,%buf9,%convert_element_type_7,%buf16 +triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2 = async_compile.triton('triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr4': '*i32', 'out_ptr5': '*i32', 'out_ptr6': '*i32', 'out_ptr7': '*i32', 'out_ptr8': '*i32', 'out_ptr9': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr7', 'out_ptr9'], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(in_ptr0, out_ptr4, out_ptr5, out_ptr6, out_ptr7, out_ptr8, out_ptr9, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 32 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + tmp0 = tl.load(in_ptr0 + (r0_1 + 16*x0), xmask, other=0.0) + tmp1 = tl.full([1, 1], 0, tl.int64) + tmp2 = tmp0 > tmp1 + tmp3 = tl.full([1, 1], 16384, tl.int64) + tmp4 = tmp0 < tmp3 + tmp5 = tmp2 & tmp4 + tmp6 = tmp5.to(tl.int8) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8.to(tl.int16) + tmp10 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp11 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12, tmp13, = triton_helpers.sort_with_index(tmp10, tmp11, None, 1, stable=True, descending=True) + tmp14 = tmp0 == tmp3 + tmp15 = tmp14.to(tl.int8) + tmp16 = tmp15.to(tl.int32) + tmp17 = tl.broadcast_to(tmp16, [XBLOCK, R0_BLOCK]) + tmp18, tmp19, = triton_helpers.sort_with_index(tmp17, tmp11, None, 1, stable=True, descending=True) + tmp20 = tmp7.to(tl.int64) + tmp21 = tl.broadcast_to(tmp20, [XBLOCK, R0_BLOCK]) + tmp23 = tl.where(xmask, tmp21, 0) + tmp24 = tl.sum(tmp23, 1)[:, None].to(tl.int64) + tmp25 = tmp16.to(tl.int64) + tmp26 = tl.broadcast_to(tmp25, [XBLOCK, R0_BLOCK]) + tmp28 = tl.where(xmask, tmp26, 0) + tmp29 = tl.sum(tmp28, 1)[:, None].to(tl.int64) + tmp30 = tmp24.to(tl.int32) + tmp31 = tmp29.to(tl.int32) + tmp32 = tmp13.to(tl.int64) + tmp33 = tmp32.to(tl.int32) + tmp34 = tmp8 < tmp30 + tmp35 = tl.full([1, 1], 16, tl.int32) + tmp36 = tl.where(tmp34, tmp33, tmp35) + tmp37 = tl.full([XBLOCK, R0_BLOCK], 17, tl.int32) + tmp38 = tmp36 + tmp37 + tmp39 = tmp36 < 0 + tmp40 = tl.where(tmp39, tmp38, tmp36) + tl.device_assert(((0 <= tmp40) & (tmp40 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp40 < 17") + tmp42 = tl.full([1, 1], 1, tl.int32) + tmp43 = tmp19.to(tl.int64) + tmp44 = tmp43.to(tl.int32) + tmp45 = tmp8 < tmp31 + tmp46 = tl.where(tmp45, tmp44, tmp35) + tmp47 = tmp46 + tmp37 + tmp48 = tmp46 < 0 + tmp49 = tl.where(tmp48, tmp47, tmp46) + tl.device_assert(((0 <= tmp49) & (tmp49 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp49 < 17") + tl.store(out_ptr4 + (x0), tmp30, xmask) + tl.store(out_ptr5 + (x0), tmp31, xmask) + tl.store(out_ptr6 + (r0_1 + 16*x0), tmp33, xmask) + tl.store(out_ptr7 + (tl.broadcast_to(tmp40 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) + tl.store(out_ptr8 + (r0_1 + 16*x0), tmp44, xmask) + tl.store(out_ptr9 + (tl.broadcast_to(tmp49 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/pe/cpecz443wnneogc65oicauauoytwy7k6ryeyv24laczmux6pdi2b.py +# Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] +# Source node to ATen node mapping: +# batched_outputs_3 => clone_4, slice_2 +# col_indices_2 => sort_2 +# num_blocks_in_row_2 => sum_4 +# q_indices => clone_6, convert_element_type_9 +# q_num_blocks => convert_element_type_8 +# transpose => permute_1 +# Graph fragment: +# %buf9 : Tensor "i32[2, 1, 16, 17][272, 272, 17, 1]cuda:7" = PlaceHolder[target=buf9] +# %buf11 : Tensor "i16[2, 1, 16, 16][256, 512, 16, 1]cuda:7" = PlaceHolder[target=buf11] +# %sum_4 : Tensor "i64[2, 1, 16][16, 32, 1]cuda:7" = PlaceHolder[target=sum_4] +# %slice_2 : Tensor "i32[2, 1, 16, 16][272, 272, 17, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, 16), kwargs = {}) +# %clone_4 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_2,), kwargs = {memory_format: torch.contiguous_format}) +# %permute_1 : Tensor "i32[2, 1, 16, 16][256, 256, 1, 16]cuda:7"[num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%clone_4, [0, 1, 3, 2]), kwargs = {}) +# %sort_2 : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%permute_1,), kwargs = {stable: True, descending: True}) +# %convert_element_type_9 : Tensor "i32[2, 1, 16, 16][256, 256, 1, 16]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_5, torch.int32), kwargs = {}) +# %clone_6 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_9,), kwargs = {memory_format: torch.contiguous_format}) +# %sum_4 : Tensor "i64[2, 1, 16][16, 16, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute_1, [-1]), kwargs = {}) +# %convert_element_type_8 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_4, torch.int32), kwargs = {}) +# return %buf11,%sum_4,%clone_6,%convert_element_type_8 +triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3 = async_compile.triton('triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 256, 'r0_': 4096}} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 32 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % 16) + x1 = xindex // 16 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + 17*r0_2 + 272*x1), xmask, other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x3), tmp13, xmask) + tl.store(out_ptr3 + (x3), tmp14, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, = args + args.clear() + assert_size_stride(arg0_1, (2, ), (1, )) + with torch.cuda._DeviceGuard(7): + torch.cuda.set_device(7) + buf0 = empty_strided_cuda((2, 1, 16, 16), (256, 512, 16, 1), torch.int64) + # Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_2, mask_3, mask_block_sum], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.permute, aten.sum] + stream7 = get_raw_stream(7) + triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0.run(arg0_1, buf0, 512, 16384, stream=stream7) + del arg0_1 + buf15 = empty_strided_cuda((2, 1, 16, 17), (272, 272, 17, 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros] + stream7 = get_raw_stream(7) + triton_poi_fused_new_zeros_1.run(buf15, 544, stream=stream7) + buf8 = empty_strided_cuda((2, 1, 16, 17), (272, 272, 17, 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] + stream7 = get_raw_stream(7) + triton_poi_fused_new_zeros_1.run(buf8, 544, stream=stream7) + buf6 = empty_strided_cuda((2, 1, 16), (16, 16, 1), torch.int32) + buf13 = empty_strided_cuda((2, 1, 16), (16, 16, 1), torch.int32) + buf7 = empty_strided_cuda((2, 1, 16, 16), (256, 256, 16, 1), torch.int32) + buf14 = empty_strided_cuda((2, 1, 16, 16), (256, 256, 16, 1), torch.int32) + # Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices, full_blocks, full_blocks_1, dense_mask_1, col_indices_1, dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices, dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, num_blocks_in_row_1, child_7, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort, aten.eq, aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + stream7 = get_raw_stream(7) + triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.run(buf0, buf6, buf13, buf7, buf8, buf14, buf15, 32, 16, stream=stream7) + del buf0 + buf22 = empty_strided_cuda((2, 1, 16, 16), (256, 256, 16, 1), torch.int32) + buf24 = empty_strided_cuda((2, 1, 16), (16, 16, 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] + stream7 = get_raw_stream(7) + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf8, buf22, buf24, 32, 16, stream=stream7) + del buf8 + buf19 = empty_strided_cuda((2, 1, 16, 16), (256, 256, 16, 1), torch.int32) + buf21 = empty_strided_cuda((2, 1, 16), (16, 16, 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, col_indices_3, full_q_indices, num_blocks_in_row_3, full_q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] + stream7 = get_raw_stream(7) + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf15, buf19, buf21, 32, 16, stream=stream7) + del buf15 + return (buf19, buf21, buf22, buf24, buf14, buf13, buf7, buf6, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((2, ), (1, ), device='cuda:7', dtype=torch.int64) + fn = lambda: call([arg0_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/bp/02166e66902b07c981abc0d6357a1442342717d52a0a4ffc2c98583aa957ee8d.best_config b/SpecForge-ext/cache/compiled_kernels/bp/02166e66902b07c981abc0d6357a1442342717d52a0a4ffc2c98583aa957ee8d.best_config new file mode 100644 index 0000000000000000000000000000000000000000..ed4bbafec32134c55e06add8fdbae259cebe3543 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/bp/02166e66902b07c981abc0d6357a1442342717d52a0a4ffc2c98583aa957ee8d.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "6fcabd0411a839b7b5d117b5e6638bd1b5d7bc3379312c678d803859f08278a9", "found_by_coordesc": false, "time_taken_ms": 18, "triton_cache_hash": "EB4J5U2HKNQBLXRWK6B5L6ATOH55AWD3MB7P63KH5AKRGRDZER7A"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/bp/4ff8e0241eceaf80a01d05271d3ff1bec630c19537a184598d4ec49bd53d786f.best_config b/SpecForge-ext/cache/compiled_kernels/bp/4ff8e0241eceaf80a01d05271d3ff1bec630c19537a184598d4ec49bd53d786f.best_config new file mode 100644 index 0000000000000000000000000000000000000000..cbf4eb5ae8826a07243c88f3ee991df371ea45fb --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/bp/4ff8e0241eceaf80a01d05271d3ff1bec630c19537a184598d4ec49bd53d786f.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 53, "triton_cache_hash": "UQSFYICF6CFQWZOBHCGZ7JZ457GHWVO6RMPN5ABNWOATFMKI6GQA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/bp/cbp4ofim2oujxe6hm47xzugia67k4kofgbgvt7n7d5gd3iux76li.py b/SpecForge-ext/cache/compiled_kernels/bp/cbp4ofim2oujxe6hm47xzugia67k4kofgbgvt7n7d5gd3iux76li.py new file mode 100644 index 0000000000000000000000000000000000000000..a73b462c7d3b0e3288b213668f3fcb0360459af4 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/bp/cbp4ofim2oujxe6hm47xzugia67k4kofgbgvt7n7d5gd3iux76li.py @@ -0,0 +1,50 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 256, 'r0_': 4096}} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 32 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % 16) + x1 = xindex // 16 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + 17*r0_2 + 272*x1), xmask, other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x3), tmp13, xmask) + tl.store(out_ptr3 + (x3), tmp14, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/bp/cbpuekklhjdszdnpjmnzg77zhi5rum3iueweicitfxwda6abrl2a.py b/SpecForge-ext/cache/compiled_kernels/bp/cbpuekklhjdszdnpjmnzg77zhi5rum3iueweicitfxwda6abrl2a.py new file mode 100644 index 0000000000000000000000000000000000000000..e29f51abfce77d159d041e65372b0165234a3f6a --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/bp/cbpuekklhjdszdnpjmnzg77zhi5rum3iueweicitfxwda6abrl2a.py @@ -0,0 +1,66 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/bs/3ccc058a4644249487c1bf2eda709262ae9176f2ce0d56eb626f4ab685c220b3.best_config b/SpecForge-ext/cache/compiled_kernels/bs/3ccc058a4644249487c1bf2eda709262ae9176f2ce0d56eb626f4ab685c220b3.best_config new file mode 100644 index 0000000000000000000000000000000000000000..88fa1eb0a94466d5a4e4f6b6a3a910edd6bf446b --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/bs/3ccc058a4644249487c1bf2eda709262ae9176f2ce0d56eb626f4ab685c220b3.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "6fcabd0411a839b7b5d117b5e6638bd1b5d7bc3379312c678d803859f08278a9", "found_by_coordesc": false, "time_taken_ms": 28, "triton_cache_hash": "NFXDDTZIEQAGR5JBWWCXV73QZILXJMIVVJVPZH3CHAIULDAND5UQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py b/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py new file mode 100644 index 0000000000000000000000000000000000000000..9d160f969198a5fa09291a4ff0fb79f0bad26939 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py @@ -0,0 +1,86 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr4': '*i32', 'out_ptr5': '*i32', 'out_ptr6': '*i32', 'out_ptr7': '*i32', 'out_ptr8': '*i32', 'out_ptr9': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr7', 'out_ptr9'], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(in_ptr0, out_ptr4, out_ptr5, out_ptr6, out_ptr7, out_ptr8, out_ptr9, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 32 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + tmp0 = tl.load(in_ptr0 + (r0_1 + 16*x0), xmask, other=0.0) + tmp1 = tl.full([1, 1], 0, tl.int64) + tmp2 = tmp0 > tmp1 + tmp3 = tl.full([1, 1], 16384, tl.int64) + tmp4 = tmp0 < tmp3 + tmp5 = tmp2 & tmp4 + tmp6 = tmp5.to(tl.int8) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8.to(tl.int16) + tmp10 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp11 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12, tmp13, = triton_helpers.sort_with_index(tmp10, tmp11, None, 1, stable=True, descending=True) + tmp14 = tmp0 == tmp3 + tmp15 = tmp14.to(tl.int8) + tmp16 = tmp15.to(tl.int32) + tmp17 = tl.broadcast_to(tmp16, [XBLOCK, R0_BLOCK]) + tmp18, tmp19, = triton_helpers.sort_with_index(tmp17, tmp11, None, 1, stable=True, descending=True) + tmp20 = tmp7.to(tl.int64) + tmp21 = tl.broadcast_to(tmp20, [XBLOCK, R0_BLOCK]) + tmp23 = tl.where(xmask, tmp21, 0) + tmp24 = tl.sum(tmp23, 1)[:, None].to(tl.int64) + tmp25 = tmp16.to(tl.int64) + tmp26 = tl.broadcast_to(tmp25, [XBLOCK, R0_BLOCK]) + tmp28 = tl.where(xmask, tmp26, 0) + tmp29 = tl.sum(tmp28, 1)[:, None].to(tl.int64) + tmp30 = tmp24.to(tl.int32) + tmp31 = tmp29.to(tl.int32) + tmp32 = tmp13.to(tl.int64) + tmp33 = tmp32.to(tl.int32) + tmp34 = tmp8 < tmp30 + tmp35 = tl.full([1, 1], 16, tl.int32) + tmp36 = tl.where(tmp34, tmp33, tmp35) + tmp37 = tl.full([XBLOCK, R0_BLOCK], 17, tl.int32) + tmp38 = tmp36 + tmp37 + tmp39 = tmp36 < 0 + tmp40 = tl.where(tmp39, tmp38, tmp36) + tl.device_assert(((0 <= tmp40) & (tmp40 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp40 < 17") + tmp42 = tl.full([1, 1], 1, tl.int32) + tmp43 = tmp19.to(tl.int64) + tmp44 = tmp43.to(tl.int32) + tmp45 = tmp8 < tmp31 + tmp46 = tl.where(tmp45, tmp44, tmp35) + tmp47 = tmp46 + tmp37 + tmp48 = tmp46 < 0 + tmp49 = tl.where(tmp48, tmp47, tmp46) + tl.device_assert(((0 <= tmp49) & (tmp49 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp49 < 17") + tl.store(out_ptr4 + (x0), tmp30, xmask) + tl.store(out_ptr5 + (x0), tmp31, xmask) + tl.store(out_ptr6 + (r0_1 + 16*x0), tmp33, xmask) + tl.store(out_ptr7 + (tl.broadcast_to(tmp40 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) + tl.store(out_ptr8 + (r0_1 + 16*x0), tmp44, xmask) + tl.store(out_ptr9 + (tl.broadcast_to(tmp49 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/cc/ccc2zhcp2onp6meazov42mk42clahnw3nlyeh7lah2divver6mup.py b/SpecForge-ext/cache/compiled_kernels/cc/ccc2zhcp2onp6meazov42mk42clahnw3nlyeh7lah2divver6mup.py new file mode 100644 index 0000000000000000000000000000000000000000..2fe48b1180c9a22ecbefa49ac097f932f102a390 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/cc/ccc2zhcp2onp6meazov42mk42clahnw3nlyeh7lah2divver6mup.py @@ -0,0 +1,62 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32, 'r0_': 32}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'in_ptr1': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr3'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2(in_ptr0, in_ptr1, out_ptr1, out_ptr2, out_ptr3, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x0), tmp5, xmask) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp6 = tl.load(in_ptr1 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8 < tmp5 + tmp10 = ks0 + tmp11 = tl.where(tmp9, tmp7, tmp10) + tmp12 = 1 + ks0 + tmp13 = tmp11 + tmp12 + tmp14 = tmp11 < 0 + tmp15 = tl.where(tmp14, tmp13, tmp11) + tl.device_assert(((0 <= tmp15) & (tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128)))) | ~(r0_mask & xmask), "index out of bounds: 0 <= tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128))") + tmp17 = tl.full([1, 1], 1, tl.int32) + tl.store(out_ptr2 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp7, r0_mask & xmask) + tl.store(out_ptr3 + (tl.broadcast_to(tmp15 + x0 + ks0*x0, [XBLOCK, R0_BLOCK])), tmp17, r0_mask & xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/cc/ccctqwiv2ol5sk2rtx3jkmxgeoonbpa32j4eyltbnwdla6jsvr6r.py b/SpecForge-ext/cache/compiled_kernels/cc/ccctqwiv2ol5sk2rtx3jkmxgeoonbpa32j4eyltbnwdla6jsvr6r.py new file mode 100644 index 0000000000000000000000000000000000000000..bce6dba54a66925cdd41fa961151695326980250 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/cc/ccctqwiv2ol5sk2rtx3jkmxgeoonbpa32j4eyltbnwdla6jsvr6r.py @@ -0,0 +1,320 @@ +# AOT ID: ['4_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/zu/czupkfdsgvzkkkyrmre5slwdxod32ccb5eacvhg2ud5wd2ypvoq2.py +# Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add] +# Source node to ATen node mapping: +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# Graph fragment: +# %tangents_2 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:7" = PlaceHolder[target=tangents_2] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:7" = PlaceHolder[target=primals_8] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:7" = PlaceHolder[target=primals_6] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:7" = PlaceHolder[target=primals_4] +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %mul_84 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, %unsqueeze_1), kwargs = {}) +# %slice_5 : Tensor "bf16[s48, s48, s9, s24 - ((s24//2))][s24*s48*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_84, 3, 0, %add_96), kwargs = {}) +# %slice_6 : Tensor "bf16[s48, s48, s9, (s24//2)][s24*s48*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_84, 3, %sub_72, %primals_2), kwargs = {}) +# %neg_2 : Tensor "bf16[s48, s48, s9, s24 - ((s24//2))][s48*s9*Max(1, s24 - ((s24//2))), s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_5,), kwargs = {}) +# %full_default : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.aten.full.default](args = ([%primals_10, %primals_10, %primals_7, %primals_2], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:7, pin_memory: False}) +# %slice_scatter_default : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default, %neg_2, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %slice_scatter_default_1 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default, %slice_6, 3, 0, %floordiv), kwargs = {}) +# %add_100 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default, %slice_scatter_default_1), kwargs = {}) +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %mul_85 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, %unsqueeze), kwargs = {}) +# %add_101 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_100, %mul_85), kwargs = {}) +# return %add_101 +triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0 = async_compile.triton('triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ue/cueey4eyo3vezj23w73bspqk3hwft3v6gzwq43jxzs4bcvkqnk3y.py +# Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add] +# Source node to ATen node mapping: +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# Graph fragment: +# %tangents_1 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:7" = PlaceHolder[target=tangents_1] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:7" = PlaceHolder[target=primals_8] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:7" = PlaceHolder[target=primals_6] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:7" = PlaceHolder[target=primals_4] +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %mul_86 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %unsqueeze_1), kwargs = {}) +# %slice_7 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s24*s34*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_86, 3, 0, %sub_72), kwargs = {}) +# %slice_8 : Tensor "bf16[s48, s34, s9, (s24//2)][s24*s34*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_86, 3, %sub_72, %primals_2), kwargs = {}) +# %neg_3 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s34*s9*Max(1, s24 - ((s24//2))), s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_7,), kwargs = {}) +# %full_default_2 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.aten.full.default](args = ([%primals_10, %primals_11, %primals_7, %primals_2], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:7, pin_memory: False}) +# %slice_scatter_default_2 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_2, %neg_3, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %slice_scatter_default_3 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_2, %slice_8, 3, 0, %floordiv), kwargs = {}) +# %add_106 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default_2, %slice_scatter_default_3), kwargs = {}) +# %mul_87 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %unsqueeze), kwargs = {}) +# %add_107 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_106, %mul_87), kwargs = {}) +# return %add_107 +triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1 = async_compile.triton('triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 67108864}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_2, primals_7, primals_10, primals_11, primals_1, primals_3, primals_5, floordiv, add_96, primals_4, primals_6, primals_8, tangents_1, tangents_2 = args + args.clear() + s24 = primals_2 + s9 = primals_7 + s48 = primals_10 + s34 = primals_11 + s92 = primals_1 + s96 = primals_3 + s79 = primals_5 + assert_size_stride(primals_4, (1, 1, s92, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_6, (1, 1, s79, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_8, (1, s9), (s9, 1)) + assert_size_stride(tangents_1, (s48, s34, s9, s24), (s24*s34*s9, s24*s9, s24, 1)) + assert_size_stride(tangents_2, (s48, s48, s9, s24), (s24*s48*s9, s24*s9, s24, 1)) + with torch.cuda._DeviceGuard(7): + torch.cuda.set_device(7) + buf0 = empty_strided_cuda((s48, s48, s9, s24), (s24*s48*s9, s24*s9, s24, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add] + triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0_xnumel = s24*s9*s48*s48 + stream7 = get_raw_stream(7) + triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0.run(tangents_2, primals_8, primals_6, primals_4, buf0, s24, s9, s79, s92, triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0_xnumel, stream=stream7) + del tangents_2 + buf1 = empty_strided_cuda((s48, s34, s9, s24), (s24*s34*s9, s24*s9, s24, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add] + triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1_xnumel = s24*s34*s48*s9 + stream7 = get_raw_stream(7) + triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1.run(tangents_1, primals_8, primals_6, primals_4, buf1, s24, s9, s79, s92, triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1_xnumel, stream=stream7) + del primals_4 + del primals_6 + del primals_8 + del tangents_1 + return (None, None, None, None, None, None, None, None, None, None, None, buf1, buf0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_2 = 128 + primals_7 = 2048 + primals_10 = 8 + primals_11 = 32 + primals_1 = 2048 + primals_3 = 5245440 + primals_5 = 2048 + floordiv = 64 + add_96 = 64 + primals_4 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:7', dtype=torch.bfloat16) + primals_6 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:7', dtype=torch.bfloat16) + primals_8 = rand_strided((1, 2048), (2048, 1), device='cuda:7', dtype=torch.int64) + tangents_1 = rand_strided((8, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:7', dtype=torch.bfloat16) + tangents_2 = rand_strided((8, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:7', dtype=torch.bfloat16) + fn = lambda: call([primals_2, primals_7, primals_10, primals_11, primals_1, primals_3, primals_5, floordiv, add_96, primals_4, primals_6, primals_8, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/ck/ccksctmn7dvdwu27cnkgisk777nnrd2keej2xmbgrfhagxbfw7z5.py b/SpecForge-ext/cache/compiled_kernels/ck/ccksctmn7dvdwu27cnkgisk777nnrd2keej2xmbgrfhagxbfw7z5.py new file mode 100644 index 0000000000000000000000000000000000000000..d669ce5abea7989363e19cd6680abc26b6da8d66 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ck/ccksctmn7dvdwu27cnkgisk777nnrd2keej2xmbgrfhagxbfw7z5.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 16 + stride_q_idx_h = 256 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/ck/cckzdxyy6ajufkz2tqe57rm35zwce622wgu4hgaweph34kugtk2y.py b/SpecForge-ext/cache/compiled_kernels/ck/cckzdxyy6ajufkz2tqe57rm35zwce622wgu4hgaweph34kugtk2y.py new file mode 100644 index 0000000000000000000000000000000000000000..9cce1f4b03a102deedb7d92a2e2f144149009fe6 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ck/cckzdxyy6ajufkz2tqe57rm35zwce622wgu4hgaweph34kugtk2y.py @@ -0,0 +1,416 @@ +# AOT ID: ['14_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wl/cwlepqkeid7zo46auilgnirm5hqxuf7wqtbi3bhndddz2uyg7dbu.py +# Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax => argmax +# Graph fragment: +# %arg1_1 : Tensor "bf16[8, s3, 32000][32000*s3, 32000, 1]cuda:7" = PlaceHolder[target=arg1_1] +# %argmax : Tensor "i64[8, s3][s3, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {}) +# return %argmax +triton_red_fused_argmax_0 = async_compile.triton('triton_red_fused_argmax_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/4w/c4wdhwlu6yb3wcwazdnzmgzewiemvznxvrr3525eojupqjldo5pt.py +# Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax_1 => argmax_1 +# Graph fragment: +# %arg3_1 : Tensor "f32[8, s3, 32000][s71, 32000, 1]cuda:7" = PlaceHolder[target=arg3_1] +# %argmax_1 : Tensor "i64[8, s3][s3, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg3_1, -1), kwargs = {}) +# return %argmax_1 +triton_red_fused_argmax_1 = async_compile.triton('triton_red_fused_argmax_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + ks1*x1), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/d2/cd2itktqorx3l5od6vqzz3m73qqojgnnlgzl47ehhpxtuclki64i.py +# Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum] +# Source node to ATen node mapping: +# eq => eq_2 +# mul => mul_7 +# squeeze => squeeze +# sum_1 => sum_1 +# Graph fragment: +# %argmax : Tensor "i64[8, s3][s3, 1]cuda:7" = PlaceHolder[target=argmax] +# %argmax_1 : Tensor "i64[8, s3][s3, 1]cuda:7" = PlaceHolder[target=argmax_1] +# %arg4_1 : Tensor "i64[8, s3, 1][s3, 1, 1]cuda:7" = PlaceHolder[target=arg4_1] +# %eq_2 : Tensor "b8[8, s3][s3, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {}) +# %squeeze : Tensor "i64[8, s3][s3, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg4_1, -1), kwargs = {}) +# %mul_7 : Tensor "i64[8, s3][s3, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq_2, %squeeze), kwargs = {}) +# %sum_1 : Tensor "i64[][]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_7,), kwargs = {}) +# return %buf3 +triton_red_fused_eq_mul_squeeze_sum_2 = async_compile.triton('triton_red_fused_eq_mul_squeeze_sum_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2, 'r0_': 8192}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'out_ptr0': '*i64', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (ks0*((((r0_1 + 4*ks0*x0) // ks0) % 8)) + ((r0_1 % ks0))), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp1 = tl.load(in_ptr1 + (ks0*((((r0_1 + 4*ks0*x0) // ks0) % 8)) + ((r0_1 % ks0))), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp4 = tl.load(in_ptr2 + (ks0*((((r0_1 + 4*ks0*x0) // ks0) % 8)) + ((r0_1 % ks0))), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask & xmask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + tl.store(out_ptr0 + (x0), tmp7, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/6h/c6hb7ee3npf6er65o5pppoq7uu3izd2oii7lz3cy4tsiysyvjtwd.py +# Topologically Sorted Source Nodes: [sum_2], Original ATen: [aten.sum] +# Source node to ATen node mapping: +# sum_2 => sum_2 +# Graph fragment: +# %arg6_1 : Tensor "i64[8, s14, 1][s14, 1, 1]cuda:7" = PlaceHolder[target=arg6_1] +# %sum_2 : Tensor "i64[][]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg6_1,), kwargs = {}) +# return %buf5 +triton_red_fused_sum_3 = async_compile.triton('triton_red_fused_sum_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2, 'r0_': 8192}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_sum_3(in_ptr0, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (ks0*((((r0_1 + 4*ks0*x0) // ks0) % 8)) + ((r0_1 % ks0))), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = _tmp2 + tmp1 + _tmp2 = tl.where(r0_mask & xmask, tmp3, _tmp2) + tmp2 = tl.sum(_tmp2, 1)[:, None] + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/6s/c6sofobf6aibktijbyii34pvwp2u36pgbnmz5t2jtcp2eu7hxsct.py +# Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1, sum_2, clamp_min, truediv], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum, aten.clamp_min, aten.div] +# Source node to ATen node mapping: +# clamp_min => clamp_min +# eq => eq_2 +# mul => mul_7 +# squeeze => squeeze +# sum_1 => sum_1 +# sum_2 => sum_2 +# truediv => div +# Graph fragment: +# %buf3 : Tensor "i64[2][1]cuda:7" = PlaceHolder[target=buf3] +# %buf5 : Tensor "i64[2][1]cuda:7" = PlaceHolder[target=buf5] +# %sum_1 : Tensor "i64[][]cuda:7" = PlaceHolder[target=sum_1] +# %sum_2 : Tensor "i64[][]cuda:7" = PlaceHolder[target=sum_2] +# %eq_2 : Tensor "b8[8, s3][s3, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {}) +# %squeeze : Tensor "i64[8, s3][s3, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg4_1, -1), kwargs = {}) +# %mul_7 : Tensor "i64[8, s3][s3, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq_2, %squeeze), kwargs = {}) +# %sum_1 : Tensor "i64[][]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_7,), kwargs = {}) +# %sum_2 : Tensor "i64[][]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg6_1,), kwargs = {}) +# %clamp_min : Tensor "f32[][]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%sum_2, 1e-06), kwargs = {}) +# %div : Tensor "f32[][]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, %clamp_min), kwargs = {}) +# return %sum_1,%sum_2,%div +triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4 = async_compile.triton('triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 1, 'r0_': 2}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'out_ptr2': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'r0_': 8}} +) +@triton.jit +def triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4(in_ptr0, in_ptr1, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 1 + r0_numel = 2 + R0_BLOCK: tl.constexpr = 2 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), None) + tmp4 = tl.load(in_ptr1 + (r0_0), None) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.sum(tmp1, 1)[:, None].to(tl.int64) + tmp5 = tl.broadcast_to(tmp4, [XBLOCK, R0_BLOCK]) + tmp7 = tl.sum(tmp5, 1)[:, None].to(tl.int64) + tmp8 = tmp3.to(tl.float32) + tmp9 = tmp7.to(tl.float32) + tmp10 = 1e-06 + tmp11 = triton_helpers.maximum(tmp9, tmp10) + tmp12 = (tmp8 / tmp11) + tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp12, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1 = args + args.clear() + s3 = arg0_1 + s71 = arg2_1 + s14 = arg5_1 + assert_size_stride(arg1_1, (8, s3, 32000), (32000*s3, 32000, 1)) + assert_size_stride(arg3_1, (8, s3, 32000), (s71, 32000, 1)) + assert_size_stride(arg4_1, (8, s3, 1), (s3, 1, 1)) + assert_size_stride(arg6_1, (8, s14, 1), (s14, 1, 1)) + with torch.cuda._DeviceGuard(7): + torch.cuda.set_device(7) + buf0 = empty_strided_cuda((8, s3), (s3, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] + triton_red_fused_argmax_0_xnumel = 8*s3 + stream7 = get_raw_stream(7) + triton_red_fused_argmax_0.run(arg1_1, buf0, triton_red_fused_argmax_0_xnumel, 32000, stream=stream7) + del arg1_1 + buf1 = empty_strided_cuda((8, s3), (s3, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] + triton_red_fused_argmax_1_xnumel = 8*s3 + stream7 = get_raw_stream(7) + triton_red_fused_argmax_1.run(arg3_1, buf1, s3, s71, triton_red_fused_argmax_1_xnumel, 32000, stream=stream7) + del arg3_1 + buf3 = empty_strided_cuda((2, ), (1, ), torch.int64) + # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum] + triton_red_fused_eq_mul_squeeze_sum_2_r0_numel = 4*s3 + stream7 = get_raw_stream(7) + triton_red_fused_eq_mul_squeeze_sum_2.run(buf0, buf1, arg4_1, buf3, s3, 2, triton_red_fused_eq_mul_squeeze_sum_2_r0_numel, stream=stream7) + del arg4_1 + del buf0 + del buf1 + buf5 = empty_strided_cuda((2, ), (1, ), torch.int64) + # Topologically Sorted Source Nodes: [sum_2], Original ATen: [aten.sum] + triton_red_fused_sum_3_r0_numel = 4*s14 + stream7 = get_raw_stream(7) + triton_red_fused_sum_3.run(arg6_1, buf5, s14, 2, triton_red_fused_sum_3_r0_numel, stream=stream7) + del arg6_1 + buf7 = empty_strided_cuda((), (), torch.float32) + # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1, sum_2, clamp_min, truediv], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum, aten.clamp_min, aten.div] + stream7 = get_raw_stream(7) + triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4.run(buf3, buf5, buf7, 1, 2, stream=stream7) + del buf3 + del buf5 + return (buf7, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 2025 + arg1_1 = rand_strided((8, 2025, 32000), (64800000, 32000, 1), device='cuda:7', dtype=torch.bfloat16) + arg2_1 = 65024000 + arg3_1 = rand_strided((8, 2025, 32000), (65024000, 32000, 1), device='cuda:7', dtype=torch.float32) + arg4_1 = rand_strided((8, 2025, 1), (2025, 1, 1), device='cuda:7', dtype=torch.int64) + arg5_1 = 2025 + arg6_1 = rand_strided((8, 2025, 1), (2025, 1, 1), device='cuda:7', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/cn/ccn3lbcmvoojl7h6shfwmw3dug5zsxjbghxgap2am4mubpm3o5s5.py b/SpecForge-ext/cache/compiled_kernels/cn/ccn3lbcmvoojl7h6shfwmw3dug5zsxjbghxgap2am4mubpm3o5s5.py new file mode 100644 index 0000000000000000000000000000000000000000..27b89f2c9d4d6d6ef3a0145ec975b5af578d4065 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/cn/ccn3lbcmvoojl7h6shfwmw3dug5zsxjbghxgap2am4mubpm3o5s5.py @@ -0,0 +1,693 @@ +# AOT ID: ['9_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +from torch._C import _cuda_getCurrentRawStream as get_raw_stream +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bw/cbwh3lvhwrwzn37scdwidtv64cijchzfdtklw5kp77i2o4tue5wk.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:0" = PlaceHolder[target=primals_1] +# %primals_3 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:0" = PlaceHolder[target=primals_3] +# %primals_5 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:0" = PlaceHolder[target=primals_5] +# %getitem_1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:0" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:0" = PlaceHolder[target=buf1] +# %primals_9 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:0" = PlaceHolder[target=primals_9] +# %primals_7 : Tensor "i32[2, 1, 16, s72][16*s72, 16*s72, s72, 1]cuda:0" = PlaceHolder[target=primals_7] +# %primals_11 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:0" = PlaceHolder[target=primals_11] +# %primals_13 : Tensor "i32[2, 1, 16, s4][16*s4, 16*s4, s4, 1]cuda:0" = PlaceHolder[target=primals_13] +# %primals_10 : Tensor "i64[2][1]cuda:0" = PlaceHolder[target=primals_10] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_3, %primals_5, %sdpa_score0, (2048, %primals_8, %primals_9, %primals_7, %primals_11, %primals_13, %primals_15, %primals_17, %primals_19, %primals_21, 128, 128, %sdpa_mask0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_10,)), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 2 + HQ = 32 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21 = args + args.clear() + s0 = primals_2 + s43 = primals_4 + s72 = primals_6 + s71 = primals_8 + s4 = primals_12 + s56 = primals_14 + s84 = primals_16 + s99 = primals_18 + s6 = primals_20 + assert_size_stride(primals_1, (2, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_3, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_5, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_7, (2, 1, 16, s72), (16*s72, 16*s72, s72, 1)) + assert_size_stride(primals_9, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (2, ), (1, )) + assert_size_stride(primals_11, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_13, (2, 1, 16, s4), (16*s4, 16*s4, s4, 1)) + assert_size_stride(primals_15, (2, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_17, (2, 1, s84, 16), (16*s84, 16*s84, 16, 1)) + assert_size_stride(primals_19, (2, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_21, (2, 1, s6, 16), (16*s6, 16*s6, 16, 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf0 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32) + buf1 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32) + buf2 = empty_strided_cuda((2, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream0 = get_raw_stream(0) + triton_tem_fused_0.run(primals_1, primals_3, primals_5, buf0, buf1, primals_9, primals_7, primals_11, primals_13, primals_10, buf2, s0, s72, 16, 2, 32, stream=stream0) + del buf1 + return (buf2, primals_1, primals_3, primals_5, primals_7, primals_9, primals_10, primals_11, primals_13, primals_15, primals_17, primals_19, primals_21, buf2, buf0, s0, s72, s4, s56, s84, s99, s6, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:0', dtype=torch.bfloat16) + primals_2 = 4096 + primals_3 = rand_strided((2, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:0', dtype=torch.bfloat16) + primals_4 = 4096 + primals_5 = rand_strided((2, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:0', dtype=torch.bfloat16) + primals_6 = 32 + primals_7 = rand_strided((2, 1, 16, 32), (512, 512, 32, 1), device='cuda:0', dtype=torch.int32) + primals_8 = 4096 + primals_9 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:0', dtype=torch.int32) + primals_10 = rand_strided((2, ), (1, ), device='cuda:0', dtype=torch.int64) + primals_11 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:0', dtype=torch.int32) + primals_12 = 32 + primals_13 = rand_strided((2, 1, 16, 32), (512, 512, 32, 1), device='cuda:0', dtype=torch.int32) + primals_14 = 32 + primals_15 = rand_strided((2, 1, 32), (32, 32, 1), device='cuda:0', dtype=torch.int32) + primals_16 = 32 + primals_17 = rand_strided((2, 1, 32, 16), (512, 512, 16, 1), device='cuda:0', dtype=torch.int32) + primals_18 = 32 + primals_19 = rand_strided((2, 1, 32), (32, 32, 1), device='cuda:0', dtype=torch.int32) + primals_20 = 32 + primals_21 = rand_strided((2, 1, 32, 16), (512, 512, 16, 1), device='cuda:0', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/cn/ccnsjveu43nfa7cb7ww4aiu5yps2rzsvlmzvexepbb6yezpu7lan.py b/SpecForge-ext/cache/compiled_kernels/cn/ccnsjveu43nfa7cb7ww4aiu5yps2rzsvlmzvexepbb6yezpu7lan.py new file mode 100644 index 0000000000000000000000000000000000000000..f3f630a182f2ca4f06248bb0a7271707aae9cd29 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/cn/ccnsjveu43nfa7cb7ww4aiu5yps2rzsvlmzvexepbb6yezpu7lan.py @@ -0,0 +1,62 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'in_ptr3': '*i64', 'out_ptr2': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'r0_': 131072}} +) +@triton.jit +def triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 1 + r0_numel = 4096 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp4 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + _tmp11 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp9 = tl.load(in_ptr3 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = _tmp11 + tmp10 + _tmp11 = tl.where(r0_mask, tmp12, _tmp11) + tmp11 = tl.sum(_tmp11, 1)[:, None] + tmp13 = tmp7.to(tl.float32) + tmp14 = tmp11.to(tl.float32) + tmp15 = 1e-06 + tmp16 = triton_helpers.maximum(tmp14, tmp15) + tmp17 = (tmp13 / tmp16) + tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp17, None) diff --git a/SpecForge-ext/cache/compiled_kernels/d2/cd2itktqorx3l5od6vqzz3m73qqojgnnlgzl47ehhpxtuclki64i.py b/SpecForge-ext/cache/compiled_kernels/d2/cd2itktqorx3l5od6vqzz3m73qqojgnnlgzl47ehhpxtuclki64i.py new file mode 100644 index 0000000000000000000000000000000000000000..74e21bb82e625ac5f6beb8d64b34cc631228258d --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/d2/cd2itktqorx3l5od6vqzz3m73qqojgnnlgzl47ehhpxtuclki64i.py @@ -0,0 +1,45 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2, 'r0_': 8192}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'out_ptr0': '*i64', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (ks0*((((r0_1 + 4*ks0*x0) // ks0) % 8)) + ((r0_1 % ks0))), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp1 = tl.load(in_ptr1 + (ks0*((((r0_1 + 4*ks0*x0) // ks0) % 8)) + ((r0_1 % ks0))), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp4 = tl.load(in_ptr2 + (ks0*((((r0_1 + 4*ks0*x0) // ks0) % 8)) + ((r0_1 % ks0))), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask & xmask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + tl.store(out_ptr0 + (x0), tmp7, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/d5/cd5lwybnbjetoaf6hxajj7itqmrk3fj4xejz52d5s2w56qouijor.py b/SpecForge-ext/cache/compiled_kernels/d5/cd5lwybnbjetoaf6hxajj7itqmrk3fj4xejz52d5s2w56qouijor.py new file mode 100644 index 0000000000000000000000000000000000000000..333450195a7f3d8aa69ad8eb5e8f81fcbb4198a3 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/d5/cd5lwybnbjetoaf6hxajj7itqmrk3fj4xejz52d5s2w56qouijor.py @@ -0,0 +1,62 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 128, 'r0_': 32}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'in_ptr1': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr3'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2(in_ptr0, in_ptr1, out_ptr1, out_ptr2, out_ptr3, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x0), tmp5, xmask) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp6 = tl.load(in_ptr1 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8 < tmp5 + tmp10 = ks0 + tmp11 = tl.where(tmp9, tmp7, tmp10) + tmp12 = 1 + ks0 + tmp13 = tmp11 + tmp12 + tmp14 = tmp11 < 0 + tmp15 = tl.where(tmp14, tmp13, tmp11) + tl.device_assert(((0 <= tmp15) & (tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128)))) | ~(r0_mask & xmask), "index out of bounds: 0 <= tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128))") + tmp17 = tl.full([1, 1], 1, tl.int32) + tl.store(out_ptr2 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp7, r0_mask & xmask) + tl.store(out_ptr3 + (tl.broadcast_to(tmp15 + x0 + ks0*x0, [XBLOCK, R0_BLOCK])), tmp17, r0_mask & xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/d7/cd73nnhe5667wrvnpwwmwkmeyjl5pwluzxnmki4j25npjpqqqb24.py b/SpecForge-ext/cache/compiled_kernels/d7/cd73nnhe5667wrvnpwwmwkmeyjl5pwluzxnmki4j25npjpqqqb24.py new file mode 100644 index 0000000000000000000000000000000000000000..4280d454a8eaa8188dfdec1887eef6a78513c436 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/d7/cd73nnhe5667wrvnpwwmwkmeyjl5pwluzxnmki4j25npjpqqqb24.py @@ -0,0 +1,45 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/db/8c5cba917943d72e6095940378f428aed2c2e397c029a543e6c09a7f3dfaaa5d.best_config b/SpecForge-ext/cache/compiled_kernels/db/8c5cba917943d72e6095940378f428aed2c2e397c029a543e6c09a7f3dfaaa5d.best_config new file mode 100644 index 0000000000000000000000000000000000000000..7921a12b007ca46a00e959ad115401adf0bd4471 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/db/8c5cba917943d72e6095940378f428aed2c2e397c029a543e6c09a7f3dfaaa5d.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "B46RWD5PEMKEQR7EBR6IG3BGTK4P7CWBVNOODNZQX5NAVXXVIH2A"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/db/cdb2xzrfoo57wjb66k65l77dul6dh4uspo7o2timi7dmrmwwdzhq.py b/SpecForge-ext/cache/compiled_kernels/db/cdb2xzrfoo57wjb66k65l77dul6dh4uspo7o2timi7dmrmwwdzhq.py new file mode 100644 index 0000000000000000000000000000000000000000..5a2d61fa0f9b910aea9c200e8c929626c338eeb6 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/db/cdb2xzrfoo57wjb66k65l77dul6dh4uspo7o2timi7dmrmwwdzhq.py @@ -0,0 +1,1065 @@ +# AOT ID: ['9_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/m7/cm7u3olama3gox426hxhxixqvzhslez5o7pvi4bnehh2g4ww6k6i.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:1" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 262144, 128, 1]cuda:1" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[8, 32, 2048][65536, 2048, 1]cuda:1" = PlaceHolder[target=buf0] +# %full_default : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:1, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_3, %primals_5, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, %primals_8, %primals_9, %primals_7, %primals_11, %primals_13, %primals_15, %primals_17, %primals_19, %primals_21, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_10,)), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_zeros_0 = async_compile.triton('triton_red_fused_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 524288, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 4194304, 'r0_': 268435456}} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 524288 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = ((xindex // 2048) % 32) + x2 = xindex // 65536 + x4 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/42/c42olsblh7ymaib2tr5gwhfzuighing5bkpmabq5hx7nxumtbsig.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:1" = PlaceHolder[target=primals_1] +# %primals_3 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:1" = PlaceHolder[target=primals_3] +# %primals_5 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:1" = PlaceHolder[target=primals_5] +# %getitem_1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:1" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:1" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 262144, 128, 1]cuda:1" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:1" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:1" = PlaceHolder[target=getitem_5] +# %primals_9 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:1" = PlaceHolder[target=primals_9] +# %primals_7 : Tensor "i32[8, 1, 16, s72][16*s72, 16*s72, s72, 1]cuda:1" = PlaceHolder[target=primals_7] +# %primals_15 : Tensor "i32[8, 1, s56][s56, s56, 1]cuda:1" = PlaceHolder[target=primals_15] +# %primals_17 : Tensor "i32[8, 1, s84, 16][16*s84, 16*s84, 16, 1]cuda:1" = PlaceHolder[target=primals_17] +# %primals_11 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:1" = PlaceHolder[target=primals_11] +# %primals_13 : Tensor "i32[8, 1, 16, s4][16*s4, 16*s4, s4, 1]cuda:1" = PlaceHolder[target=primals_13] +# %primals_19 : Tensor "i32[8, 1, s99][s99, s99, 1]cuda:1" = PlaceHolder[target=primals_19] +# %primals_21 : Tensor "i32[8, 1, s6, 16][16*s6, 16*s6, 16, 1]cuda:1" = PlaceHolder[target=primals_21] +# %primals_10 : Tensor "i64[8][1]cuda:1" = PlaceHolder[target=primals_10] +# %full_default : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:1, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_3, %primals_5, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, %primals_8, %primals_9, %primals_7, %primals_11, %primals_13, %primals_15, %primals_17, %primals_19, %primals_21, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_10,)), kwargs = {}) +# return %getitem_4 +triton_tem_fused_zeros_1 = async_compile.triton('triton_tem_fused_zeros_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks2 + stride_q_idx_h = 16*ks3 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_8, primals_6, primals_12, primals_14, primals_16, primals_18, primals_20, primals_1, primals_3, primals_5, primals_7, primals_9, primals_10, primals_11, primals_13, primals_15, primals_17, primals_19, primals_21, getitem, getitem_1, tangents_1 = args + args.clear() + s0 = primals_8 + s72 = primals_6 + s4 = primals_12 + s56 = primals_14 + s84 = primals_16 + s99 = primals_18 + s6 = primals_20 + assert_size_stride(primals_1, (8, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_3, (8, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_5, (8, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_7, (8, 1, 16, s72), (16*s72, 16*s72, s72, 1)) + assert_size_stride(primals_9, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (8, ), (1, )) + assert_size_stride(primals_11, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_13, (8, 1, 16, s4), (16*s4, 16*s4, s4, 1)) + assert_size_stride(primals_15, (8, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_17, (8, 1, s84, 16), (16*s84, 16*s84, 16, 1)) + assert_size_stride(primals_19, (8, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_21, (8, 1, s6, 16), (16*s6, 16*s6, 16, 1)) + assert_size_stride(getitem, (8, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(getitem_1, (8, 32, 2048), (65536, 2048, 1)) + assert_size_stride(tangents_1, (8, 32, 2048, 128), (8388608, 262144, 128, 1)) + with torch.cuda._DeviceGuard(1): + torch.cuda.set_device(1) + buf1 = empty_strided_cuda((8, 32, 2048), (65536, 2048, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream1 = get_raw_stream(1) + triton_red_fused_zeros_0.run(getitem, tangents_1, buf1, 524288, 128, stream=stream1) + del getitem + buf3 = empty_strided_cuda((8, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((8, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16) + buf5 = empty_strided_cuda((8, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream1 = get_raw_stream(1) + triton_tem_fused_zeros_1.run(primals_1, primals_3, primals_5, getitem_1, buf1, tangents_1, buf3, buf4, primals_9, primals_7, primals_15, primals_17, primals_11, primals_13, primals_19, primals_21, primals_10, buf5, s0, s72, s56, s84, 64 + ((127 + s0) // 128), 8, 8, stream=stream1) + del buf1 + del getitem_1 + del primals_1 + del primals_10 + del primals_11 + del primals_13 + del primals_15 + del primals_17 + del primals_19 + del primals_21 + del primals_3 + del primals_5 + del primals_7 + del primals_9 + del tangents_1 + return (buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_8 = 4096 + primals_6 = 32 + primals_12 = 32 + primals_14 = 32 + primals_16 = 32 + primals_18 = 32 + primals_20 = 32 + primals_1 = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:1', dtype=torch.bfloat16) + primals_3 = rand_strided((8, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:1', dtype=torch.bfloat16) + primals_5 = rand_strided((8, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:1', dtype=torch.bfloat16) + primals_7 = rand_strided((8, 1, 16, 32), (512, 512, 32, 1), device='cuda:1', dtype=torch.int32) + primals_9 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:1', dtype=torch.int32) + primals_10 = rand_strided((8, ), (1, ), device='cuda:1', dtype=torch.int64) + primals_11 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:1', dtype=torch.int32) + primals_13 = rand_strided((8, 1, 16, 32), (512, 512, 32, 1), device='cuda:1', dtype=torch.int32) + primals_15 = rand_strided((8, 1, 32), (32, 32, 1), device='cuda:1', dtype=torch.int32) + primals_17 = rand_strided((8, 1, 32, 16), (512, 512, 16, 1), device='cuda:1', dtype=torch.int32) + primals_19 = rand_strided((8, 1, 32), (32, 32, 1), device='cuda:1', dtype=torch.int32) + primals_21 = rand_strided((8, 1, 32, 16), (512, 512, 16, 1), device='cuda:1', dtype=torch.int32) + getitem = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:1', dtype=torch.bfloat16) + getitem_1 = rand_strided((8, 32, 2048), (65536, 2048, 1), device='cuda:1', dtype=torch.float32) + tangents_1 = rand_strided((8, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:1', dtype=torch.bfloat16) + fn = lambda: call([primals_8, primals_6, primals_12, primals_14, primals_16, primals_18, primals_20, primals_1, primals_3, primals_5, primals_7, primals_9, primals_10, primals_11, primals_13, primals_15, primals_17, primals_19, primals_21, getitem, getitem_1, tangents_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/db/cdbdssx4zf2lbltfbsjfmqxymtxotzkkq4kkmqp7jet2odl5uhc6.py b/SpecForge-ext/cache/compiled_kernels/db/cdbdssx4zf2lbltfbsjfmqxymtxotzkkq4kkmqp7jet2odl5uhc6.py new file mode 100644 index 0000000000000000000000000000000000000000..009b45877aa6f2ab107d6a78820a2e58544fdc90 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/db/cdbdssx4zf2lbltfbsjfmqxymtxotzkkq4kkmqp7jet2odl5uhc6.py @@ -0,0 +1,24 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 2048}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/db/cdbndait3n3m2xjscmwgxwrzumtrzuidy4zn3mvnp5xydagqf7fw.py b/SpecForge-ext/cache/compiled_kernels/db/cdbndait3n3m2xjscmwgxwrzumtrzuidy4zn3mvnp5xydagqf7fw.py new file mode 100644 index 0000000000000000000000000000000000000000..2934fb5d385b73f5159444bb5bbd4b338289c414 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/db/cdbndait3n3m2xjscmwgxwrzumtrzuidy4zn3mvnp5xydagqf7fw.py @@ -0,0 +1,184 @@ +# AOT ID: ['2_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/qa/cqambnamuby4hynvyzhccuoc4f5nkvwpn7yeizvaaaojnmlep42d.py +# Topologically Sorted Source Nodes: [hidden_states, pow_1, variance, rsqrt, hidden_states_1, to_1, mul_1], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] +# Source node to ATen node mapping: +# hidden_states => convert_element_type +# hidden_states_1 => mul_16 +# mul_1 => mul_23 +# pow_1 => pow_1 +# rsqrt => rsqrt +# to_1 => convert_element_type_1 +# variance => mean +# Graph fragment: +# %primals_4 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:7" = PlaceHolder[target=primals_4] +# %buf0 : Tensor "f32[s47, s87, 1][s87, 1, s47*s87]cuda:7" = PlaceHolder[target=buf0] +# %primals_5 : Tensor "f64[][]cpu" = PlaceHolder[target=primals_5] +# %primals_7 : Tensor "bf16[s33][1]cuda:7" = PlaceHolder[target=primals_7] +# %rsqrt : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:7" = PlaceHolder[target=rsqrt] +# %convert_element_type : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_4, torch.float32), kwargs = {}) +# %pow_1 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type, 2), kwargs = {}) +# %mean : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_1, [-1], True), kwargs = {}) +# %convert_element_type_default_1 : Tensor "f32[][]cpu"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_5, torch.float32), kwargs = {}) +# %add_tensor : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean, %convert_element_type_default_1), kwargs = {}) +# %rsqrt : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_tensor,), kwargs = {}) +# %mul_16 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %rsqrt), kwargs = {}) +# %convert_element_type_1 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_16, torch.bfloat16), kwargs = {}) +# %mul_23 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_7, %convert_element_type_1), kwargs = {}) +# return %buf0,%rsqrt,%mul_23 +triton_red_fused__to_copy_mean_mul_pow_rsqrt_0 = async_compile.triton('triton_red_fused__to_copy_mean_mul_pow_rsqrt_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*bf16', 'in_ptr1': 'fp64', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mean_mul_pow_rsqrt_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_mean_mul_pow_rsqrt_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tmp1 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp9 = in_ptr1 + tmp6 = ks0 + tmp7 = tmp6.to(tl.float32) + tmp8 = (tmp4 / tmp7) + tmp10 = tmp9.to(tl.float32) + tmp11 = tmp8 + tmp10 + tmp12 = libdevice.rsqrt(tmp11) + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, xmask) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp13 = tl.load(in_ptr2 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp14 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp15 = tmp14.to(tl.float32) + tmp16 = tmp15 * tmp12 + tmp17 = tmp16.to(tl.float32) + tmp18 = tmp13 * tmp17 + tl.store(out_ptr0 + (r0_1 + ks0*x0), tmp18, r0_mask & xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7 = args + args.clear() + s47 = primals_1 + s87 = primals_2 + s33 = primals_3 + s82 = primals_6 + assert_size_stride(primals_4, (s47, s87, s33), (s33*s87, s33, 1)) + assert_size_stride(primals_5, (), ()) + assert_size_stride(primals_7, (s33, ), (1, )) + with torch.cuda._DeviceGuard(7): + torch.cuda.set_device(7) + buf0 = empty_strided_cuda((s47, s87, 1), (s87, 1, s47*s87), torch.float32) + buf1 = reinterpret_tensor(buf0, (s47, s87, 1), (s87, 1, 1), 0); del buf0 # reuse + buf2 = empty_strided_cuda((s47, s87, s33), (s33*s87, s33, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [hidden_states, pow_1, variance, rsqrt, hidden_states_1, to_1, mul_1], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] + triton_red_fused__to_copy_mean_mul_pow_rsqrt_0_xnumel = s47*s87 + stream7 = get_raw_stream(7) + triton_red_fused__to_copy_mean_mul_pow_rsqrt_0.run(buf1, primals_4, primals_5.item(), primals_7, buf2, s33, triton_red_fused__to_copy_mean_mul_pow_rsqrt_0_xnumel, s33, stream=stream7) + del primals_5 + return (buf2, primals_4, primals_7, buf1, s47, s87, s33, s82, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 2 + primals_2 = 2048 + primals_3 = 4096 + primals_4 = rand_strided((2, 2048, 4096), (8388608, 4096, 1), device='cuda:7', dtype=torch.bfloat16) + primals_5 = rand_strided((), (), device='cpu', dtype=torch.float64) + primals_6 = 840433664 + primals_7 = rand_strided((4096, ), (1, ), device='cuda:7', dtype=torch.bfloat16) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/e5/ce5roatsezodqf5qpd32hkspt4fcfiqo7jobom7pzhsftuk27gji.py b/SpecForge-ext/cache/compiled_kernels/e5/ce5roatsezodqf5qpd32hkspt4fcfiqo7jobom7pzhsftuk27gji.py new file mode 100644 index 0000000000000000000000000000000000000000..72846a1fbef2b6aa270058a0022e4972bd73ca6b --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/e5/ce5roatsezodqf5qpd32hkspt4fcfiqo7jobom7pzhsftuk27gji.py @@ -0,0 +1,48 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 262144, 'r0_': 2097152000}} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = xindex // 2048 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + 65760000*x1), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, None) diff --git a/SpecForge-ext/cache/compiled_kernels/ej/cej7q625eupa7qxmjwyzq6vqbzktysdozpsvtlaa2uqslxq4wfvp.py b/SpecForge-ext/cache/compiled_kernels/ej/cej7q625eupa7qxmjwyzq6vqbzktysdozpsvtlaa2uqslxq4wfvp.py new file mode 100644 index 0000000000000000000000000000000000000000..124c86de8cb5680f5c492ab0251d9de255bec7e1 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ej/cej7q625eupa7qxmjwyzq6vqbzktysdozpsvtlaa2uqslxq4wfvp.py @@ -0,0 +1,410 @@ +# AOT ID: ['7_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xo/cxogqdjmha6a2mjw43q2sd56tohi4ncj3zpekedkws2baol4oehq.py +# Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax => argmax +# Graph fragment: +# %arg0_1 : Tensor "bf16[8, 2048, 32000][65536000, 32000, 1]cuda:0" = PlaceHolder[target=arg0_1] +# %argmax : Tensor "i64[8, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg0_1, -1), kwargs = {}) +# return %argmax +triton_red_fused_argmax_0 = async_compile.triton('triton_red_fused_argmax_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 262144, 'r0_': 1048576000}} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/qf/cqfc3szpcuoolzwoo3v5leduvtw3otipmc7qo5zxlbkaehkx5nuo.py +# Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax_1 => argmax_1 +# Graph fragment: +# %arg1_1 : Tensor "f32[8, 2048, 32000][65760000, 32000, 1]cuda:0" = PlaceHolder[target=arg1_1] +# %argmax_1 : Tensor "i64[8, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {}) +# return %argmax_1 +triton_red_fused_argmax_1 = async_compile.triton('triton_red_fused_argmax_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 262144, 'r0_': 2097152000}} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = xindex // 2048 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + 65760000*x1), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/mj/cmj7cw6kmumq5dhcjj5z77edfzandxlaftganyllgmz7p3fqeubv.py +# Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum] +# Source node to ATen node mapping: +# eq => eq +# mul => mul +# squeeze => squeeze +# sum_1 => sum_1 +# Graph fragment: +# %argmax : Tensor "i64[8, 2048][2048, 1]cuda:0" = PlaceHolder[target=argmax] +# %argmax_1 : Tensor "i64[8, 2048][2048, 1]cuda:0" = PlaceHolder[target=argmax_1] +# %arg2_1 : Tensor "i64[8, 2048, 1][2048, 1, 1]cuda:0" = PlaceHolder[target=arg2_1] +# %eq : Tensor "b8[8, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {}) +# %squeeze : Tensor "i64[8, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg2_1, -1), kwargs = {}) +# %mul : Tensor "i64[8, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq, %squeeze), kwargs = {}) +# %sum_1 : Tensor "i64[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul,), kwargs = {}) +# return %buf3 +triton_red_fused_eq_mul_squeeze_sum_2 = async_compile.triton('triton_red_fused_eq_mul_squeeze_sum_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2, 'r0_': 8192}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 8, 'r0_': 393216}} +) +@triton.jit +def triton_red_fused_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2 + r0_numel = 8192 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.load(in_ptr1 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp4 = tl.load(in_ptr2 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask & xmask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + tl.store(out_ptr0 + (x0), tmp7, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/q7/cq7qlfw752z7qztzmvc5idn6ikz7c7qhxrmppou7flmcrl63nd7h.py +# Topologically Sorted Source Nodes: [sum_2], Original ATen: [aten.sum] +# Source node to ATen node mapping: +# sum_2 => sum_2 +# Graph fragment: +# %arg3_1 : Tensor "i64[8, 2048, 1][2048, 1, 1]cuda:0" = PlaceHolder[target=arg3_1] +# %sum_2 : Tensor "i64[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg3_1,), kwargs = {}) +# return %buf5 +triton_red_fused_sum_3 = async_compile.triton('triton_red_fused_sum_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2, 'r0_': 8192}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 8, 'r0_': 131072}} +) +@triton.jit +def triton_red_fused_sum_3(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2 + r0_numel = 8192 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = _tmp2 + tmp1 + _tmp2 = tl.where(r0_mask & xmask, tmp3, _tmp2) + tmp2 = tl.sum(_tmp2, 1)[:, None] + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/az/cazc4elakae7tgyuygha6gaxmfo4ouj4mtb6kxylbj7524jvkqaz.py +# Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1, sum_2, clamp_min, truediv], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum, aten.clamp_min, aten.div] +# Source node to ATen node mapping: +# clamp_min => clamp_min +# eq => eq +# mul => mul +# squeeze => squeeze +# sum_1 => sum_1 +# sum_2 => sum_2 +# truediv => div +# Graph fragment: +# %buf3 : Tensor "i64[2][1]cuda:0" = PlaceHolder[target=buf3] +# %buf5 : Tensor "i64[2][1]cuda:0" = PlaceHolder[target=buf5] +# %sum_1 : Tensor "i64[][]cuda:0" = PlaceHolder[target=sum_1] +# %sum_2 : Tensor "i64[][]cuda:0" = PlaceHolder[target=sum_2] +# %eq : Tensor "b8[8, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {}) +# %squeeze : Tensor "i64[8, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg2_1, -1), kwargs = {}) +# %mul : Tensor "i64[8, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq, %squeeze), kwargs = {}) +# %sum_1 : Tensor "i64[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul,), kwargs = {}) +# %sum_2 : Tensor "i64[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg3_1,), kwargs = {}) +# %clamp_min : Tensor "f32[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%sum_2, 1e-06), kwargs = {}) +# %div : Tensor "f32[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, %clamp_min), kwargs = {}) +# return %sum_1,%sum_2,%div +triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4 = async_compile.triton('triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 1, 'r0_': 2}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'out_ptr2': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'r0_': 8}} +) +@triton.jit +def triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4(in_ptr0, in_ptr1, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 1 + r0_numel = 2 + R0_BLOCK: tl.constexpr = 2 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), None) + tmp4 = tl.load(in_ptr1 + (r0_0), None) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.sum(tmp1, 1)[:, None].to(tl.int64) + tmp5 = tl.broadcast_to(tmp4, [XBLOCK, R0_BLOCK]) + tmp7 = tl.sum(tmp5, 1)[:, None].to(tl.int64) + tmp8 = tmp3.to(tl.float32) + tmp9 = tmp7.to(tl.float32) + tmp10 = 1e-06 + tmp11 = triton_helpers.maximum(tmp9, tmp10) + tmp12 = (tmp8 / tmp11) + tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp12, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1 = args + args.clear() + assert_size_stride(arg0_1, (8, 2048, 32000), (65536000, 32000, 1)) + assert_size_stride(arg1_1, (8, 2048, 32000), (65760000, 32000, 1)) + assert_size_stride(arg2_1, (8, 2048, 1), (2048, 1, 1)) + assert_size_stride(arg3_1, (8, 2048, 1), (2048, 1, 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf0 = empty_strided_cuda((8, 2048), (2048, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] + stream0 = get_raw_stream(0) + triton_red_fused_argmax_0.run(arg0_1, buf0, 16384, 32000, stream=stream0) + del arg0_1 + buf1 = empty_strided_cuda((8, 2048), (2048, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] + stream0 = get_raw_stream(0) + triton_red_fused_argmax_1.run(arg1_1, buf1, 16384, 32000, stream=stream0) + del arg1_1 + buf3 = empty_strided_cuda((2, ), (1, ), torch.int64) + # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum] + stream0 = get_raw_stream(0) + triton_red_fused_eq_mul_squeeze_sum_2.run(buf0, buf1, arg2_1, buf3, 2, 8192, stream=stream0) + del arg2_1 + del buf0 + del buf1 + buf5 = empty_strided_cuda((2, ), (1, ), torch.int64) + # Topologically Sorted Source Nodes: [sum_2], Original ATen: [aten.sum] + stream0 = get_raw_stream(0) + triton_red_fused_sum_3.run(arg3_1, buf5, 2, 8192, stream=stream0) + del arg3_1 + buf7 = empty_strided_cuda((), (), torch.float32) + # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1, sum_2, clamp_min, truediv], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum, aten.clamp_min, aten.div] + stream0 = get_raw_stream(0) + triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4.run(buf3, buf5, buf7, 1, 2, stream=stream0) + del buf3 + del buf5 + return (buf7, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((8, 2048, 32000), (65536000, 32000, 1), device='cuda:0', dtype=torch.bfloat16) + arg1_1 = rand_strided((8, 2048, 32000), (65760000, 32000, 1), device='cuda:0', dtype=torch.float32) + arg2_1 = rand_strided((8, 2048, 1), (2048, 1, 1), device='cuda:0', dtype=torch.int64) + arg3_1 = rand_strided((8, 2048, 1), (2048, 1, 1), device='cuda:0', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/em/cema2d67vwnn3vhq5dmcrrnexptf3avprzdd2qc4la2ulmxodzzu.py b/SpecForge-ext/cache/compiled_kernels/em/cema2d67vwnn3vhq5dmcrrnexptf3avprzdd2qc4la2ulmxodzzu.py new file mode 100644 index 0000000000000000000000000000000000000000..fe8c31fa228acb9ce8ffedae88d9560e9f945271 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/em/cema2d67vwnn3vhq5dmcrrnexptf3avprzdd2qc4la2ulmxodzzu.py @@ -0,0 +1,410 @@ +# AOT ID: ['7_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/7p/c7ph4dk7ghsg37h7a46klnkhb6rck4rpgxyqg7fjyewxnxqk5vvs.py +# Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax => argmax +# Graph fragment: +# %arg0_1 : Tensor "bf16[8, 2048, 32000][65536000, 32000, 1]cuda:7" = PlaceHolder[target=arg0_1] +# %argmax : Tensor "i64[8, 2048][2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg0_1, -1), kwargs = {}) +# return %argmax +triton_red_fused_argmax_0 = async_compile.triton('triton_red_fused_argmax_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 262144, 'r0_': 1048576000}} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/lo/clokd3a6z24hxv5xnqanomelba6camwbo47pze4xlztxiuax6uqh.py +# Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax_1 => argmax_1 +# Graph fragment: +# %arg1_1 : Tensor "f32[8, 2048, 32000][65760000, 32000, 1]cuda:7" = PlaceHolder[target=arg1_1] +# %argmax_1 : Tensor "i64[8, 2048][2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {}) +# return %argmax_1 +triton_red_fused_argmax_1 = async_compile.triton('triton_red_fused_argmax_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 262144, 'r0_': 2097152000}} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = xindex // 2048 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + 65760000*x1), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/5d/c5dhltakqikiolalkhxr6qk33ibnkblchixuky4qhu3ngcdubsgg.py +# Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum] +# Source node to ATen node mapping: +# eq => eq +# mul => mul +# squeeze => squeeze +# sum_1 => sum_1 +# Graph fragment: +# %argmax : Tensor "i64[8, 2048][2048, 1]cuda:7" = PlaceHolder[target=argmax] +# %argmax_1 : Tensor "i64[8, 2048][2048, 1]cuda:7" = PlaceHolder[target=argmax_1] +# %arg2_1 : Tensor "i64[8, 2048, 1][2048, 1, 1]cuda:7" = PlaceHolder[target=arg2_1] +# %eq : Tensor "b8[8, 2048][2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {}) +# %squeeze : Tensor "i64[8, 2048][2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg2_1, -1), kwargs = {}) +# %mul : Tensor "i64[8, 2048][2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq, %squeeze), kwargs = {}) +# %sum_1 : Tensor "i64[][]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul,), kwargs = {}) +# return %buf3 +triton_red_fused_eq_mul_squeeze_sum_2 = async_compile.triton('triton_red_fused_eq_mul_squeeze_sum_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2, 'r0_': 8192}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 8, 'r0_': 393216}} +) +@triton.jit +def triton_red_fused_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2 + r0_numel = 8192 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.load(in_ptr1 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp4 = tl.load(in_ptr2 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask & xmask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + tl.store(out_ptr0 + (x0), tmp7, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/qa/cqasclcikvb2uryr7k2gtwdnliae55wql22q6kutfmldlk5e7kks.py +# Topologically Sorted Source Nodes: [sum_2], Original ATen: [aten.sum] +# Source node to ATen node mapping: +# sum_2 => sum_2 +# Graph fragment: +# %arg3_1 : Tensor "i64[8, 2048, 1][2048, 1, 1]cuda:7" = PlaceHolder[target=arg3_1] +# %sum_2 : Tensor "i64[][]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg3_1,), kwargs = {}) +# return %buf5 +triton_red_fused_sum_3 = async_compile.triton('triton_red_fused_sum_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2, 'r0_': 8192}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 8, 'r0_': 131072}} +) +@triton.jit +def triton_red_fused_sum_3(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2 + r0_numel = 8192 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = _tmp2 + tmp1 + _tmp2 = tl.where(r0_mask & xmask, tmp3, _tmp2) + tmp2 = tl.sum(_tmp2, 1)[:, None] + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/6s/c6sofobf6aibktijbyii34pvwp2u36pgbnmz5t2jtcp2eu7hxsct.py +# Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1, sum_2, clamp_min, truediv], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum, aten.clamp_min, aten.div] +# Source node to ATen node mapping: +# clamp_min => clamp_min +# eq => eq +# mul => mul +# squeeze => squeeze +# sum_1 => sum_1 +# sum_2 => sum_2 +# truediv => div +# Graph fragment: +# %buf3 : Tensor "i64[2][1]cuda:7" = PlaceHolder[target=buf3] +# %buf5 : Tensor "i64[2][1]cuda:7" = PlaceHolder[target=buf5] +# %sum_1 : Tensor "i64[][]cuda:7" = PlaceHolder[target=sum_1] +# %sum_2 : Tensor "i64[][]cuda:7" = PlaceHolder[target=sum_2] +# %eq : Tensor "b8[8, 2048][2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {}) +# %squeeze : Tensor "i64[8, 2048][2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg2_1, -1), kwargs = {}) +# %mul : Tensor "i64[8, 2048][2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq, %squeeze), kwargs = {}) +# %sum_1 : Tensor "i64[][]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul,), kwargs = {}) +# %sum_2 : Tensor "i64[][]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg3_1,), kwargs = {}) +# %clamp_min : Tensor "f32[][]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%sum_2, 1e-06), kwargs = {}) +# %div : Tensor "f32[][]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, %clamp_min), kwargs = {}) +# return %sum_1,%sum_2,%div +triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4 = async_compile.triton('triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 1, 'r0_': 2}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'out_ptr2': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'r0_': 8}} +) +@triton.jit +def triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4(in_ptr0, in_ptr1, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 1 + r0_numel = 2 + R0_BLOCK: tl.constexpr = 2 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), None) + tmp4 = tl.load(in_ptr1 + (r0_0), None) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.sum(tmp1, 1)[:, None].to(tl.int64) + tmp5 = tl.broadcast_to(tmp4, [XBLOCK, R0_BLOCK]) + tmp7 = tl.sum(tmp5, 1)[:, None].to(tl.int64) + tmp8 = tmp3.to(tl.float32) + tmp9 = tmp7.to(tl.float32) + tmp10 = 1e-06 + tmp11 = triton_helpers.maximum(tmp9, tmp10) + tmp12 = (tmp8 / tmp11) + tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp12, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1 = args + args.clear() + assert_size_stride(arg0_1, (8, 2048, 32000), (65536000, 32000, 1)) + assert_size_stride(arg1_1, (8, 2048, 32000), (65760000, 32000, 1)) + assert_size_stride(arg2_1, (8, 2048, 1), (2048, 1, 1)) + assert_size_stride(arg3_1, (8, 2048, 1), (2048, 1, 1)) + with torch.cuda._DeviceGuard(7): + torch.cuda.set_device(7) + buf0 = empty_strided_cuda((8, 2048), (2048, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] + stream7 = get_raw_stream(7) + triton_red_fused_argmax_0.run(arg0_1, buf0, 16384, 32000, stream=stream7) + del arg0_1 + buf1 = empty_strided_cuda((8, 2048), (2048, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] + stream7 = get_raw_stream(7) + triton_red_fused_argmax_1.run(arg1_1, buf1, 16384, 32000, stream=stream7) + del arg1_1 + buf3 = empty_strided_cuda((2, ), (1, ), torch.int64) + # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum] + stream7 = get_raw_stream(7) + triton_red_fused_eq_mul_squeeze_sum_2.run(buf0, buf1, arg2_1, buf3, 2, 8192, stream=stream7) + del arg2_1 + del buf0 + del buf1 + buf5 = empty_strided_cuda((2, ), (1, ), torch.int64) + # Topologically Sorted Source Nodes: [sum_2], Original ATen: [aten.sum] + stream7 = get_raw_stream(7) + triton_red_fused_sum_3.run(arg3_1, buf5, 2, 8192, stream=stream7) + del arg3_1 + buf7 = empty_strided_cuda((), (), torch.float32) + # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1, sum_2, clamp_min, truediv], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum, aten.clamp_min, aten.div] + stream7 = get_raw_stream(7) + triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4.run(buf3, buf5, buf7, 1, 2, stream=stream7) + del buf3 + del buf5 + return (buf7, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((8, 2048, 32000), (65536000, 32000, 1), device='cuda:7', dtype=torch.bfloat16) + arg1_1 = rand_strided((8, 2048, 32000), (65760000, 32000, 1), device='cuda:7', dtype=torch.float32) + arg2_1 = rand_strided((8, 2048, 1), (2048, 1, 1), device='cuda:7', dtype=torch.int64) + arg3_1 = rand_strided((8, 2048, 1), (2048, 1, 1), device='cuda:7', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/eu/cb8b28aa80dee4e265bf38d34b96174bb2cfcc005f4ec630c4d7aa4cb0327c66.best_config b/SpecForge-ext/cache/compiled_kernels/eu/cb8b28aa80dee4e265bf38d34b96174bb2cfcc005f4ec630c4d7aa4cb0327c66.best_config new file mode 100644 index 0000000000000000000000000000000000000000..6d7e6fca6219b05b2bbb28ccc5c57a35a1dbc088 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/eu/cb8b28aa80dee4e265bf38d34b96174bb2cfcc005f4ec630c4d7aa4cb0327c66.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "R0_BLOCK": 2048, "num_warps": 16, "num_stages": 1, "configs_hash": "8c03dc2e05d158372838fe4d32248dfba74b467c7576f6e1d3eb472c41b37c80", "found_by_coordesc": false, "time_taken_ms": 214, "triton_cache_hash": "VBVRCEQLKQI4X4GYXD4JC6UEYZT2F7LIKNA2UR4GNVIWAPM6GKFA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/eu/ceu3thmw4b2sxynuibvt3bgyszc3xjkaelupjne2jlinm3wf3nzj.py b/SpecForge-ext/cache/compiled_kernels/eu/ceu3thmw4b2sxynuibvt3bgyszc3xjkaelupjne2jlinm3wf3nzj.py new file mode 100644 index 0000000000000000000000000000000000000000..ec785f25cf5ff437c2d1ed16da200260cbaf020c --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/eu/ceu3thmw4b2sxynuibvt3bgyszc3xjkaelupjne2jlinm3wf3nzj.py @@ -0,0 +1,307 @@ +# AOT ID: ['4_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ka/ckakj5oalaoxcoqpyrpbjs75fxr6p2sp5eymgmxpomaulof2nifw.py +# Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul, x1, x2, neg, cat, mul_1, q_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] +# Source node to ATen node mapping: +# cat => cat +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# mul => mul_24 +# mul_1 => mul_45 +# neg => neg +# q_embed => add_54 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# x1 => slice_1 +# x2 => slice_2 +# Graph fragment: +# %primals_12 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:7" = PlaceHolder[target=primals_12] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:7" = PlaceHolder[target=primals_8] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:7" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:7" = PlaceHolder[target=primals_6] +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %mul_24 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_12, %unsqueeze), kwargs = {}) +# %slice_1 : Tensor "bf16[s48, s34, s9, (s24//2)][s24*s34*s9, s24, s24*s34, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_12, 3, 0, %floordiv), kwargs = {}) +# %slice_2 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s24*s34*s9, s24, s24*s34, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_12, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %neg : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s34*s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), s34*Max(1, s24 - ((s24//2))), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_2,), kwargs = {}) +# %cat : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg, %slice_1], -1), kwargs = {}) +# %mul_45 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat, %unsqueeze_1), kwargs = {}) +# %add_54 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_24, %mul_45), kwargs = {}) +# return %add_54 +triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0 = async_compile.triton('triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 67108864}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/zj/czjkm7k57yqvjckfd6v5afshsyjeafn25srmkljdz6wbhdiqqu7l.py +# Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul_2, x1_1, x2_1, neg_1, cat_1, mul_3, k_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] +# Source node to ATen node mapping: +# cat_1 => cat_1 +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# k_embed => add_90 +# mul_2 => mul_54 +# mul_3 => mul_75 +# neg_1 => neg_1 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# x1_1 => slice_3 +# x2_1 => slice_4 +# Graph fragment: +# %primals_13 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24, s24*s48, 1]cuda:7" = PlaceHolder[target=primals_13] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:7" = PlaceHolder[target=primals_8] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:7" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:7" = PlaceHolder[target=primals_6] +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %mul_54 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24, s24*s48, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_13, %unsqueeze), kwargs = {}) +# %slice_3 : Tensor "bf16[s48, s48, s9, (s24//2)][s24*s48*s9, s24, s24*s48, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_13, 3, 0, %floordiv), kwargs = {}) +# %slice_4 : Tensor "bf16[s48, s48, s9, s24 - ((s24//2))][s24*s48*s9, s24, s24*s48, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_13, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %neg_1 : Tensor "bf16[s48, s48, s9, s24 - ((s24//2))][s48*s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), s48*Max(1, s24 - ((s24//2))), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_4,), kwargs = {}) +# %cat_1 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_1, %slice_3], -1), kwargs = {}) +# %mul_75 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_1, %unsqueeze_1), kwargs = {}) +# %add_90 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24, s24*s48, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_54, %mul_75), kwargs = {}) +# return %add_90 +triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1 = async_compile.triton('triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13 = args + args.clear() + s92 = primals_1 + s24 = primals_2 + s96 = primals_3 + s79 = primals_5 + s9 = primals_7 + s38 = primals_9 + s48 = primals_10 + s34 = primals_11 + assert_size_stride(primals_4, (1, 1, s92, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_6, (1, 1, s79, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_8, (1, s9), (s9, 1)) + assert_size_stride(primals_12, (s48, s34, s9, s24), (s24*s34*s9, s24, s24*s34, 1)) + assert_size_stride(primals_13, (s48, s48, s9, s24), (s24*s48*s9, s24, s24*s48, 1)) + with torch.cuda._DeviceGuard(7): + torch.cuda.set_device(7) + ps0 = s24*s34 + buf0 = empty_strided_cuda((s48, s34, s9, s24), (s24*s34*s9, s24, s24*s34, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul, x1, x2, neg, cat, mul_1, q_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0_xnumel = s24*s34*s48*s9 + stream7 = get_raw_stream(7) + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0.run(primals_12, primals_8, primals_4, primals_6, buf0, ps0, s9, s92, s24, s79, triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0_xnumel, stream=stream7) + del primals_12 + ps1 = s24*s48 + buf1 = empty_strided_cuda((s48, s48, s9, s24), (s24*s48*s9, s24, s24*s48, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul_2, x1_1, x2_1, neg_1, cat_1, mul_3, k_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1_xnumel = s24*s9*s48*s48 + stream7 = get_raw_stream(7) + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1.run(primals_13, primals_8, primals_4, primals_6, buf1, ps1, s9, s92, s24, s79, triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1_xnumel, stream=stream7) + del primals_13 + return (buf0, buf1, primals_4, primals_6, primals_8, s24, s9, s48, s34, s92, s96, s79, s24 // 2, s24 + (-1)*(s24 // 2), ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 2048 + primals_2 = 128 + primals_3 = 5245440 + primals_4 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:7', dtype=torch.bfloat16) + primals_5 = 2048 + primals_6 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:7', dtype=torch.bfloat16) + primals_7 = 2048 + primals_8 = rand_strided((1, 2048), (2048, 1), device='cuda:7', dtype=torch.int64) + primals_9 = 1 + primals_10 = 8 + primals_11 = 32 + primals_12 = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16) + primals_13 = rand_strided((8, 8, 2048, 128), (2097152, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/eu/ceuhopmcdleig6m43h7kk4fhghkl5w2umfjuyngxydc4pr3zpumg.py b/SpecForge-ext/cache/compiled_kernels/eu/ceuhopmcdleig6m43h7kk4fhghkl5w2umfjuyngxydc4pr3zpumg.py new file mode 100644 index 0000000000000000000000000000000000000000..b5cfa0cbf94983b2c854e5cdfb1159a1528c45c3 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/eu/ceuhopmcdleig6m43h7kk4fhghkl5w2umfjuyngxydc4pr3zpumg.py @@ -0,0 +1,41 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2, 'r0_': 8192}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 8, 'r0_': 131072}} +) +@triton.jit +def triton_red_fused_sum_3(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2 + r0_numel = 8192 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = _tmp2 + tmp1 + _tmp2 = tl.where(r0_mask & xmask, tmp3, _tmp2) + tmp2 = tl.sum(_tmp2, 1)[:, None] + tl.store(out_ptr0 + (x0), tmp2, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/eu/ceui6qrb2t3lmzs3ljrqtcomt4b2q6svzo24j6mmryaiovr6kp7y.py b/SpecForge-ext/cache/compiled_kernels/eu/ceui6qrb2t3lmzs3ljrqtcomt4b2q6svzo24j6mmryaiovr6kp7y.py new file mode 100644 index 0000000000000000000000000000000000000000..09d1f43dc101b9139bec000b71eeb904635032fa --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/eu/ceui6qrb2t3lmzs3ljrqtcomt4b2q6svzo24j6mmryaiovr6kp7y.py @@ -0,0 +1,47 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + ks1*x1), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/ez/ceztpk557ohveef4x2fpr7mhqfmutptjkv42bvrokmsifscetbsn.py b/SpecForge-ext/cache/compiled_kernels/ez/ceztpk557ohveef4x2fpr7mhqfmutptjkv42bvrokmsifscetbsn.py new file mode 100644 index 0000000000000000000000000000000000000000..5bcf60239e28efd6d8f0e7f3a9071159b1b8d6e5 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ez/ceztpk557ohveef4x2fpr7mhqfmutptjkv42bvrokmsifscetbsn.py @@ -0,0 +1,543 @@ +# AOT ID: ['5_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/e6/ce65awkeaxcxjqfa27pcogsy3sjyxwzxjt3w2rte76m7izgybp2s.py +# Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_2, mask_3, mask_block_sum], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.permute, aten.sum] +# Source node to ATen node mapping: +# and_2 => bitwise_and_1 +# and_3 => bitwise_and_2 +# and_4 => bitwise_and_3, view_8 +# b => iota +# batched_outputs_2 => view_9 +# causal_mask => ge, view +# diagnol_mask => eq +# index => index +# index_1 => index_1 +# index_2 => index_2 +# lt => lt, view_1 +# lt_1 => lt_1, view_2 +# m => iota_2 +# mask_2 => view_10 +# mask_3 => permute +# mask_block_sum => sum_1 +# n => iota_3 +# padding_mask => bitwise_and, view_3, view_4 +# padding_mask_1 => lt_2, view_6 +# remainder => remainder +# remainder_1 => remainder_1 +# result_1 => bitwise_or, full_default +# result_2 => bitwise_or_1 +# sub => sub, view_7 +# suffix_mask => ge_1 +# Graph fragment: +# %arg0_1 : Tensor "i64[2][1]cuda:0" = PlaceHolder[target=arg0_1] +# %full_default : Tensor "b8[2, 1, 1][1, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 1], False), kwargs = {dtype: torch.bool, layout: torch.strided, device: cuda:0, pin_memory: False}) +# %iota_2 : Tensor "i64[2048][1]cuda:0"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (2048,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False}) +# %view : Tensor "i64[2048, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {}) +# %iota_3 : Tensor "i64[2048][1]cuda:0"[num_users=5] = call_function[target=torch.ops.prims.iota.default](args = (2048,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False}) +# %ge : Tensor "b8[2048, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.ge.Tensor](args = (%view, %iota_3), kwargs = {}) +# %iota : Tensor "i64[2][1]cuda:0"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (2,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False}) +# %index : Tensor "i64[2][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {}) +# %view_1 : Tensor "i64[2, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index, [2, 1]), kwargs = {}) +# %lt : Tensor "b8[2, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_3, %view_1), kwargs = {}) +# %view_4 : Tensor "b8[2, 1, 2048][2048, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt, [2, 1, 2048]), kwargs = {}) +# %index_1 : Tensor "i64[2][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {}) +# %view_2 : Tensor "i64[2, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_1, [2, 1]), kwargs = {}) +# %lt_1 : Tensor "b8[2, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_2, %view_2), kwargs = {}) +# %view_3 : Tensor "b8[2, 2048, 1][2048, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt_1, [2, 2048, 1]), kwargs = {}) +# %bitwise_and : Tensor "b8[2, 2048, 2048][4194304, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_4, %view_3), kwargs = {}) +# %bitwise_and_1 : Tensor "b8[2, 2048, 2048][4194304, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge, %bitwise_and), kwargs = {}) +# %bitwise_or : Tensor "b8[2, 2048, 2048][4194304, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%full_default, %bitwise_and_1), kwargs = {}) +# %ge_1 : Tensor "b8[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%iota_3, 2048), kwargs = {}) +# %remainder : Tensor "i64[2048][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%iota_3, 2048), kwargs = {}) +# %index_2 : Tensor "i64[2][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {}) +# %view_6 : Tensor "i64[2, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_2, [2, 1]), kwargs = {}) +# %lt_2 : Tensor "b8[2, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%remainder, %view_6), kwargs = {}) +# %bitwise_and_2 : Tensor "b8[2, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_1, %lt_2), kwargs = {}) +# %view_8 : Tensor "b8[2, 1, 2048][2048, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_and_2, [2, 1, 2048]), kwargs = {}) +# %view_7 : Tensor "i64[2048, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {}) +# %sub : Tensor "i64[2048, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%iota_3, %view_7), kwargs = {}) +# %remainder_1 : Tensor "i64[2048, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%sub, 2048), kwargs = {}) +# %eq : Tensor "b8[2048, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%remainder_1, 0), kwargs = {}) +# %bitwise_and_3 : Tensor "b8[2, 2048, 2048][4194304, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_8, %eq), kwargs = {}) +# %bitwise_or_1 : Tensor "b8[2, 2048, 2048][4194304, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%bitwise_or, %bitwise_and_3), kwargs = {}) +# %view_9 : Tensor "b8[2, 1, 2048, 2048][4194304, 4194304, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_or_1, [2, 1, 2048, 2048]), kwargs = {}) +# %view_10 : Tensor "b8[2, 1, 16, 128, 16, 128][4194304, 4194304, 262144, 2048, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand, [2, 1, 16, 128, 16, 128]), kwargs = {}) +# %permute : Tensor "b8[2, 1, 16, 16, 128, 128][4194304, 4194304, 262144, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 1, 2, 4, 3, 5]), kwargs = {}) +# %sum_1 : Tensor "i64[2, 1, 16, 16][256, 256, 16, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute, [-2, -1]), kwargs = {}) +# return %sum_1 +triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0 = async_compile.triton('triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 512, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 8192, 'r0_': 0}} +) +@triton.jit +def triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 512 + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = ((xindex // 16) % 16) + x0 = (xindex % 16) + x2 = xindex // 256 + tmp3 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + _tmp29 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x6 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_4 = r0_index // 128 + r0_3 = (r0_index % 128) + tmp0 = r0_4 + 128*x1 + tmp1 = r0_3 + 128*x0 + tmp2 = tmp0 >= tmp1 + tmp4 = tmp1 < tmp3 + tmp5 = tmp0 < tmp3 + tmp6 = tmp4 & tmp5 + tmp7 = tmp2 & tmp6 + tmp8 = tl.full([1, 1], False, tl.int1) + tmp9 = tmp8 | tmp7 + tmp10 = tl.full([1, 1], 2048, tl.int64) + tmp11 = tmp1 >= tmp10 + tmp12 = tmp11 & tmp4 + tmp13 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp14 = (tmp13 % tmp10) + tmp15 = tl.full([1, 1], 0, tl.int32) + tmp16 = tmp14 != tmp15 + tmp17 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp18 = (libdevice.signbit(tmp10) != 0) if (tmp10).dtype is tl.float32 else tmp10 < 0 + tmp19 = tmp17 != tmp18 + tmp20 = tmp16 & tmp19 + tmp21 = tmp14 + tmp10 + tmp22 = tl.where(tmp20, tmp21, tmp14) + tmp23 = tl.full([1, 1], 0, tl.int64) + tmp24 = tmp22 == tmp23 + tmp25 = tmp12 & tmp24 + tmp26 = tmp9 | tmp25 + tmp27 = tmp26.to(tl.int64) + tmp28 = tl.broadcast_to(tmp27, [XBLOCK, R0_BLOCK]) + tmp30 = _tmp29 + tmp28 + _tmp29 = tl.where(r0_mask & xmask, tmp30, _tmp29) + tmp29 = tl.sum(_tmp29, 1)[:, None] + tl.store(out_ptr0 + (x6), tmp29, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/4n/c4ntlraqki6522y3kmq7crnap6gq5asdu5huu7r2d7hvfkgash6w.py +# Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros] +# Source node to ATen node mapping: +# dense_mask_4 => full_default_4 +# Graph fragment: +# %full_default_4 : Tensor "i32[2, 1, 16, 17][272, 272, 17, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:0, pin_memory: False}) +# return %index_put_1 +triton_poi_fused_new_zeros_1 = async_compile.triton('triton_poi_fused_new_zeros_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 1024}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 4352}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_1(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 544 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py +# Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices, full_blocks, full_blocks_1, dense_mask_1, col_indices_1, dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices, dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, num_blocks_in_row_1, child_7, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort, aten.eq, aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten.scalar_tensor, aten.where, aten.view, aten.index_put] +# Source node to ATen node mapping: +# arange_4 => iota_4 +# arange_6 => iota_8 +# child_3 => convert_element_type_3 +# child_4 => convert_element_type_4 +# child_7 => convert_element_type_6 +# child_8 => convert_element_type_7 +# col_indices => sort +# col_indices_1 => sort_1 +# col_range => iota_5 +# col_range_1 => iota_9 +# dense_mask => convert_element_type_2 +# dense_mask_1 => convert_element_type_5 +# dense_mask_2 => full_default_1 +# dense_mask_4 => full_default_4 +# full_blocks => eq_1 +# full_blocks_1 => convert_element_type_1 +# gt => gt +# index_mask => lt_4 +# index_mask_1 => lt_5 +# lt_3 => lt_3 +# num_blocks_in_row => sum_2 +# num_blocks_in_row_1 => sum_3 +# partial_blocks => bitwise_and_4 +# partial_blocks_1 => convert_element_type +# row_indices => unsqueeze +# row_indices_1 => unsqueeze_7 +# setitem => full_default_3, index_put, iota_6, iota_7, unsqueeze_2, unsqueeze_3, unsqueeze_4, unsqueeze_5, unsqueeze_6 +# setitem_1 => full_default_6, index_put_1, iota_10, iota_11, unsqueeze_10, unsqueeze_11, unsqueeze_12, unsqueeze_13, unsqueeze_9 +# unsqueeze_1 => unsqueeze_1 +# unsqueeze_3 => unsqueeze_8 +# valid_indices => full_default_2, where +# valid_indices_1 => full_default_5, where_1 +# Graph fragment: +# %sum_1 : Tensor "i64[2, 1, 16, 16][256, 512, 16, 1]cuda:0" = PlaceHolder[target=sum_1] +# %sum_2 : Tensor "i64[2, 1, 16][16, 32, 1]cuda:0" = PlaceHolder[target=sum_2] +# %sum_3 : Tensor "i64[2, 1, 16][16, 32, 1]cuda:0" = PlaceHolder[target=sum_3] +# %buf2 : Tensor "i16[2, 1, 16, 16][256, 512, 16, 1]cuda:0" = PlaceHolder[target=buf2] +# %convert_element_type_3 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:0" = PlaceHolder[target=convert_element_type_3] +# %convert_element_type_4 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:0" = PlaceHolder[target=convert_element_type_4] +# %index_put : Tensor "i32[2, 1, 16, 17][272, 272, 17, 1]cuda:0" = PlaceHolder[target=index_put] +# %buf4 : Tensor "i16[2, 1, 16, 16][256, 512, 16, 1]cuda:0" = PlaceHolder[target=buf4] +# %convert_element_type_6 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:0" = PlaceHolder[target=convert_element_type_6] +# %convert_element_type_7 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:0" = PlaceHolder[target=convert_element_type_7] +# %index_put_1 : Tensor "i32[2, 1, 16, 17][272, 272, 17, 1]cuda:0" = PlaceHolder[target=index_put_1] +# %gt : Tensor "b8[2, 1, 16, 16][256, 256, 16, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {}) +# %lt_3 : Tensor "b8[2, 1, 16, 16][256, 256, 16, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %bitwise_and_4 : Tensor "b8[2, 1, 16, 16][256, 256, 16, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%gt, %lt_3), kwargs = {}) +# %convert_element_type : Tensor "i8[2, 1, 16, 16][256, 256, 16, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%bitwise_and_4, torch.int8), kwargs = {}) +# %convert_element_type_2 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type, torch.int32), kwargs = {}) +# %sort : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%convert_element_type_2,), kwargs = {stable: True, descending: True}) +# %eq_1 : Tensor "b8[2, 1, 16, 16][256, 256, 16, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %convert_element_type_1 : Tensor "i8[2, 1, 16, 16][256, 256, 16, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%eq_1, torch.int8), kwargs = {}) +# %convert_element_type_5 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_1, torch.int32), kwargs = {}) +# %sort_1 : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%convert_element_type_5,), kwargs = {stable: True, descending: True}) +# %full_default_1 : Tensor "i32[2, 1, 16, 17][272, 272, 17, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:0, pin_memory: False}) +# %iota_7 : Tensor "i64[2][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (2,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False}) +# %unsqueeze_4 : Tensor "i64[2, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_7, -1), kwargs = {}) +# %unsqueeze_5 : Tensor "i64[2, 1, 1][1, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_4, -1), kwargs = {}) +# %unsqueeze_6 : Tensor "i64[2, 1, 1, 1][1, 1, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_5, -1), kwargs = {}) +# %iota_6 : Tensor "i64[1][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False}) +# %unsqueeze_2 : Tensor "i64[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_6, -1), kwargs = {}) +# %unsqueeze_3 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_2, -1), kwargs = {}) +# %iota_4 : Tensor "i32[16][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:0, requires_grad: False}) +# %unsqueeze : Tensor "i32[16, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_4, -1), kwargs = {}) +# %iota_5 : Tensor "i32[16][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:0, requires_grad: False}) +# %sum_2 : Tensor "i64[2, 1, 16][16, 16, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_2, [-1]), kwargs = {}) +# %convert_element_type_3 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_2, torch.int32), kwargs = {}) +# %unsqueeze_1 : Tensor "i32[2, 1, 16, 1][16, 16, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_3, 3), kwargs = {}) +# %lt_4 : Tensor "b8[2, 1, 16, 16][256, 256, 16, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_5, %unsqueeze_1), kwargs = {}) +# %convert_element_type_4 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_1, torch.int32), kwargs = {}) +# %full_default_2 : Tensor "i32[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 16), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:0, pin_memory: False}) +# %where : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_4, %convert_element_type_4, %full_default_2), kwargs = {}) +# %full_default_3 : Tensor "i32[2, 1, 1, 1][1, 1, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:0, pin_memory: False}) +# %index_put : Tensor "i32[2, 1, 16, 17][272, 272, 17, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_1, [%unsqueeze_6, %unsqueeze_3, %unsqueeze, %where], %full_default_3), kwargs = {}) +# %full_default_4 : Tensor "i32[2, 1, 16, 17][272, 272, 17, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:0, pin_memory: False}) +# %iota_11 : Tensor "i64[2][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (2,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False}) +# %unsqueeze_11 : Tensor "i64[2, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_11, -1), kwargs = {}) +# %unsqueeze_12 : Tensor "i64[2, 1, 1][1, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_11, -1), kwargs = {}) +# %unsqueeze_13 : Tensor "i64[2, 1, 1, 1][1, 1, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_12, -1), kwargs = {}) +# %iota_10 : Tensor "i64[1][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False}) +# %unsqueeze_9 : Tensor "i64[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_10, -1), kwargs = {}) +# %unsqueeze_10 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_9, -1), kwargs = {}) +# %iota_8 : Tensor "i32[16][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:0, requires_grad: False}) +# %unsqueeze_7 : Tensor "i32[16, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_8, -1), kwargs = {}) +# %iota_9 : Tensor "i32[16][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:0, requires_grad: False}) +# %sum_3 : Tensor "i64[2, 1, 16][16, 16, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_5, [-1]), kwargs = {}) +# %convert_element_type_6 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_3, torch.int32), kwargs = {}) +# %unsqueeze_8 : Tensor "i32[2, 1, 16, 1][16, 16, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_6, 3), kwargs = {}) +# %lt_5 : Tensor "b8[2, 1, 16, 16][256, 256, 16, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_9, %unsqueeze_8), kwargs = {}) +# %convert_element_type_7 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_3, torch.int32), kwargs = {}) +# %full_default_5 : Tensor "i32[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 16), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:0, pin_memory: False}) +# %where_1 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_5, %convert_element_type_7, %full_default_5), kwargs = {}) +# %full_default_6 : Tensor "i32[2, 1, 1, 1][1, 1, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:0, pin_memory: False}) +# %index_put_1 : Tensor "i32[2, 1, 16, 17][272, 272, 17, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_4, [%unsqueeze_13, %unsqueeze_10, %unsqueeze_7, %where_1], %full_default_6), kwargs = {}) +# return %buf2,%buf4,%sum_2,%sum_3,%convert_element_type_3,%convert_element_type_6,%convert_element_type_4,%buf9,%convert_element_type_7,%buf16 +triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2 = async_compile.triton('triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr4': '*i32', 'out_ptr5': '*i32', 'out_ptr6': '*i32', 'out_ptr7': '*i32', 'out_ptr8': '*i32', 'out_ptr9': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr7', 'out_ptr9'], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(in_ptr0, out_ptr4, out_ptr5, out_ptr6, out_ptr7, out_ptr8, out_ptr9, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 32 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + tmp0 = tl.load(in_ptr0 + (r0_1 + 16*x0), xmask, other=0.0) + tmp1 = tl.full([1, 1], 0, tl.int64) + tmp2 = tmp0 > tmp1 + tmp3 = tl.full([1, 1], 16384, tl.int64) + tmp4 = tmp0 < tmp3 + tmp5 = tmp2 & tmp4 + tmp6 = tmp5.to(tl.int8) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8.to(tl.int16) + tmp10 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp11 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12, tmp13, = triton_helpers.sort_with_index(tmp10, tmp11, None, 1, stable=True, descending=True) + tmp14 = tmp0 == tmp3 + tmp15 = tmp14.to(tl.int8) + tmp16 = tmp15.to(tl.int32) + tmp17 = tl.broadcast_to(tmp16, [XBLOCK, R0_BLOCK]) + tmp18, tmp19, = triton_helpers.sort_with_index(tmp17, tmp11, None, 1, stable=True, descending=True) + tmp20 = tmp7.to(tl.int64) + tmp21 = tl.broadcast_to(tmp20, [XBLOCK, R0_BLOCK]) + tmp23 = tl.where(xmask, tmp21, 0) + tmp24 = tl.sum(tmp23, 1)[:, None].to(tl.int64) + tmp25 = tmp16.to(tl.int64) + tmp26 = tl.broadcast_to(tmp25, [XBLOCK, R0_BLOCK]) + tmp28 = tl.where(xmask, tmp26, 0) + tmp29 = tl.sum(tmp28, 1)[:, None].to(tl.int64) + tmp30 = tmp24.to(tl.int32) + tmp31 = tmp29.to(tl.int32) + tmp32 = tmp13.to(tl.int64) + tmp33 = tmp32.to(tl.int32) + tmp34 = tmp8 < tmp30 + tmp35 = tl.full([1, 1], 16, tl.int32) + tmp36 = tl.where(tmp34, tmp33, tmp35) + tmp37 = tl.full([XBLOCK, R0_BLOCK], 17, tl.int32) + tmp38 = tmp36 + tmp37 + tmp39 = tmp36 < 0 + tmp40 = tl.where(tmp39, tmp38, tmp36) + tl.device_assert(((0 <= tmp40) & (tmp40 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp40 < 17") + tmp42 = tl.full([1, 1], 1, tl.int32) + tmp43 = tmp19.to(tl.int64) + tmp44 = tmp43.to(tl.int32) + tmp45 = tmp8 < tmp31 + tmp46 = tl.where(tmp45, tmp44, tmp35) + tmp47 = tmp46 + tmp37 + tmp48 = tmp46 < 0 + tmp49 = tl.where(tmp48, tmp47, tmp46) + tl.device_assert(((0 <= tmp49) & (tmp49 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp49 < 17") + tl.store(out_ptr4 + (x0), tmp30, xmask) + tl.store(out_ptr5 + (x0), tmp31, xmask) + tl.store(out_ptr6 + (r0_1 + 16*x0), tmp33, xmask) + tl.store(out_ptr7 + (tl.broadcast_to(tmp40 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) + tl.store(out_ptr8 + (r0_1 + 16*x0), tmp44, xmask) + tl.store(out_ptr9 + (tl.broadcast_to(tmp49 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/qo/cqojmb5e4b5iomuis3bstfp3rn23xoq2xegyr72zgkrjnzktbusv.py +# Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] +# Source node to ATen node mapping: +# batched_outputs_3 => clone_4, slice_2 +# col_indices_2 => sort_2 +# num_blocks_in_row_2 => sum_4 +# q_indices => clone_6, convert_element_type_9 +# q_num_blocks => convert_element_type_8 +# transpose => permute_1 +# Graph fragment: +# %buf9 : Tensor "i32[2, 1, 16, 17][272, 272, 17, 1]cuda:0" = PlaceHolder[target=buf9] +# %buf11 : Tensor "i16[2, 1, 16, 16][256, 512, 16, 1]cuda:0" = PlaceHolder[target=buf11] +# %sum_4 : Tensor "i64[2, 1, 16][16, 32, 1]cuda:0" = PlaceHolder[target=sum_4] +# %slice_2 : Tensor "i32[2, 1, 16, 16][272, 272, 17, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, 16), kwargs = {}) +# %clone_4 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_2,), kwargs = {memory_format: torch.contiguous_format}) +# %permute_1 : Tensor "i32[2, 1, 16, 16][256, 256, 1, 16]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%clone_4, [0, 1, 3, 2]), kwargs = {}) +# %sort_2 : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%permute_1,), kwargs = {stable: True, descending: True}) +# %convert_element_type_9 : Tensor "i32[2, 1, 16, 16][256, 256, 1, 16]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_5, torch.int32), kwargs = {}) +# %clone_6 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_9,), kwargs = {memory_format: torch.contiguous_format}) +# %sum_4 : Tensor "i64[2, 1, 16][16, 16, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute_1, [-1]), kwargs = {}) +# %convert_element_type_8 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_4, torch.int32), kwargs = {}) +# return %buf11,%sum_4,%clone_6,%convert_element_type_8 +triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3 = async_compile.triton('triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 256, 'r0_': 4096}} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 32 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % 16) + x1 = xindex // 16 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + 17*r0_2 + 272*x1), xmask, other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x3), tmp13, xmask) + tl.store(out_ptr3 + (x3), tmp14, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, = args + args.clear() + assert_size_stride(arg0_1, (2, ), (1, )) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf0 = empty_strided_cuda((2, 1, 16, 16), (256, 512, 16, 1), torch.int64) + # Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_2, mask_3, mask_block_sum], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.permute, aten.sum] + stream0 = get_raw_stream(0) + triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0.run(arg0_1, buf0, 512, 16384, stream=stream0) + del arg0_1 + buf15 = empty_strided_cuda((2, 1, 16, 17), (272, 272, 17, 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros] + stream0 = get_raw_stream(0) + triton_poi_fused_new_zeros_1.run(buf15, 544, stream=stream0) + buf8 = empty_strided_cuda((2, 1, 16, 17), (272, 272, 17, 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] + stream0 = get_raw_stream(0) + triton_poi_fused_new_zeros_1.run(buf8, 544, stream=stream0) + buf6 = empty_strided_cuda((2, 1, 16), (16, 16, 1), torch.int32) + buf13 = empty_strided_cuda((2, 1, 16), (16, 16, 1), torch.int32) + buf7 = empty_strided_cuda((2, 1, 16, 16), (256, 256, 16, 1), torch.int32) + buf14 = empty_strided_cuda((2, 1, 16, 16), (256, 256, 16, 1), torch.int32) + # Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices, full_blocks, full_blocks_1, dense_mask_1, col_indices_1, dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices, dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, num_blocks_in_row_1, child_7, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort, aten.eq, aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + stream0 = get_raw_stream(0) + triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.run(buf0, buf6, buf13, buf7, buf8, buf14, buf15, 32, 16, stream=stream0) + del buf0 + buf22 = empty_strided_cuda((2, 1, 16, 16), (256, 256, 16, 1), torch.int32) + buf24 = empty_strided_cuda((2, 1, 16), (16, 16, 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] + stream0 = get_raw_stream(0) + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf8, buf22, buf24, 32, 16, stream=stream0) + del buf8 + buf19 = empty_strided_cuda((2, 1, 16, 16), (256, 256, 16, 1), torch.int32) + buf21 = empty_strided_cuda((2, 1, 16), (16, 16, 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, col_indices_3, full_q_indices, num_blocks_in_row_3, full_q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] + stream0 = get_raw_stream(0) + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf15, buf19, buf21, 32, 16, stream=stream0) + del buf15 + return (buf19, buf21, buf22, buf24, buf14, buf13, buf7, buf6, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((2, ), (1, ), device='cuda:0', dtype=torch.int64) + fn = lambda: call([arg0_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/f4/cf4r35cayrbjymekgfcgqjfnkedkuroswlxcm7c47gskkf72saio.py b/SpecForge-ext/cache/compiled_kernels/f4/cf4r35cayrbjymekgfcgqjfnkedkuroswlxcm7c47gskkf72saio.py new file mode 100644 index 0000000000000000000000000000000000000000..93855cc3e8b382ec36e6f1af6afa0337c4119f26 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/f4/cf4r35cayrbjymekgfcgqjfnkedkuroswlxcm7c47gskkf72saio.py @@ -0,0 +1,52 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 131072, 'r0_': 128}, + reduction_hint=ReductionHint.OUTER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mul_sum_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_mul_sum_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = xindex // ks0 + x0 = (xindex % ks0) + _tmp13 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = r0_2 + x1*((31 + ks1*ks2) // 32) + tmp1 = ks1*ks2 + tmp2 = tmp0 < tmp1 + tmp3 = tl.load(in_ptr0 + (x0 + ks0*(((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2)))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x0 + ks0*(((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2)))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp5 = tmp4.to(tl.float32) + tmp6 = tl.load(in_ptr2 + (((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp7 = tmp5 * tmp6 + tmp8 = tmp7.to(tl.float32) + tmp9 = tmp3 * tmp8 + tmp10 = tl.full(tmp9.shape, 0, tmp9.dtype) + tmp11 = tl.where(tmp2, tmp9, tmp10) + tmp12 = tl.broadcast_to(tmp11, [XBLOCK, R0_BLOCK]) + tmp14 = _tmp13 + tmp12 + _tmp13 = tl.where(r0_mask & xmask, tmp14, _tmp13) + tmp13 = tl.sum(_tmp13, 1)[:, None] + tl.store(out_ptr0 + (x3), tmp13, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/fp/cfpqdwmnrpuw6cpaelcqs7w2is6wrhnduqm3fuvr47rbdqaklmih.py b/SpecForge-ext/cache/compiled_kernels/fp/cfpqdwmnrpuw6cpaelcqs7w2is6wrhnduqm3fuvr47rbdqaklmih.py new file mode 100644 index 0000000000000000000000000000000000000000..5f9c409c88608c487715f16f8e0c26584adcab5f --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/fp/cfpqdwmnrpuw6cpaelcqs7w2is6wrhnduqm3fuvr47rbdqaklmih.py @@ -0,0 +1,24 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 512}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/fp/f3bd0f5de2150ac1d396c0adc63a1f424457c3954e6099a9227a3c50e0829d99.best_config b/SpecForge-ext/cache/compiled_kernels/fp/f3bd0f5de2150ac1d396c0adc63a1f424457c3954e6099a9227a3c50e0829d99.best_config new file mode 100644 index 0000000000000000000000000000000000000000..e55f7bfebe5b3c7f85806aa00c569d57b5dddc3c --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/fp/f3bd0f5de2150ac1d396c0adc63a1f424457c3954e6099a9227a3c50e0829d99.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "CYLQGL4LYUHWUHXWUOPJ5IHPQCUACWXMO577UIG54KOMYOQPA6IQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/fu/cfupjar3maclck7vvyyzejnr6bkivhdowjbp5qcswnnpaijbt7el.py b/SpecForge-ext/cache/compiled_kernels/fu/cfupjar3maclck7vvyyzejnr6bkivhdowjbp5qcswnnpaijbt7el.py new file mode 100644 index 0000000000000000000000000000000000000000..f1b318aab6327925328a55255e7746b5f5374592 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/fu/cfupjar3maclck7vvyyzejnr6bkivhdowjbp5qcswnnpaijbt7el.py @@ -0,0 +1,164 @@ +# AOT ID: ['0_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/y5/cy5vf4gnz7zle5ospiqwxpaakdrocn2nwwwhum3styir6fju6b6t.py +# Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul] +# Source node to ATen node mapping: +# getitem_1 => unsqueeze +# position_mask => mul +# target_mask => index +# target_mask_1 => convert_element_type +# target_max_token => argmax +# Graph fragment: +# %arg0_1 : Tensor "bf16[8, 2048, 151936][311164928, 151936, 1]cuda:0" = PlaceHolder[target=arg0_1] +# %argmax : Tensor "i64[8, 2048][2048, 1]cuda:0" = PlaceHolder[target=argmax] +# %arg1_1 : Tensor "b8[151936][1]cuda:0" = PlaceHolder[target=arg1_1] +# %arg2_1 : Tensor "i64[8, 2048, 1][2048, 1, 1]cuda:0" = PlaceHolder[target=arg2_1] +# %argmax : Tensor "i64[8, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg0_1, -1), kwargs = {}) +# %index : Tensor "b8[8, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg1_1, [%argmax]), kwargs = {}) +# %unsqueeze : Tensor "b8[8, 2048, 1][2048, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 2), kwargs = {}) +# %convert_element_type : Tensor "i32[8, 2048, 1][2048, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze, torch.int32), kwargs = {}) +# %mul : Tensor "i64[8, 2048, 1][2048, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %arg2_1), kwargs = {}) +# return %argmax,%mul +triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0 = async_compile.triton('triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 262144}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i64', 'r0_numel': 'i64', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 151936 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0).to(tl.int64) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None].to(tl.int64) + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :].to(tl.int64) + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 9223372036854775807, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tmp11 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last') + tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32) + tmp4 = tmp2 + tmp3 + tmp5 = tmp2 < 0 + tmp6 = tl.where(tmp5, tmp4, tmp2) + tl.device_assert((0 <= tmp6) & (tmp6 < 151936), "index out of bounds: 0 <= tmp6 < 151936") + tmp8 = tl.load(in_ptr1 + (tmp6), None, eviction_policy='evict_last').to(tl.int1) + tmp9 = tmp8.to(tl.int32) + tmp10 = tmp9.to(tl.int64) + tmp12 = tmp10 * tmp11 + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1 = args + args.clear() + assert_size_stride(arg0_1, (8, 2048, 151936), (311164928, 151936, 1)) + assert_size_stride(arg1_1, (151936, ), (1, )) + assert_size_stride(arg2_1, (8, 2048, 1), (2048, 1, 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf0 = empty_strided_cuda((8, 2048), (2048, 1), torch.int64) + buf1 = reinterpret_tensor(buf0, (8, 2048, 1), (2048, 1, 1), 0); del buf0 # reuse + # Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul] + stream0 = get_raw_stream(0) + triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0.run(buf1, arg0_1, arg1_1, arg2_1, 16384, 151936, stream=stream0) + del arg0_1 + del arg1_1 + del arg2_1 + return (buf1, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((8, 2048, 151936), (311164928, 151936, 1), device='cuda:0', dtype=torch.bfloat16) + arg1_1 = rand_strided((151936, ), (1, ), device='cuda:0', dtype=torch.bool) + arg2_1 = rand_strided((8, 2048, 1), (2048, 1, 1), device='cuda:0', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1, arg2_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/fu/cfutdfbkthjdcl32iesehiivh7updz3ivwh7w57h66ptuk7446en.py b/SpecForge-ext/cache/compiled_kernels/fu/cfutdfbkthjdcl32iesehiivh7updz3ivwh7w57h66ptuk7446en.py new file mode 100644 index 0000000000000000000000000000000000000000..e855a26368fc438460f1c12cd13b3c501b20f98a --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/fu/cfutdfbkthjdcl32iesehiivh7updz3ivwh7w57h66ptuk7446en.py @@ -0,0 +1,46 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2, 'r0_': 8192}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 8, 'r0_': 393216}} +) +@triton.jit +def triton_red_fused_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2 + r0_numel = 8192 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.load(in_ptr1 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp4 = tl.load(in_ptr2 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask & xmask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + tl.store(out_ptr0 + (x0), tmp7, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/gg/776bcf53cacd547ee07c5f284e8c56dd1601f475eee53dbe560adf9ca667f02a.best_config b/SpecForge-ext/cache/compiled_kernels/gg/776bcf53cacd547ee07c5f284e8c56dd1601f475eee53dbe560adf9ca667f02a.best_config new file mode 100644 index 0000000000000000000000000000000000000000..a013ad4c2a7b9a18e1d475008c8b3e320dca3141 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/gg/776bcf53cacd547ee07c5f284e8c56dd1601f475eee53dbe560adf9ca667f02a.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 52, "triton_cache_hash": "UQSFYICF6CFQWZOBHCGZ7JZ457GHWVO6RMPN5ABNWOATFMKI6GQA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/gg/cggaq6nrqpddpb3m26r424xkekrz73n6xpmfwjkrwdl244vaker2.py b/SpecForge-ext/cache/compiled_kernels/gg/cggaq6nrqpddpb3m26r424xkekrz73n6xpmfwjkrwdl244vaker2.py new file mode 100644 index 0000000000000000000000000000000000000000..99a9c9e1b7b74d4c3bcd61f506b4eaeed29d3f12 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/gg/cggaq6nrqpddpb3m26r424xkekrz73n6xpmfwjkrwdl244vaker2.py @@ -0,0 +1,320 @@ +# AOT ID: ['4_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/qd/cqd6lffrumnqrtflwfoqtqs6mvn23l4bxialovx3yvqgximtpflz.py +# Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add] +# Source node to ATen node mapping: +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# Graph fragment: +# %tangents_2 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:0" = PlaceHolder[target=tangents_2] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:0" = PlaceHolder[target=primals_8] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:0" = PlaceHolder[target=primals_6] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:0" = PlaceHolder[target=primals_4] +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %mul_84 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, %unsqueeze_1), kwargs = {}) +# %slice_5 : Tensor "bf16[s48, s48, s9, s24 - ((s24//2))][s24*s48*s9, s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_84, 3, 0, %add_96), kwargs = {}) +# %slice_6 : Tensor "bf16[s48, s48, s9, (s24//2)][s24*s48*s9, s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_84, 3, %sub_72, %primals_2), kwargs = {}) +# %neg_2 : Tensor "bf16[s48, s48, s9, s24 - ((s24//2))][s48*s9*Max(1, s24 - ((s24//2))), s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_5,), kwargs = {}) +# %full_default : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.full.default](args = ([%primals_10, %primals_10, %primals_7, %primals_2], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:0, pin_memory: False}) +# %slice_scatter_default : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default, %neg_2, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %slice_scatter_default_1 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default, %slice_6, 3, 0, %floordiv), kwargs = {}) +# %add_100 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default, %slice_scatter_default_1), kwargs = {}) +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %mul_85 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, %unsqueeze), kwargs = {}) +# %add_101 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_100, %mul_85), kwargs = {}) +# return %add_101 +triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0 = async_compile.triton('triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/t4/ct4ja5lnaomv5oj7757f5xs5m47uf73w3ia3qy4uznjw6fr7z7gi.py +# Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add] +# Source node to ATen node mapping: +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# Graph fragment: +# %tangents_1 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:0" = PlaceHolder[target=tangents_1] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:0" = PlaceHolder[target=primals_8] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:0" = PlaceHolder[target=primals_6] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:0" = PlaceHolder[target=primals_4] +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %mul_86 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %unsqueeze_1), kwargs = {}) +# %slice_7 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s24*s34*s9, s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_86, 3, 0, %sub_72), kwargs = {}) +# %slice_8 : Tensor "bf16[s48, s34, s9, (s24//2)][s24*s34*s9, s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_86, 3, %sub_72, %primals_2), kwargs = {}) +# %neg_3 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s34*s9*Max(1, s24 - ((s24//2))), s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_7,), kwargs = {}) +# %full_default_2 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.full.default](args = ([%primals_10, %primals_11, %primals_7, %primals_2], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:0, pin_memory: False}) +# %slice_scatter_default_2 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_2, %neg_3, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %slice_scatter_default_3 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_2, %slice_8, 3, 0, %floordiv), kwargs = {}) +# %add_106 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default_2, %slice_scatter_default_3), kwargs = {}) +# %mul_87 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %unsqueeze), kwargs = {}) +# %add_107 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_106, %mul_87), kwargs = {}) +# return %add_107 +triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1 = async_compile.triton('triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 67108864}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_2, primals_7, primals_10, primals_11, primals_1, primals_3, primals_5, floordiv, add_96, primals_4, primals_6, primals_8, tangents_1, tangents_2 = args + args.clear() + s24 = primals_2 + s9 = primals_7 + s48 = primals_10 + s34 = primals_11 + s92 = primals_1 + s96 = primals_3 + s79 = primals_5 + assert_size_stride(primals_4, (1, 1, s92, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_6, (1, 1, s79, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_8, (1, s9), (s9, 1)) + assert_size_stride(tangents_1, (s48, s34, s9, s24), (s24*s34*s9, s24*s9, s24, 1)) + assert_size_stride(tangents_2, (s48, s48, s9, s24), (s24*s48*s9, s24*s9, s24, 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf0 = empty_strided_cuda((s48, s48, s9, s24), (s24*s48*s9, s24*s9, s24, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add] + triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0_xnumel = s24*s9*s48*s48 + stream0 = get_raw_stream(0) + triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0.run(tangents_2, primals_8, primals_6, primals_4, buf0, s24, s9, s79, s92, triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0_xnumel, stream=stream0) + del tangents_2 + buf1 = empty_strided_cuda((s48, s34, s9, s24), (s24*s34*s9, s24*s9, s24, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add] + triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1_xnumel = s24*s34*s48*s9 + stream0 = get_raw_stream(0) + triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1.run(tangents_1, primals_8, primals_6, primals_4, buf1, s24, s9, s79, s92, triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1_xnumel, stream=stream0) + del primals_4 + del primals_6 + del primals_8 + del tangents_1 + return (None, None, None, None, None, None, None, None, None, None, None, buf1, buf0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_2 = 128 + primals_7 = 2048 + primals_10 = 8 + primals_11 = 32 + primals_1 = 2048 + primals_3 = 5245440 + primals_5 = 2048 + floordiv = 64 + add_96 = 64 + primals_4 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:0', dtype=torch.bfloat16) + primals_6 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:0', dtype=torch.bfloat16) + primals_8 = rand_strided((1, 2048), (2048, 1), device='cuda:0', dtype=torch.int64) + tangents_1 = rand_strided((8, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) + tangents_2 = rand_strided((8, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:0', dtype=torch.bfloat16) + fn = lambda: call([primals_2, primals_7, primals_10, primals_11, primals_1, primals_3, primals_5, floordiv, add_96, primals_4, primals_6, primals_8, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/gg/cgggk6pegregqt4lolln3yxfp6wzahy6vf2ocae3vbpohfif7mtz.py b/SpecForge-ext/cache/compiled_kernels/gg/cgggk6pegregqt4lolln3yxfp6wzahy6vf2ocae3vbpohfif7mtz.py new file mode 100644 index 0000000000000000000000000000000000000000..b0a032c03697a451e62bf6ed1ee5aed689633291 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/gg/cgggk6pegregqt4lolln3yxfp6wzahy6vf2ocae3vbpohfif7mtz.py @@ -0,0 +1,66 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/gi/cgim5omgpzmqctohwn6fzcyz4k522sj2zt2nk2nh2j3m4l73x44y.py b/SpecForge-ext/cache/compiled_kernels/gi/cgim5omgpzmqctohwn6fzcyz4k522sj2zt2nk2nh2j3m4l73x44y.py new file mode 100644 index 0000000000000000000000000000000000000000..4b1c754da2ebde3dc215c8d7053870f4e9b11615 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/gi/cgim5omgpzmqctohwn6fzcyz4k522sj2zt2nk2nh2j3m4l73x44y.py @@ -0,0 +1,47 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + ks1*x1), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/gi/cgimoobnauaprveeizkzq45bhfd4z4orownh4u3daa6c2si3ruh2.py b/SpecForge-ext/cache/compiled_kernels/gi/cgimoobnauaprveeizkzq45bhfd4z4orownh4u3daa6c2si3ruh2.py new file mode 100644 index 0000000000000000000000000000000000000000..42115b5aed5a811e0faf9d999b11d033d87fd0ba --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/gi/cgimoobnauaprveeizkzq45bhfd4z4orownh4u3daa6c2si3ruh2.py @@ -0,0 +1,159 @@ +# AOT ID: ['1_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ca/ccakd5lwbwmf3z6wn4n7fcni3bewnfe7eagan6yl4dmoelh5awg6.py +# Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax] +# Source node to ATen node mapping: +# target_head => convert_element_type +# target_p => div +# Graph fragment: +# %arg0_1 : Tensor "bf16[8, 2048, 32000][65536000, 32000, 1]cuda:7" = PlaceHolder[target=arg0_1] +# %getitem : Tensor "f32[8, 2048, 1][2048, 1, 16384]cuda:7" = PlaceHolder[target=getitem] +# %getitem_1 : Tensor "f32[8, 2048, 1][2048, 1, 16384]cuda:7" = PlaceHolder[target=getitem_1] +# %convert_element_type : Tensor "f32[8, 2048, 32000][65536000, 32000, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%arg0_1, torch.float32), kwargs = {}) +# %prepare_softmax_online_default : [num_users=2] = call_function[target=torch.ops.prims.prepare_softmax_online.default](args = (%convert_element_type, 2), kwargs = {}) +# %sub_tensor : Tensor "f32[8, 2048, 32000][65536000, 32000, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type, %getitem), kwargs = {}) +# %exp_default : Tensor "f32[8, 2048, 32000][65536000, 32000, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub_tensor,), kwargs = {}) +# %div : Tensor "f32[8, 2048, 32000][65536000, 32000, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%exp_default, %getitem_1), kwargs = {}) +# return %getitem,%getitem_1,%div +triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0 = async_compile.triton('triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'add_persistent_rblock': True, 'tiling_scores': {'x': 0, 'r0_': 5242880000}} +) +@triton.jit +def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32) + _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + + _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine( + _tmp3_max, _tmp3_sum, tmp2, False + ) + + _tmp3_max = tl.where(r0_mask, _tmp3_max_next, _tmp3_max) + _tmp3_sum = tl.where(r0_mask, _tmp3_sum_next, _tmp3_sum) + + tmp3, tmp4 = triton_helpers.online_softmax_reduce( + _tmp3_max, _tmp3_sum, 1, False) + tmp3 = tmp3[:, None] + tmp4 = tmp4[:, None] + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp6 - tmp3 + tmp8 = libdevice.exp(tmp7) + tmp9 = (tmp8 / tmp4) + tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, = args + args.clear() + assert_size_stride(arg0_1, (8, 2048, 32000), (65536000, 32000, 1)) + with torch.cuda._DeviceGuard(7): + torch.cuda.set_device(7) + buf2 = empty_strided_cuda((8, 2048, 32000), (65536000, 32000, 1), torch.float32) + # Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax] + stream7 = get_raw_stream(7) + triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0.run(arg0_1, buf2, 16384, 32000, stream=stream7) + del arg0_1 + return (buf2, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((8, 2048, 32000), (65536000, 32000, 1), device='cuda:7', dtype=torch.bfloat16) + fn = lambda: call([arg0_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/gi/cgizr5dsdmd4fht4iyis24kvfutr7cihykia6spr2gsornoa7wlt.py b/SpecForge-ext/cache/compiled_kernels/gi/cgizr5dsdmd4fht4iyis24kvfutr7cihykia6spr2gsornoa7wlt.py new file mode 100644 index 0000000000000000000000000000000000000000..fc7e66edaa2433e0e7566fcedb8db2b98afcbd48 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/gi/cgizr5dsdmd4fht4iyis24kvfutr7cihykia6spr2gsornoa7wlt.py @@ -0,0 +1,159 @@ +# AOT ID: ['1_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/4i/c4iwnhsf5kfmm7jnzrkyiv4x3yahjog6dyhf4prm2cjdi5xhllx2.py +# Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax] +# Source node to ATen node mapping: +# target_head => convert_element_type +# target_p => div +# Graph fragment: +# %arg0_1 : Tensor "bf16[8, 2048, 32000][65536000, 32000, 1]cuda:6" = PlaceHolder[target=arg0_1] +# %getitem : Tensor "f32[8, 2048, 1][2048, 1, 16384]cuda:6" = PlaceHolder[target=getitem] +# %getitem_1 : Tensor "f32[8, 2048, 1][2048, 1, 16384]cuda:6" = PlaceHolder[target=getitem_1] +# %convert_element_type : Tensor "f32[8, 2048, 32000][65536000, 32000, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%arg0_1, torch.float32), kwargs = {}) +# %prepare_softmax_online_default : [num_users=2] = call_function[target=torch.ops.prims.prepare_softmax_online.default](args = (%convert_element_type, 2), kwargs = {}) +# %sub_tensor : Tensor "f32[8, 2048, 32000][65536000, 32000, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type, %getitem), kwargs = {}) +# %exp_default : Tensor "f32[8, 2048, 32000][65536000, 32000, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub_tensor,), kwargs = {}) +# %div : Tensor "f32[8, 2048, 32000][65536000, 32000, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%exp_default, %getitem_1), kwargs = {}) +# return %getitem,%getitem_1,%div +triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0 = async_compile.triton('triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'add_persistent_rblock': True, 'tiling_scores': {'x': 0, 'r0_': 5242880000}} +) +@triton.jit +def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32) + _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + + _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine( + _tmp3_max, _tmp3_sum, tmp2, False + ) + + _tmp3_max = tl.where(r0_mask, _tmp3_max_next, _tmp3_max) + _tmp3_sum = tl.where(r0_mask, _tmp3_sum_next, _tmp3_sum) + + tmp3, tmp4 = triton_helpers.online_softmax_reduce( + _tmp3_max, _tmp3_sum, 1, False) + tmp3 = tmp3[:, None] + tmp4 = tmp4[:, None] + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp6 - tmp3 + tmp8 = libdevice.exp(tmp7) + tmp9 = (tmp8 / tmp4) + tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, = args + args.clear() + assert_size_stride(arg0_1, (8, 2048, 32000), (65536000, 32000, 1)) + with torch.cuda._DeviceGuard(6): + torch.cuda.set_device(6) + buf2 = empty_strided_cuda((8, 2048, 32000), (65536000, 32000, 1), torch.float32) + # Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax] + stream6 = get_raw_stream(6) + triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0.run(arg0_1, buf2, 16384, 32000, stream=stream6) + del arg0_1 + return (buf2, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((8, 2048, 32000), (65536000, 32000, 1), device='cuda:6', dtype=torch.bfloat16) + fn = lambda: call([arg0_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/ha/147d19502b6b2cc895f475c425cf10060732a7c43639841f4feb8663a06f5feb.best_config b/SpecForge-ext/cache/compiled_kernels/ha/147d19502b6b2cc895f475c425cf10060732a7c43639841f4feb8663a06f5feb.best_config new file mode 100644 index 0000000000000000000000000000000000000000..96dd92ec5b0239e781a57a11e2928c5c0f286636 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ha/147d19502b6b2cc895f475c425cf10060732a7c43639841f4feb8663a06f5feb.best_config @@ -0,0 +1 @@ +{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "42NVHDOVRHC3TSIT2M6NVJU72L5EVVTGFXWS47GDCP2GM2XRN7KA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/ha/cha4dxgiwtoqvo2tmipuixwe6hlkkkzkba3275dxpqk3fdiyykga.py b/SpecForge-ext/cache/compiled_kernels/ha/cha4dxgiwtoqvo2tmipuixwe6hlkkkzkba3275dxpqk3fdiyykga.py new file mode 100644 index 0000000000000000000000000000000000000000..69bfe3c6cdf997f9d20228c0203d28294f051e60 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ha/cha4dxgiwtoqvo2tmipuixwe6hlkkkzkba3275dxpqk3fdiyykga.py @@ -0,0 +1,527 @@ +# AOT ID: ['8_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bj/cbj72m23cmcn2yjoxrp4vabc2f76gw727jcpbi4y5oidokqenki5.py +# Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] +# Source node to ATen node mapping: +# dense_mask_2 => full_default_1 +# Graph fragment: +# %full_default_1 : Tensor "i32[2, 1, 16, (((s37 + 127)//128)) + 1][16*Max(1, (((s37 + 127)//128)) + 1), 16*Max(1, (((s37 + 127)//128)) + 1), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 16, %add_166], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:3, pin_memory: False}) +# return %index_put +triton_poi_fused_new_zeros_0 = async_compile.triton('triton_poi_fused_new_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 2048}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/4g/c4gr37y26wd4va4drshauwjr3p5l32j5cssih4o5yz3h2g6jkxrz.py +# Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_1, mask_2, mask_3, mask_block_sum, gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, full_blocks, full_blocks_1, dense_mask_1], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.constant_pad_nd, aten.permute, aten.sum, aten.gt, aten._to_copy] +# Source node to ATen node mapping: +# and_2 => bitwise_and_1 +# and_3 => bitwise_and_2 +# and_4 => bitwise_and_3, view_8 +# b => iota +# batched_outputs_2 => view_9 +# causal_mask => ge_1, view +# dense_mask => convert_element_type_2 +# dense_mask_1 => convert_element_type_5 +# diagnol_mask => eq_12 +# full_blocks => eq_24 +# full_blocks_1 => convert_element_type_1 +# gt => gt +# index => index +# index_1 => index_1 +# index_2 => index_2 +# lt => lt, view_1 +# lt_1 => lt_1, view_2 +# lt_3 => lt_3 +# m => iota_2 +# mask_1 => constant_pad_nd +# mask_2 => view_10 +# mask_3 => permute +# mask_block_sum => sum_1 +# n => iota_3 +# padding_mask => bitwise_and, view_3, view_4 +# padding_mask_1 => lt_2, view_6 +# partial_blocks => bitwise_and_4 +# partial_blocks_1 => convert_element_type +# remainder => remainder +# remainder_1 => remainder_1 +# result_1 => bitwise_or, full_default +# result_2 => bitwise_or_1 +# sub => sub_12, view_7 +# suffix_mask => ge_2 +# Graph fragment: +# %arg1_1 : Tensor "i64[2][1]cuda:3" = PlaceHolder[target=arg1_1] +# %sum_1 : Tensor "i64[2, 1, 16, ((s37 + 127)//128)][16*(((s37 + 127)//128)), 32*(((s37 + 127)//128)), ((s37 + 127)//128), 1]cuda:3" = PlaceHolder[target=sum_1] +# %full_default : Tensor "b8[2, 1, 1][1, 1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 1], False), kwargs = {dtype: torch.bool, layout: torch.strided, device: cuda:3, pin_memory: False}) +# %iota_2 : Tensor "i64[2048][1]cuda:3"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (2048,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:3, requires_grad: False}) +# %view : Tensor "i64[2048, 1][1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {}) +# %iota_3 : Tensor "i64[s37][1]cuda:3"[num_users=5] = call_function[target=torch.ops.prims.iota.default](args = (%arg0_1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:3, requires_grad: False}) +# %ge_1 : Tensor "b8[2048, s37][Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.ge.Tensor](args = (%view, %iota_3), kwargs = {}) +# %iota : Tensor "i64[2][1]cuda:3"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (2,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:3, requires_grad: False}) +# %index : Tensor "i64[2][1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg1_1, [%iota]), kwargs = {}) +# %view_1 : Tensor "i64[2, 1][1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index, [2, 1]), kwargs = {}) +# %lt : Tensor "b8[2, s37][Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_3, %view_1), kwargs = {}) +# %view_4 : Tensor "b8[2, 1, s37][Max(1, s37), s37, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt, [2, 1, %arg0_1]), kwargs = {}) +# %index_1 : Tensor "i64[2][1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg1_1, [%iota]), kwargs = {}) +# %view_2 : Tensor "i64[2, 1][1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_1, [2, 1]), kwargs = {}) +# %lt_1 : Tensor "b8[2, 2048][2048, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_2, %view_2), kwargs = {}) +# %view_3 : Tensor "b8[2, 2048, 1][2048, 1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt_1, [2, 2048, 1]), kwargs = {}) +# %bitwise_and : Tensor "b8[2, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_4, %view_3), kwargs = {}) +# %bitwise_and_1 : Tensor "b8[2, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_1, %bitwise_and), kwargs = {}) +# %bitwise_or : Tensor "b8[2, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%full_default, %bitwise_and_1), kwargs = {}) +# %ge_2 : Tensor "b8[s37][1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%iota_3, 2048), kwargs = {}) +# %remainder : Tensor "i64[s37][1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%iota_3, 2048), kwargs = {}) +# %index_2 : Tensor "i64[2][1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg1_1, [%iota]), kwargs = {}) +# %view_6 : Tensor "i64[2, 1][1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_2, [2, 1]), kwargs = {}) +# %lt_2 : Tensor "b8[2, s37][Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%remainder, %view_6), kwargs = {}) +# %bitwise_and_2 : Tensor "b8[2, s37][Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_2, %lt_2), kwargs = {}) +# %view_8 : Tensor "b8[2, 1, s37][Max(1, s37), s37, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_and_2, [2, 1, %arg0_1]), kwargs = {}) +# %view_7 : Tensor "i64[2048, 1][1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {}) +# %sub_12 : Tensor "i64[2048, s37][Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%iota_3, %view_7), kwargs = {}) +# %remainder_1 : Tensor "i64[2048, s37][Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%sub_12, 2048), kwargs = {}) +# %eq_12 : Tensor "b8[2048, s37][Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%remainder_1, 0), kwargs = {}) +# %bitwise_and_3 : Tensor "b8[2, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_8, %eq_12), kwargs = {}) +# %bitwise_or_1 : Tensor "b8[2, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%bitwise_or, %bitwise_and_3), kwargs = {}) +# %view_9 : Tensor "b8[2, 1, 2048, s37][2048*Max(1, s37), 2048*Max(1, s37), Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_or_1, [2, 1, 2048, %arg0_1]), kwargs = {}) +# %constant_pad_nd : Tensor "b8[2, 1, 2048, 128*(((s37 + 127)//128))][2048*Max(1, 128*(((s37 + 127)//128))), 2048*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s37 + 127)//128))), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.constant_pad_nd.default](args = (%expand, [0, %sub_23, 0, 0], 0.0), kwargs = {}) +# %view_10 : Tensor "b8[2, 1, 16, 128, ((s37 + 127)//128), 128][2048*Max(1, 128*(((s37 + 127)//128))), 2048*Max(1, 128*(((s37 + 127)//128))), 128*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s37 + 127)//128))), 128, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%constant_pad_nd, [2, 1, 16, 128, %floordiv_1, 128]), kwargs = {}) +# %permute : Tensor "b8[2, 1, 16, ((s37 + 127)//128), 128, 128][2048*Max(1, 128*(((s37 + 127)//128))), 2048*Max(1, 128*(((s37 + 127)//128))), 128*Max(1, 128*(((s37 + 127)//128))), 128, Max(1, 128*(((s37 + 127)//128))), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 1, 2, 4, 3, 5]), kwargs = {}) +# %sum_1 : Tensor "i64[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:3"[num_users=3] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute, [-2, -1]), kwargs = {}) +# %gt : Tensor "b8[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {}) +# %lt_3 : Tensor "b8[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %bitwise_and_4 : Tensor "b8[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%gt, %lt_3), kwargs = {}) +# %convert_element_type : Tensor "i8[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%bitwise_and_4, torch.int8), kwargs = {}) +# %convert_element_type_2 : Tensor "i32[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:3"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type, torch.int32), kwargs = {}) +# %eq_24 : Tensor "b8[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %convert_element_type_1 : Tensor "i8[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%eq_24, torch.int8), kwargs = {}) +# %convert_element_type_5 : Tensor "i32[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:3"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_1, torch.int32), kwargs = {}) +# return %sum_1,%convert_element_type_2,%convert_element_type_5 +triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1 = async_compile.triton('triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1024, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1(in_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % 16) + x2 = xindex // ks2 + _tmp36 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x5 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = (r0_index % 128) + r0_4 = r0_index // 128 + tmp0 = r0_3 + 128*x0 + tmp1 = ks1 + tmp2 = tmp0 < tmp1 + tmp3 = r0_4 + 128*x1 + tmp4 = r0_3 + 128*x0 + tmp5 = tmp3 >= tmp4 + tmp6 = tl.load(in_ptr0 + (tl.broadcast_to(x2, [XBLOCK, R0_BLOCK])), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp7 = tmp4 < tmp6 + tmp8 = tmp3 < tmp6 + tmp9 = tmp7 & tmp8 + tmp10 = tmp5 & tmp9 + tmp11 = tl.full([1, 1], False, tl.int1) + tmp12 = tmp11 | tmp10 + tmp13 = tl.full([1, 1], 2048, tl.int64) + tmp14 = tmp4 >= tmp13 + tmp15 = ((r0_3 + 128*x0) % 2048) + tmp16 = tmp15 < tmp6 + tmp17 = tmp14 & tmp16 + tmp18 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp19 = (tmp18 % tmp13) + tmp20 = tl.full([1, 1], 0, tl.int32) + tmp21 = tmp19 != tmp20 + tmp22 = (libdevice.signbit(tmp19) != 0) if (tmp19).dtype is tl.float32 else tmp19 < 0 + tmp23 = (libdevice.signbit(tmp13) != 0) if (tmp13).dtype is tl.float32 else tmp13 < 0 + tmp24 = tmp22 != tmp23 + tmp25 = tmp21 & tmp24 + tmp26 = tmp19 + tmp13 + tmp27 = tl.where(tmp25, tmp26, tmp19) + tmp28 = tl.full([1, 1], 0, tl.int64) + tmp29 = tmp27 == tmp28 + tmp30 = tmp17 & tmp29 + tmp31 = tmp12 | tmp30 + tmp32 = tl.full(tmp31.shape, False, tmp31.dtype) + tmp33 = tl.where(tmp2, tmp31, tmp32) + tmp34 = tmp33.to(tl.int64) + tmp35 = tl.broadcast_to(tmp34, [XBLOCK, R0_BLOCK]) + tmp37 = _tmp36 + tmp35 + _tmp36 = tl.where(r0_mask & xmask, tmp37, _tmp36) + tmp36 = tl.sum(_tmp36, 1)[:, None] + tmp38 = tl.full([1, 1], 0, tl.int64) + tmp39 = tmp36 > tmp38 + tmp40 = tl.full([1, 1], 16384, tl.int64) + tmp41 = tmp36 < tmp40 + tmp42 = tmp39 & tmp41 + tmp43 = tmp42.to(tl.int8) + tmp44 = tmp43.to(tl.int32) + tmp45 = tmp36 == tmp40 + tmp46 = tmp45.to(tl.int8) + tmp47 = tmp46.to(tl.int32) + tl.store(out_ptr1 + (x5), tmp44, xmask) + tl.store(out_ptr2 + (x5), tmp47, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/vr/cvrhnrmpgyxwu34xleclee3tt4kemoldkj7iam4uciathomirvlc.py +# Topologically Sorted Source Nodes: [dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten._to_copy, aten.lt, aten.scalar_tensor, aten.where, aten.view, aten.index_put] +# Source node to ATen node mapping: +# arange_4 => iota_4 +# child_3 => convert_element_type_3 +# child_4 => convert_element_type_4 +# col_range => iota_5 +# dense_mask_2 => full_default_1 +# index_mask => lt_4 +# num_blocks_in_row => sum_2 +# row_indices => unsqueeze +# setitem => full_default_2, index_put, iota_6, iota_7, unsqueeze_2, unsqueeze_3, unsqueeze_4, unsqueeze_5, unsqueeze_6 +# unsqueeze_1 => unsqueeze_1 +# valid_indices => scalar_tensor, where +# Graph fragment: +# %convert_element_type_2 : Tensor "i32[2, 1, 16, ((s37 + 127)//128)][16*(((s37 + 127)//128)), 32*(((s37 + 127)//128)), ((s37 + 127)//128), 1]cuda:3" = PlaceHolder[target=convert_element_type_2] +# %sum_2 : Tensor "i64[2, 1, 16][16, 32, 1]cuda:3" = PlaceHolder[target=sum_2] +# %getitem_1 : Tensor "i64[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 32*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:3" = PlaceHolder[target=getitem_1] +# %convert_element_type_3 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:3" = PlaceHolder[target=convert_element_type_3] +# %convert_element_type_4 : Tensor "i32[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:3" = PlaceHolder[target=convert_element_type_4] +# %index_put : Tensor "i32[2, 1, 16, (((s37 + 127)//128)) + 1][16*(((s37 + 127)//128)) + 16, 16*(((s37 + 127)//128)) + 16, (((s37 + 127)//128)) + 1, 1]cuda:3" = PlaceHolder[target=index_put] +# %full_default_1 : Tensor "i32[2, 1, 16, (((s37 + 127)//128)) + 1][16*Max(1, (((s37 + 127)//128)) + 1), 16*Max(1, (((s37 + 127)//128)) + 1), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 16, %add_166], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:3, pin_memory: False}) +# %iota_7 : Tensor "i64[2][1]cuda:3"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (2,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:3, requires_grad: False}) +# %unsqueeze_4 : Tensor "i64[2, 1][1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_7, -1), kwargs = {}) +# %unsqueeze_5 : Tensor "i64[2, 1, 1][1, 1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_4, -1), kwargs = {}) +# %unsqueeze_6 : Tensor "i64[2, 1, 1, 1][1, 1, 1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_5, -1), kwargs = {}) +# %iota_6 : Tensor "i64[1][1]cuda:3"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:3, requires_grad: False}) +# %unsqueeze_2 : Tensor "i64[1, 1][1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_6, -1), kwargs = {}) +# %unsqueeze_3 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_2, -1), kwargs = {}) +# %iota_4 : Tensor "i32[16][1]cuda:3"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:3, requires_grad: False}) +# %unsqueeze : Tensor "i32[16, 1][1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_4, -1), kwargs = {}) +# %iota_5 : Tensor "i32[((s37 + 127)//128)][1]cuda:3"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (%floordiv_1,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:3, requires_grad: False}) +# %sum_2 : Tensor "i64[2, 1, 16][16, 16, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_2, [-1]), kwargs = {}) +# %convert_element_type_3 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_2, torch.int32), kwargs = {}) +# %unsqueeze_1 : Tensor "i32[2, 1, 16, 1][16, 16, 1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_3, 3), kwargs = {}) +# %lt_4 : Tensor "b8[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_5, %unsqueeze_1), kwargs = {}) +# %convert_element_type_4 : Tensor "i32[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:3"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_1, torch.int32), kwargs = {}) +# %scalar_tensor : Tensor "i32[][]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (%floordiv_1,), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:3}) +# %where : Tensor "i32[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_4, %convert_element_type_4, %scalar_tensor), kwargs = {}) +# %full_default_2 : Tensor "i32[2, 1, 1, 1][1, 1, 1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:3, pin_memory: False}) +# %index_put : Tensor "i32[2, 1, 16, (((s37 + 127)//128)) + 1][16*Max(1, (((s37 + 127)//128)) + 1), 16*Max(1, (((s37 + 127)//128)) + 1), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_1, [%unsqueeze_6, %unsqueeze_3, %unsqueeze, %where], %full_default_2), kwargs = {}) +# return %sum_2,%convert_element_type_3,%convert_element_type_4,%buf13 +triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2 = async_compile.triton('triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32, 'r0_': 32}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'in_ptr1': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr3'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2(in_ptr0, in_ptr1, out_ptr1, out_ptr2, out_ptr3, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x0), tmp5, xmask) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp6 = tl.load(in_ptr1 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8 < tmp5 + tmp10 = ks0 + tmp11 = tl.where(tmp9, tmp7, tmp10) + tmp12 = 1 + ks0 + tmp13 = tmp11 + tmp12 + tmp14 = tmp11 < 0 + tmp15 = tl.where(tmp14, tmp13, tmp11) + tl.device_assert(((0 <= tmp15) & (tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128)))) | ~(r0_mask & xmask), "index out of bounds: 0 <= tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128))") + tmp17 = tl.full([1, 1], 1, tl.int32) + tl.store(out_ptr2 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp7, r0_mask & xmask) + tl.store(out_ptr3 + (tl.broadcast_to(tmp15 + x0 + ks0*x0, [XBLOCK, R0_BLOCK])), tmp17, r0_mask & xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3j/c3j47dekusw3y4mohtk5v36cc6fso3wdtqn5oqjwew3yy3exjo76.py +# Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] +# Source node to ATen node mapping: +# batched_outputs_3 => clone_4, slice_2 +# col_indices_2 => sort_2 +# num_blocks_in_row_2 => sum_4 +# q_indices => clone_6, convert_element_type_9 +# q_num_blocks => convert_element_type_8 +# transpose => permute_1 +# Graph fragment: +# %buf13 : Tensor "i32[2, 1, 16, (((s37 + 127)//128)) + 1][16*(((s37 + 127)//128)) + 16, 16*(((s37 + 127)//128)) + 16, (((s37 + 127)//128)) + 1, 1]cuda:3" = PlaceHolder[target=buf13] +# %buf15 : Tensor "i16[2, 1, ((s37 + 127)//128), 16][16*(((s37 + 127)//128)), 32*(((s37 + 127)//128)), 16, 1]cuda:3" = PlaceHolder[target=buf15] +# %sum_4 : Tensor "i64[2, 1, ((s37 + 127)//128)][((s37 + 127)//128), 2*(((s37 + 127)//128)), 1]cuda:3" = PlaceHolder[target=sum_4] +# %slice_2 : Tensor "i32[2, 1, 16, ((s37 + 127)//128)][16*Max(1, (((s37 + 127)//128)) + 1), 16*Max(1, (((s37 + 127)//128)) + 1), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, %floordiv_1), kwargs = {}) +# %clone_4 : Tensor "i32[2, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_2,), kwargs = {memory_format: torch.contiguous_format}) +# %permute_1 : Tensor "i32[2, 1, ((s37 + 127)//128), 16][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%clone_4, [0, 1, 3, 2]), kwargs = {}) +# %sort_2 : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%permute_1,), kwargs = {stable: True, descending: True}) +# %convert_element_type_9 : Tensor "i32[2, 1, ((s37 + 127)//128), 16][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:3"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_5, torch.int32), kwargs = {}) +# %clone_6 : Tensor "i32[2, 1, ((s37 + 127)//128), 16][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), 16, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_9,), kwargs = {memory_format: torch.contiguous_format}) +# %sum_4 : Tensor "i64[2, 1, ((s37 + 127)//128)][Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute_1, [-1]), kwargs = {}) +# %convert_element_type_8 : Tensor "i32[2, 1, ((s37 + 127)//128)][Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_4, torch.int32), kwargs = {}) +# return %buf15,%sum_4,%clone_6,%convert_element_type_8 +triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3 = async_compile.triton('triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 64, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % ks0) + x1 = xindex // ks0 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (r0_2 + x0 + 16*x1 + ks0*r0_2 + 16*ks0*x1), xmask, eviction_policy='evict_last', other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x0 + 16*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp13, xmask) + tl.store(out_ptr3 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp14, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1 = args + args.clear() + s37 = arg0_1 + assert_size_stride(arg1_1, (2, ), (1, )) + with torch.cuda._DeviceGuard(3): + torch.cuda.set_device(3) + buf12 = empty_strided_cuda((2, 1, 16, 1 + ((127 + s37) // 128)), (16 + 16*((127 + s37) // 128), 16 + 16*((127 + s37) // 128), 1 + ((127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] + triton_poi_fused_new_zeros_0_xnumel = 32 + 32*((127 + s37) // 128) + stream3 = get_raw_stream(3) + triton_poi_fused_new_zeros_0.run(buf12, triton_poi_fused_new_zeros_0_xnumel, stream=stream3) + buf19 = empty_strided_cuda((2, 1, 16, 1 + ((127 + s37) // 128)), (16 + 16*((127 + s37) // 128), 16 + 16*((127 + s37) // 128), 1 + ((127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros] + triton_poi_fused_new_zeros_0_xnumel = 32 + 32*((127 + s37) // 128) + stream3 = get_raw_stream(3) + triton_poi_fused_new_zeros_0.run(buf19, triton_poi_fused_new_zeros_0_xnumel, stream=stream3) + ps0 = (127 + s37) // 128 + ps1 = 16*((127 + s37) // 128) + buf1 = empty_strided_cuda((2, 1, 16, (127 + s37) // 128), (16*((127 + s37) // 128), 32*((127 + s37) // 128), (127 + s37) // 128, 1), torch.int32) + buf5 = empty_strided_cuda((2, 1, 16, (127 + s37) // 128), (16*((127 + s37) // 128), 32*((127 + s37) // 128), (127 + s37) // 128, 1), torch.int32) + # Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_1, mask_2, mask_3, mask_block_sum, gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, full_blocks, full_blocks_1, dense_mask_1], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.constant_pad_nd, aten.permute, aten.sum, aten.gt, aten._to_copy] + triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1_xnumel = 32*((127 + s37) // 128) + stream3 = get_raw_stream(3) + triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1.run(arg1_1, buf1, buf5, ps0, s37, ps1, triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1_xnumel, 16384, stream=stream3) + del arg1_1 + # Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort] + buf2 = torch.ops.aten.sort.stable(buf1, stable=True, dim=3, descending=True) + buf4 = buf2[1] + assert_size_stride(buf4, (2, 1, 16, (127 + s37) // 128), (16*max(1, (127 + s37) // 128), 32*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), 'torch.ops.aten.sort.stable') + assert_alignment(buf4, 16, 'torch.ops.aten.sort.stable') + del buf2 + buf10 = empty_strided_cuda((2, 1, 16), (16, 16, 1), torch.int32) + buf11 = empty_strided_cuda((2, 1, 16, (127 + s37) // 128), (16*max(1, (127 + s37) // 128), 16*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten._to_copy, aten.lt, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2_r0_numel = (127 + s37) // 128 + stream3 = get_raw_stream(3) + triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2.run(buf1, buf4, buf10, buf11, buf12, ps0, s37, 32, triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2_r0_numel, stream=stream3) + del buf1 + del buf4 + buf26 = empty_strided_cuda((2, 1, (127 + s37) // 128, 16), (16*max(1, (127 + s37) // 128), 16*max(1, (127 + s37) // 128), 16, 1), torch.int32) + buf28 = empty_strided_cuda((2, 1, (127 + s37) // 128), (max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3_xnumel = 2*((127 + s37) // 128) + stream3 = get_raw_stream(3) + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf12, buf26, buf28, ps0, triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3_xnumel, 16, stream=stream3) + del buf12 + # Topologically Sorted Source Nodes: [full_blocks, full_blocks_1, dense_mask_1, col_indices_1], Original ATen: [aten.eq, aten._to_copy, aten.sort] + buf6 = torch.ops.aten.sort.stable(buf5, stable=True, dim=3, descending=True) + buf8 = buf6[1] + assert_size_stride(buf8, (2, 1, 16, (127 + s37) // 128), (16*max(1, (127 + s37) // 128), 32*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), 'torch.ops.aten.sort.stable') + assert_alignment(buf8, 16, 'torch.ops.aten.sort.stable') + del buf6 + buf17 = empty_strided_cuda((2, 1, 16), (16, 16, 1), torch.int32) + buf18 = empty_strided_cuda((2, 1, 16, (127 + s37) // 128), (16*max(1, (127 + s37) // 128), 16*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, num_blocks_in_row_1, child_7, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten._to_copy, aten.lt, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2_r0_numel = (127 + s37) // 128 + stream3 = get_raw_stream(3) + triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2.run(buf5, buf8, buf17, buf18, buf19, ps0, s37, 32, triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2_r0_numel, stream=stream3) + del buf5 + del buf8 + buf23 = empty_strided_cuda((2, 1, (127 + s37) // 128, 16), (16*max(1, (127 + s37) // 128), 16*max(1, (127 + s37) // 128), 16, 1), torch.int32) + buf25 = empty_strided_cuda((2, 1, (127 + s37) // 128), (max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, col_indices_3, full_q_indices, num_blocks_in_row_3, full_q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3_xnumel = 2*((127 + s37) // 128) + stream3 = get_raw_stream(3) + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf19, buf23, buf25, ps0, triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3_xnumel, 16, stream=stream3) + del buf19 + return (buf23, buf25, buf26, buf28, buf18, buf17, buf11, buf10, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 4096 + arg1_1 = rand_strided((2, ), (1, ), device='cuda:3', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/ha/chafrhmshbanw6gakzkfz6kn3ygixlmuadzppaazt2q5wzde6yqk.py b/SpecForge-ext/cache/compiled_kernels/ha/chafrhmshbanw6gakzkfz6kn3ygixlmuadzppaazt2q5wzde6yqk.py new file mode 100644 index 0000000000000000000000000000000000000000..afd2bdb7a70097e1dec736d1c02ff06979ed39e6 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ha/chafrhmshbanw6gakzkfz6kn3ygixlmuadzppaazt2q5wzde6yqk.py @@ -0,0 +1,72 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp8 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr2 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tmp2.to(tl.float32) + tmp5 = tmp4.to(tl.float32) + tmp6 = tmp3 * tmp5 + tmp7 = tl.broadcast_to(tmp6, [XBLOCK, R0_BLOCK]) + tmp9 = _tmp8 + tmp7 + _tmp8 = tl.where(r0_mask & xmask, tmp9, _tmp8) + tmp8 = tl.sum(_tmp8, 1)[:, None] + tmp14 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last') + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp10 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp11 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp24 = tl.load(in_ptr2 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp12 = tmp10 * tmp11 + tmp13 = tmp12.to(tl.float32) + tmp15 = tmp13 * tmp14 + tmp16 = -0.5 + tmp17 = tmp8 * tmp16 + tmp18 = tmp14 * tmp14 + tmp19 = tmp18 * tmp14 + tmp20 = tmp17 * tmp19 + tmp21 = ks0 + tmp22 = tmp21.to(tl.float32) + tmp23 = (tmp20 / tmp22) + tmp25 = tmp24.to(tl.float32) + tmp26 = 2.0 + tmp27 = tmp25 * tmp26 + tmp28 = tmp23 * tmp27 + tmp29 = tmp15 + tmp28 + tmp30 = tmp29.to(tl.float32) + tl.store(out_ptr1 + (r0_1 + ks0*x0), tmp30, r0_mask & xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/ha/chajbetuj2itvh432hm3fjzlcvjkiwnpmmlunen5nqg3p6jfssgj.py b/SpecForge-ext/cache/compiled_kernels/ha/chajbetuj2itvh432hm3fjzlcvjkiwnpmmlunen5nqg3p6jfssgj.py new file mode 100644 index 0000000000000000000000000000000000000000..2dcbf170825eee9424d650b7d132bc7d617208f9 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ha/chajbetuj2itvh432hm3fjzlcvjkiwnpmmlunen5nqg3p6jfssgj.py @@ -0,0 +1,309 @@ +# AOT ID: ['4_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/al/cal2r4tfyw6gic3ggqyud3nufnajx6xau2koieoitx6zg4wsiozm.py +# Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul, x1, x2, neg, cat, mul_1, q_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] +# Source node to ATen node mapping: +# cat => cat +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# mul => mul_24 +# mul_1 => mul_45 +# neg => neg +# q_embed => add_54 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# x1 => slice_1 +# x2 => slice_2 +# Graph fragment: +# %primals_12 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:0" = PlaceHolder[target=primals_12] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:0" = PlaceHolder[target=primals_8] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:0" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:0" = PlaceHolder[target=primals_6] +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %mul_24 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_12, %unsqueeze), kwargs = {}) +# %slice_1 : Tensor "bf16[s48, s34, s9, (s24//2)][s24*s34*s9, s24, s24*s34, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_12, 3, 0, %floordiv), kwargs = {}) +# %slice_2 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s24*s34*s9, s24, s24*s34, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_12, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %neg : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s34*s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), s34*Max(1, s24 - ((s24//2))), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_2,), kwargs = {}) +# %cat : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg, %slice_1], -1), kwargs = {}) +# %mul_45 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat, %unsqueeze_1), kwargs = {}) +# %add_54 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_24, %mul_45), kwargs = {}) +# return %add_54 +triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0 = async_compile.triton('triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ey/cey3ar6s7f2t62buescu5cctxdhf6hmbv3ps5d3tmh235oaj3fj6.py +# Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul_2, x1_1, x2_1, neg_1, cat_1, mul_3, k_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] +# Source node to ATen node mapping: +# cat_1 => cat_1 +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# k_embed => add_90 +# mul_2 => mul_54 +# mul_3 => mul_75 +# neg_1 => neg_1 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# x1_1 => slice_3 +# x2_1 => slice_4 +# Graph fragment: +# %primals_14 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24, s24*s25, 1]cuda:0" = PlaceHolder[target=primals_14] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:0" = PlaceHolder[target=primals_8] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:0" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:0" = PlaceHolder[target=primals_6] +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %mul_54 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24, s24*s25, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_14, %unsqueeze), kwargs = {}) +# %slice_3 : Tensor "bf16[s48, s25, s9, (s24//2)][s24*s25*s9, s24, s24*s25, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_14, 3, 0, %floordiv), kwargs = {}) +# %slice_4 : Tensor "bf16[s48, s25, s9, s24 - ((s24//2))][s24*s25*s9, s24, s24*s25, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_14, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %neg_1 : Tensor "bf16[s48, s25, s9, s24 - ((s24//2))][s25*s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), s25*Max(1, s24 - ((s24//2))), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_4,), kwargs = {}) +# %cat_1 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_1, %slice_3], -1), kwargs = {}) +# %mul_75 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_1, %unsqueeze_1), kwargs = {}) +# %add_90 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24, s24*s25, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_54, %mul_75), kwargs = {}) +# return %add_90 +triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1 = async_compile.triton('triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4194304}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14 = args + args.clear() + s92 = primals_1 + s24 = primals_2 + s96 = primals_3 + s79 = primals_5 + s9 = primals_7 + s38 = primals_9 + s48 = primals_10 + s34 = primals_11 + s25 = primals_13 + assert_size_stride(primals_4, (1, 1, s92, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_6, (1, 1, s79, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_8, (1, s9), (s9, 1)) + assert_size_stride(primals_12, (s48, s34, s9, s24), (s24*s34*s9, s24, s24*s34, 1)) + assert_size_stride(primals_14, (s48, s25, s9, s24), (s24*s25*s9, s24, s24*s25, 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + ps0 = s24*s34 + buf0 = empty_strided_cuda((s48, s34, s9, s24), (s24*s34*s9, s24, s24*s34, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul, x1, x2, neg, cat, mul_1, q_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0_xnumel = s24*s34*s48*s9 + stream0 = get_raw_stream(0) + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0.run(primals_12, primals_8, primals_4, primals_6, buf0, ps0, s9, s92, s24, s79, triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0_xnumel, stream=stream0) + del primals_12 + ps1 = s24*s25 + buf1 = empty_strided_cuda((s48, s25, s9, s24), (s24*s25*s9, s24, s24*s25, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul_2, x1_1, x2_1, neg_1, cat_1, mul_3, k_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1_xnumel = s24*s25*s48*s9 + stream0 = get_raw_stream(0) + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1.run(primals_14, primals_8, primals_4, primals_6, buf1, ps1, s9, s92, s24, s79, triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1_xnumel, stream=stream0) + del primals_14 + return (buf0, buf1, primals_4, primals_6, primals_8, s24, s9, s48, s34, s25, s92, s96, s79, s24 // 2, s24 + (-1)*(s24 // 2), ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 2048 + primals_2 = 128 + primals_3 = 5245440 + primals_4 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:0', dtype=torch.bfloat16) + primals_5 = 2048 + primals_6 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:0', dtype=torch.bfloat16) + primals_7 = 2048 + primals_8 = rand_strided((1, 2048), (2048, 1), device='cuda:0', dtype=torch.int64) + primals_9 = 1 + primals_10 = 2 + primals_11 = 32 + primals_12 = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:0', dtype=torch.bfloat16) + primals_13 = 8 + primals_14 = rand_strided((2, 8, 2048, 128), (2097152, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/ha/chaussdp77qpxewcaeemzjepzpid324l76xe5n5k4fihzix4dkro.py b/SpecForge-ext/cache/compiled_kernels/ha/chaussdp77qpxewcaeemzjepzpid324l76xe5n5k4fihzix4dkro.py new file mode 100644 index 0000000000000000000000000000000000000000..9b3dcc83a22eca0403ba83a505a9a1ccefef84dd --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ha/chaussdp77qpxewcaeemzjepzpid324l76xe5n5k4fihzix4dkro.py @@ -0,0 +1,72 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp8 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr2 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tmp2.to(tl.float32) + tmp5 = tmp4.to(tl.float32) + tmp6 = tmp3 * tmp5 + tmp7 = tl.broadcast_to(tmp6, [XBLOCK, R0_BLOCK]) + tmp9 = _tmp8 + tmp7 + _tmp8 = tl.where(r0_mask & xmask, tmp9, _tmp8) + tmp8 = tl.sum(_tmp8, 1)[:, None] + tmp14 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last') + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp10 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp11 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp24 = tl.load(in_ptr2 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp12 = tmp10 * tmp11 + tmp13 = tmp12.to(tl.float32) + tmp15 = tmp13 * tmp14 + tmp16 = -0.5 + tmp17 = tmp8 * tmp16 + tmp18 = tmp14 * tmp14 + tmp19 = tmp18 * tmp14 + tmp20 = tmp17 * tmp19 + tmp21 = ks0 + tmp22 = tmp21.to(tl.float32) + tmp23 = (tmp20 / tmp22) + tmp25 = tmp24.to(tl.float32) + tmp26 = 2.0 + tmp27 = tmp25 * tmp26 + tmp28 = tmp23 * tmp27 + tmp29 = tmp15 + tmp28 + tmp30 = tmp29.to(tl.float32) + tl.store(out_ptr1 + (r0_1 + ks0*x0), tmp30, r0_mask & xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/ha/chay7klvou5rg46wg7dr724knn4epmemt53vvv5ezibh5ws7fezs.py b/SpecForge-ext/cache/compiled_kernels/ha/chay7klvou5rg46wg7dr724knn4epmemt53vvv5ezibh5ws7fezs.py new file mode 100644 index 0000000000000000000000000000000000000000..64faa9e4b745428435010f4c92b179efebab1150 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ha/chay7klvou5rg46wg7dr724knn4epmemt53vvv5ezibh5ws7fezs.py @@ -0,0 +1,24 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 512}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/hl/chl76ryeuxy3vqhn7zzfop2anxhx2ux6vnshomo7va3ohvp5db2o.py b/SpecForge-ext/cache/compiled_kernels/hl/chl76ryeuxy3vqhn7zzfop2anxhx2ux6vnshomo7va3ohvp5db2o.py new file mode 100644 index 0000000000000000000000000000000000000000..b8dc223a4b3da4425f7509b914cccd48a4adfb2e --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/hl/chl76ryeuxy3vqhn7zzfop2anxhx2ux6vnshomo7va3ohvp5db2o.py @@ -0,0 +1,43 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_clone_slice_sum_transpose_5', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_clone_slice_sum_transpose_5(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (x0 + ks0*r0_2 + ks0*ks1*x1), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp5, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py b/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py new file mode 100644 index 0000000000000000000000000000000000000000..d7fef5df3e9c0bea2b147aef4de360f4c1cde70c --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 8 + HQ = 32 + Q_LEN = ks0 + ZKV = 8 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 4096*idx_zq*ks0, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks5 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/hl/chlnqmsxkmscxiyicat3waah34qmfzufbkreco7bccy4ahvqzkja.py b/SpecForge-ext/cache/compiled_kernels/hl/chlnqmsxkmscxiyicat3waah34qmfzufbkreco7bccy4ahvqzkja.py new file mode 100644 index 0000000000000000000000000000000000000000..81b8323570fc93a6edcb076fde975a70cc0b8acd --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/hl/chlnqmsxkmscxiyicat3waah34qmfzufbkreco7bccy4ahvqzkja.py @@ -0,0 +1,184 @@ +# AOT ID: ['2_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/lt/cltumflylhdpmjjimvyfnj6dyacaafhbhhzyughrfgaaygjiv7l3.py +# Topologically Sorted Source Nodes: [hidden_states, pow_1, variance, rsqrt, hidden_states_1, to_1, mul_1], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] +# Source node to ATen node mapping: +# hidden_states => convert_element_type +# hidden_states_1 => mul_16 +# mul_1 => mul_23 +# pow_1 => pow_1 +# rsqrt => rsqrt +# to_1 => convert_element_type_1 +# variance => mean +# Graph fragment: +# %primals_4 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:1" = PlaceHolder[target=primals_4] +# %buf0 : Tensor "f32[s47, s87, 1][s87, 1, s47*s87]cuda:1" = PlaceHolder[target=buf0] +# %primals_5 : Tensor "f64[][]cpu" = PlaceHolder[target=primals_5] +# %primals_7 : Tensor "bf16[s33][1]cuda:1" = PlaceHolder[target=primals_7] +# %rsqrt : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:1" = PlaceHolder[target=rsqrt] +# %convert_element_type : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_4, torch.float32), kwargs = {}) +# %pow_1 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type, 2), kwargs = {}) +# %mean : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_1, [-1], True), kwargs = {}) +# %convert_element_type_default_1 : Tensor "f32[][]cpu"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_5, torch.float32), kwargs = {}) +# %add_tensor : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean, %convert_element_type_default_1), kwargs = {}) +# %rsqrt : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_tensor,), kwargs = {}) +# %mul_16 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %rsqrt), kwargs = {}) +# %convert_element_type_1 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_16, torch.bfloat16), kwargs = {}) +# %mul_23 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_7, %convert_element_type_1), kwargs = {}) +# return %buf0,%rsqrt,%mul_23 +triton_red_fused__to_copy_mean_mul_pow_rsqrt_0 = async_compile.triton('triton_red_fused__to_copy_mean_mul_pow_rsqrt_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*bf16', 'in_ptr1': 'fp64', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mean_mul_pow_rsqrt_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_mean_mul_pow_rsqrt_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tmp1 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp9 = in_ptr1 + tmp6 = ks0 + tmp7 = tmp6.to(tl.float32) + tmp8 = (tmp4 / tmp7) + tmp10 = tmp9.to(tl.float32) + tmp11 = tmp8 + tmp10 + tmp12 = libdevice.rsqrt(tmp11) + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, xmask) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp13 = tl.load(in_ptr2 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp14 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp15 = tmp14.to(tl.float32) + tmp16 = tmp15 * tmp12 + tmp17 = tmp16.to(tl.float32) + tmp18 = tmp13 * tmp17 + tl.store(out_ptr0 + (r0_1 + ks0*x0), tmp18, r0_mask & xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7 = args + args.clear() + s47 = primals_1 + s87 = primals_2 + s33 = primals_3 + s82 = primals_6 + assert_size_stride(primals_4, (s47, s87, s33), (s33*s87, s33, 1)) + assert_size_stride(primals_5, (), ()) + assert_size_stride(primals_7, (s33, ), (1, )) + with torch.cuda._DeviceGuard(1): + torch.cuda.set_device(1) + buf0 = empty_strided_cuda((s47, s87, 1), (s87, 1, s47*s87), torch.float32) + buf1 = reinterpret_tensor(buf0, (s47, s87, 1), (s87, 1, 1), 0); del buf0 # reuse + buf2 = empty_strided_cuda((s47, s87, s33), (s33*s87, s33, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [hidden_states, pow_1, variance, rsqrt, hidden_states_1, to_1, mul_1], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] + triton_red_fused__to_copy_mean_mul_pow_rsqrt_0_xnumel = s47*s87 + stream1 = get_raw_stream(1) + triton_red_fused__to_copy_mean_mul_pow_rsqrt_0.run(buf1, primals_4, primals_5.item(), primals_7, buf2, s33, triton_red_fused__to_copy_mean_mul_pow_rsqrt_0_xnumel, s33, stream=stream1) + del primals_5 + return (buf2, primals_4, primals_7, buf1, s47, s87, s33, s82, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 2 + primals_2 = 2048 + primals_3 = 4096 + primals_4 = rand_strided((2, 2048, 4096), (8388608, 4096, 1), device='cuda:1', dtype=torch.bfloat16) + primals_5 = rand_strided((), (), device='cpu', dtype=torch.float64) + primals_6 = 840433664 + primals_7 = rand_strided((4096, ), (1, ), device='cuda:1', dtype=torch.bfloat16) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/hl/e97122d4a298cb48582351e305e51250374c92c0e5dd194e32e740bb8e99ce82.best_config b/SpecForge-ext/cache/compiled_kernels/hl/e97122d4a298cb48582351e305e51250374c92c0e5dd194e32e740bb8e99ce82.best_config new file mode 100644 index 0000000000000000000000000000000000000000..152165fbb3efcb363fa441db6ed00428fa02e545 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/hl/e97122d4a298cb48582351e305e51250374c92c0e5dd194e32e740bb8e99ce82.best_config @@ -0,0 +1 @@ +{"XBLOCK": 8, "R0_BLOCK": 16, "num_warps": 2, "num_stages": 1, "configs_hash": "21ad1ee516cd6d15e1fb8e88c10082cd54bef654f8a281c7d5ccd54b6509a685", "found_by_coordesc": false, "time_taken_ms": 28, "triton_cache_hash": "JOVMEF5UB3XPAO2GMKSDIGWCQSZP2YWR3UL465XFLBODYHCLBU5A"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/hw/chwm44jdqtovypwqknevqvz2d2xrazceb4ci2erooz4tahlocvzv.py b/SpecForge-ext/cache/compiled_kernels/hw/chwm44jdqtovypwqknevqvz2d2xrazceb4ci2erooz4tahlocvzv.py new file mode 100644 index 0000000000000000000000000000000000000000..5e20c29ee0a8f07be2b3cda8d1ae037424d9bb89 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/hw/chwm44jdqtovypwqknevqvz2d2xrazceb4ci2erooz4tahlocvzv.py @@ -0,0 +1,25 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4096}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 17408}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_1(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 2176 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/hw/chwnvjc3qe5t3kaqygcq4upz4lcjepue5wpypsdd5cjmga5upazl.py b/SpecForge-ext/cache/compiled_kernels/hw/chwnvjc3qe5t3kaqygcq4upz4lcjepue5wpypsdd5cjmga5upazl.py new file mode 100644 index 0000000000000000000000000000000000000000..732e053c24db6c6bd7a19345ce96ed7e0ed78542 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/hw/chwnvjc3qe5t3kaqygcq4upz4lcjepue5wpypsdd5cjmga5upazl.py @@ -0,0 +1,682 @@ +# AOT ID: ['12_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/w3/cw3k6gc52pbbxkrr6n2khmmz6lnybmqvwoosxgxpbotxdetxuz5h.py +# Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] +# Source node to ATen node mapping: +# dense_mask_2 => full_default_1 +# Graph fragment: +# %full_default_1 : Tensor "i32[2, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, %floordiv_3, %add_201], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:1, pin_memory: False}) +# return %index_put +triton_poi_fused_new_zeros_0 = async_compile.triton('triton_poi_fused_new_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 256}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/gf/cgfbvvwjc2yj2x7xwrsusmg5ycv3kcfuxsdtllw6bjgcnmagt3ep.py +# Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_1, mask_2, mask_3, mask_block_sum, gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, full_blocks, full_blocks_1, dense_mask_1], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.constant_pad_nd, aten.permute, aten.sum, aten.gt, aten._to_copy] +# Source node to ATen node mapping: +# and_2 => bitwise_and_1 +# and_3 => bitwise_and_2 +# and_4 => bitwise_and_3, view_8 +# b => iota +# batched_outputs_2 => view_9 +# causal_mask => ge_2, view +# dense_mask => convert_element_type_2 +# dense_mask_1 => convert_element_type_5 +# diagnol_mask => eq_24 +# full_blocks => eq_45 +# full_blocks_1 => convert_element_type_1 +# gt => gt +# index => index +# index_1 => index_1 +# index_2 => index_2 +# lt => lt, view_1 +# lt_1 => lt_1, view_2 +# lt_3 => lt_3 +# m => iota_2 +# mask_1 => constant_pad_nd +# mask_2 => view_10 +# mask_3 => permute +# mask_block_sum => sum_1 +# n => iota_3 +# padding_mask => bitwise_and, view_3, view_4 +# padding_mask_1 => lt_2, view_6 +# partial_blocks => bitwise_and_4 +# partial_blocks_1 => convert_element_type +# remainder => remainder +# remainder_1 => remainder_1 +# result_1 => bitwise_or, full_default +# result_2 => bitwise_or_1 +# sub => sub_24, view_7 +# suffix_mask => ge_3 +# Graph fragment: +# %arg2_1 : Tensor "i64[2][1]cuda:1" = PlaceHolder[target=arg2_1] +# %sum_1 : Tensor "i64[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][(((s12 + 127)//128))*(((s37 + 127)//128)), 2*(((s12 + 127)//128))*(((s37 + 127)//128)), ((s37 + 127)//128), 1]cuda:1" = PlaceHolder[target=sum_1] +# %full_default : Tensor "b8[2, 1, 1][1, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 1], False), kwargs = {dtype: torch.bool, layout: torch.strided, device: cuda:1, pin_memory: False}) +# %iota_2 : Tensor "i64[s12][1]cuda:1"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (%arg0_1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:1, requires_grad: False}) +# %view : Tensor "i64[s12, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [%arg0_1, 1]), kwargs = {}) +# %iota_3 : Tensor "i64[s37][1]cuda:1"[num_users=5] = call_function[target=torch.ops.prims.iota.default](args = (%arg1_1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:1, requires_grad: False}) +# %ge_2 : Tensor "b8[s12, s37][Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.ge.Tensor](args = (%view, %iota_3), kwargs = {}) +# %iota : Tensor "i64[2][1]cuda:1"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (2,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:1, requires_grad: False}) +# %index : Tensor "i64[2][1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%iota]), kwargs = {}) +# %view_1 : Tensor "i64[2, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index, [2, 1]), kwargs = {}) +# %lt : Tensor "b8[2, s37][Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_3, %view_1), kwargs = {}) +# %view_4 : Tensor "b8[2, 1, s37][Max(1, s37), s37, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt, [2, 1, %arg1_1]), kwargs = {}) +# %index_1 : Tensor "i64[2][1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%iota]), kwargs = {}) +# %view_2 : Tensor "i64[2, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_1, [2, 1]), kwargs = {}) +# %lt_1 : Tensor "b8[2, s12][Max(1, s12), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_2, %view_2), kwargs = {}) +# %view_3 : Tensor "b8[2, s12, 1][Max(1, s12), 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt_1, [2, %arg0_1, 1]), kwargs = {}) +# %bitwise_and : Tensor "b8[2, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_4, %view_3), kwargs = {}) +# %bitwise_and_1 : Tensor "b8[2, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_2, %bitwise_and), kwargs = {}) +# %bitwise_or : Tensor "b8[2, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%full_default, %bitwise_and_1), kwargs = {}) +# %ge_3 : Tensor "b8[s37][1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%iota_3, %arg3_1), kwargs = {}) +# %remainder : Tensor "i64[s37][1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%iota_3, %arg3_1), kwargs = {}) +# %index_2 : Tensor "i64[2][1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%iota]), kwargs = {}) +# %view_6 : Tensor "i64[2, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_2, [2, 1]), kwargs = {}) +# %lt_2 : Tensor "b8[2, s37][Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%remainder, %view_6), kwargs = {}) +# %bitwise_and_2 : Tensor "b8[2, s37][Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_3, %lt_2), kwargs = {}) +# %view_8 : Tensor "b8[2, 1, s37][Max(1, s37), s37, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_and_2, [2, 1, %arg1_1]), kwargs = {}) +# %view_7 : Tensor "i64[s12, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [%arg0_1, 1]), kwargs = {}) +# %sub_24 : Tensor "i64[s12, s37][Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%iota_3, %view_7), kwargs = {}) +# %remainder_1 : Tensor "i64[s12, s37][Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%sub_24, %arg3_1), kwargs = {}) +# %eq_24 : Tensor "b8[s12, s37][Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%remainder_1, 0), kwargs = {}) +# %bitwise_and_3 : Tensor "b8[2, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_8, %eq_24), kwargs = {}) +# %bitwise_or_1 : Tensor "b8[2, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%bitwise_or, %bitwise_and_3), kwargs = {}) +# %view_9 : Tensor "b8[2, 1, s12, s37][Max(1, s12)*Max(1, s37), s12*Max(1, s37), Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_or_1, [2, 1, %arg0_1, %arg1_1]), kwargs = {}) +# %constant_pad_nd : Tensor "b8[2, 1, 128*(((s12 + 127)//128)), 128*(((s37 + 127)//128))][Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s37 + 127)//128))), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.constant_pad_nd.default](args = (%expand, [0, %sub_42, 0, %sub_44], 0.0), kwargs = {}) +# %view_10 : Tensor "b8[2, 1, ((s12 + 127)//128), 128, ((s37 + 127)//128), 128][Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), 128*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s37 + 127)//128))), 128, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%constant_pad_nd, [2, 1, %floordiv_3, 128, %floordiv_2, 128]), kwargs = {}) +# %permute : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128), 128, 128][Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), 128*Max(1, 128*(((s37 + 127)//128))), 128, Max(1, 128*(((s37 + 127)//128))), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 1, 2, 4, 3, 5]), kwargs = {}) +# %sum_1 : Tensor "i64[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=3] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute, [-2, -1]), kwargs = {}) +# %gt : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {}) +# %lt_3 : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %bitwise_and_4 : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%gt, %lt_3), kwargs = {}) +# %convert_element_type : Tensor "i8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%bitwise_and_4, torch.int8), kwargs = {}) +# %convert_element_type_2 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type, torch.int32), kwargs = {}) +# %eq_45 : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %convert_element_type_1 : Tensor "i8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%eq_45, torch.int8), kwargs = {}) +# %convert_element_type_5 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_1, torch.int32), kwargs = {}) +# return %sum_1,%convert_element_type_2,%convert_element_type_5 +triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1 = async_compile.triton('triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 256, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'ks5': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1(in_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, ks3, ks4, ks5, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = ((xindex // ks0) % ks1) + x0 = (xindex % ks0) + x2 = xindex // ks4 + _tmp46 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x5 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_4 = r0_index // 128 + r0_3 = (r0_index % 128) + tmp0 = r0_4 + 128*x1 + tmp1 = ks2 + tmp2 = tmp0 < tmp1 + tmp3 = r0_3 + 128*x0 + tmp4 = ks3 + tmp5 = tmp3 < tmp4 + tmp6 = tmp2 & tmp5 + tmp7 = r0_4 + 128*x1 + tmp8 = r0_3 + 128*x0 + tmp9 = tmp7 >= tmp8 + tmp10 = tl.load(in_ptr0 + (tl.broadcast_to(x2, [XBLOCK, R0_BLOCK])), r0_mask & tmp6 & xmask, eviction_policy='evict_last', other=0.0) + tmp11 = tmp8 < tmp10 + tmp12 = tmp7 < tmp10 + tmp13 = tmp11 & tmp12 + tmp14 = tmp9 & tmp13 + tmp15 = tl.full([1, 1], False, tl.int1) + tmp16 = tmp15 | tmp14 + tmp17 = tl.broadcast_to(ks5, [XBLOCK, R0_BLOCK]) + tmp18 = tmp8 >= tmp17 + tmp19 = (tmp8 % tmp17) + tmp20 = tl.full([1, 1], 0, tl.int32) + tmp21 = tmp19 != tmp20 + tmp22 = (libdevice.signbit(tmp19) != 0) if (tmp19).dtype is tl.float32 else tmp19 < 0 + tmp23 = (libdevice.signbit(tmp17) != 0) if (tmp17).dtype is tl.float32 else tmp17 < 0 + tmp24 = tmp22 != tmp23 + tmp25 = tmp21 & tmp24 + tmp26 = tmp19 + tmp17 + tmp27 = tl.where(tmp25, tmp26, tmp19) + tmp28 = tmp27 < tmp10 + tmp29 = tmp18 & tmp28 + tmp30 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp31 = (tmp30 % tmp17) + tmp32 = tmp31 != tmp20 + tmp33 = (libdevice.signbit(tmp31) != 0) if (tmp31).dtype is tl.float32 else tmp31 < 0 + tmp34 = tmp33 != tmp23 + tmp35 = tmp32 & tmp34 + tmp36 = tmp31 + tmp17 + tmp37 = tl.where(tmp35, tmp36, tmp31) + tmp38 = tl.full([1, 1], 0, tl.int64) + tmp39 = tmp37 == tmp38 + tmp40 = tmp29 & tmp39 + tmp41 = tmp16 | tmp40 + tmp42 = tl.full(tmp41.shape, False, tmp41.dtype) + tmp43 = tl.where(tmp6, tmp41, tmp42) + tmp44 = tmp43.to(tl.int64) + tmp45 = tl.broadcast_to(tmp44, [XBLOCK, R0_BLOCK]) + tmp47 = _tmp46 + tmp45 + _tmp46 = tl.where(r0_mask & xmask, tmp47, _tmp46) + tmp46 = tl.sum(_tmp46, 1)[:, None] + tmp48 = tl.full([1, 1], 0, tl.int64) + tmp49 = tmp46 > tmp48 + tmp50 = tl.full([1, 1], 16384, tl.int64) + tmp51 = tmp46 < tmp50 + tmp52 = tmp49 & tmp51 + tmp53 = tmp52.to(tl.int8) + tmp54 = tmp53.to(tl.int32) + tmp55 = tmp46 == tmp50 + tmp56 = tmp55.to(tl.int8) + tmp57 = tmp56.to(tl.int32) + tl.store(out_ptr1 + (x5), tmp54, xmask) + tl.store(out_ptr2 + (x5), tmp57, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/2x/c2xsu5ssb3jappbwwrbr53muiaoukfjzccks7reewucgvplouktq.py +# Topologically Sorted Source Nodes: [num_blocks_in_row, child_3], Original ATen: [aten.sum, aten._to_copy] +# Source node to ATen node mapping: +# child_3 => convert_element_type_3 +# num_blocks_in_row => sum_2 +# Graph fragment: +# %convert_element_type_2 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][(((s12 + 127)//128))*(((s37 + 127)//128)), 2*(((s12 + 127)//128))*(((s37 + 127)//128)), ((s37 + 127)//128), 1]cuda:1" = PlaceHolder[target=convert_element_type_2] +# %sum_2 : Tensor "i64[2, 1, ((s12 + 127)//128)][((s12 + 127)//128), 2*(((s12 + 127)//128)), 1]cuda:1" = PlaceHolder[target=sum_2] +# %sum_2 : Tensor "i64[2, 1, ((s12 + 127)//128)][Max(1, ((s12 + 127)//128)), Max(1, ((s12 + 127)//128)), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_2, [-1]), kwargs = {}) +# %convert_element_type_3 : Tensor "i32[2, 1, ((s12 + 127)//128)][Max(1, ((s12 + 127)//128)), Max(1, ((s12 + 127)//128)), 1]cuda:1"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_2, torch.int32), kwargs = {}) +# return %sum_2,%convert_element_type_3 +triton_red_fused__to_copy_sum_2 = async_compile.triton('triton_red_fused__to_copy_sum_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_sum_2(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + x2 = (xindex % ks1) + x3 = xindex // ks1 + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x2 + x3*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp5, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/cb/ccbi42os5amfzotf76oejmkqyoxwwst4z6yh5t35hl5uieutsy7k.py +# Topologically Sorted Source Nodes: [dense_mask_2, setitem, arange_4, row_indices, col_range, unsqueeze_1, index_mask, child_4, valid_indices], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.lt, aten._to_copy, aten.scalar_tensor, aten.where, aten.view, aten.index_put] +# Source node to ATen node mapping: +# arange_4 => iota_4 +# child_4 => convert_element_type_4 +# col_range => iota_5 +# dense_mask_2 => full_default_1 +# index_mask => lt_4 +# row_indices => unsqueeze +# setitem => full_default_2, index_put, iota_6, iota_7, unsqueeze_2, unsqueeze_3, unsqueeze_4, unsqueeze_5, unsqueeze_6 +# unsqueeze_1 => unsqueeze_1 +# valid_indices => scalar_tensor, where +# Graph fragment: +# %getitem_1 : Tensor "i64[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), 2*Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1" = PlaceHolder[target=getitem_1] +# %convert_element_type_3 : Tensor "i32[2, 1, ((s12 + 127)//128)][Max(1, ((s12 + 127)//128)), Max(1, ((s12 + 127)//128)), 1]cuda:1" = PlaceHolder[target=convert_element_type_3] +# %convert_element_type_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1" = PlaceHolder[target=convert_element_type_4] +# %index_put : Tensor "i32[2, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][((((s37 + 127)//128)) + 1)*(((s12 + 127)//128)), ((((s37 + 127)//128)) + 1)*(((s12 + 127)//128)), (((s37 + 127)//128)) + 1, 1]cuda:1" = PlaceHolder[target=index_put] +# %full_default_1 : Tensor "i32[2, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, %floordiv_3, %add_201], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:1, pin_memory: False}) +# %iota_7 : Tensor "i64[2][1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (2,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:1, requires_grad: False}) +# %unsqueeze_4 : Tensor "i64[2, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_7, -1), kwargs = {}) +# %unsqueeze_5 : Tensor "i64[2, 1, 1][1, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_4, -1), kwargs = {}) +# %unsqueeze_6 : Tensor "i64[2, 1, 1, 1][1, 1, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_5, -1), kwargs = {}) +# %iota_6 : Tensor "i64[1][1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:1, requires_grad: False}) +# %unsqueeze_2 : Tensor "i64[1, 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_6, -1), kwargs = {}) +# %unsqueeze_3 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_2, -1), kwargs = {}) +# %iota_4 : Tensor "i32[((s12 + 127)//128)][1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (%floordiv_3,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:1, requires_grad: False}) +# %unsqueeze : Tensor "i32[((s12 + 127)//128), 1][1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_4, -1), kwargs = {}) +# %iota_5 : Tensor "i32[((s37 + 127)//128)][1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (%floordiv_2,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:1, requires_grad: False}) +# %unsqueeze_1 : Tensor "i32[2, 1, ((s12 + 127)//128), 1][Max(1, ((s12 + 127)//128)), Max(1, ((s12 + 127)//128)), 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_3, 3), kwargs = {}) +# %lt_4 : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_5, %unsqueeze_1), kwargs = {}) +# %convert_element_type_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_1, torch.int32), kwargs = {}) +# %scalar_tensor : Tensor "i32[][]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (%floordiv_2,), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:1}) +# %where : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_4, %convert_element_type_4, %scalar_tensor), kwargs = {}) +# %full_default_2 : Tensor "i32[2, 1, 1, 1][1, 1, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:1, pin_memory: False}) +# %index_put : Tensor "i32[2, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_1, [%unsqueeze_6, %unsqueeze_3, %unsqueeze, %where], %full_default_2), kwargs = {}) +# return %convert_element_type_4,%buf13 +triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3 = async_compile.triton('triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 256}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i32', 'out_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3', 'mutated_arg_names': ['out_ptr1'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3(in_ptr0, in_ptr1, out_ptr0, out_ptr1, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % ks1) + x2 = xindex // ks2 + x3 = xindex // ks0 + tmp0 = tl.load(in_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), xmask, eviction_policy='evict_last') + tmp2 = tl.load(in_ptr1 + (x3), xmask, eviction_policy='evict_last') + tmp1 = tmp0.to(tl.int32) + tmp3 = x0 + tmp4 = tmp3 < tmp2 + tmp5 = ks0 + tmp6 = tl.where(tmp4, tmp1, tmp5) + tmp7 = 1 + ks0 + tmp8 = tmp6 + tmp7 + tmp9 = tmp6 < 0 + tmp10 = tl.where(tmp9, tmp8, tmp6) + tl.device_assert(((0 <= tmp10) & (tmp10 < 1 + (triton_helpers.div_floor_integer(127 + ks3, 128)))) | ~(xmask), "index out of bounds: 0 <= tmp10 < 1 + (triton_helpers.div_floor_integer(127 + ks3, 128))") + tmp12 = tl.full([1], 1, tl.int32) + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp1, xmask) + tl.store(out_ptr1 + (tmp10 + x3 + ks0*x3), tmp12, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/rc/crc3qjkg45bppwc6zjt3chvkhoetzafaxeolofypru5p7mguzfwr.py +# Topologically Sorted Source Nodes: [batched_outputs_3], Original ATen: [aten.slice, aten.clone] +# Source node to ATen node mapping: +# batched_outputs_3 => clone_4, slice_4 +# Graph fragment: +# %buf13 : Tensor "i32[2, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][((((s37 + 127)//128)) + 1)*(((s12 + 127)//128)), ((((s37 + 127)//128)) + 1)*(((s12 + 127)//128)), (((s37 + 127)//128)) + 1, 1]cuda:1" = PlaceHolder[target=buf13] +# %slice_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, %floordiv_2), kwargs = {}) +# %clone_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_4,), kwargs = {memory_format: torch.contiguous_format}) +# return %clone_4 +triton_poi_fused_clone_slice_4 = async_compile.triton('triton_poi_fused_clone_slice_4', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 256}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr0': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_slice_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_clone_slice_4(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = xindex // ks0 + x2 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + x1 + ks0*x1), xmask, eviction_policy='evict_last') + tl.store(out_ptr0 + (x2), tmp0, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/6k/c6kudx3slegvkfu3wpku4cl3hnmrugqb6m3nzkue5pbweq5fr7xz.py +# Topologically Sorted Source Nodes: [batched_outputs_3, transpose, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sum, aten._to_copy] +# Source node to ATen node mapping: +# batched_outputs_3 => clone_4, slice_4 +# num_blocks_in_row_2 => sum_4 +# q_num_blocks => convert_element_type_8 +# transpose => permute_1 +# Graph fragment: +# %clone_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][(((s12 + 127)//128))*(((s37 + 127)//128)), 1, ((s37 + 127)//128), 1]cuda:1" = PlaceHolder[target=clone_4] +# %sum_4 : Tensor "i64[2, 1, ((s37 + 127)//128)][((s37 + 127)//128), 2*(((s37 + 127)//128)), 1]cuda:1" = PlaceHolder[target=sum_4] +# %slice_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, %floordiv_2), kwargs = {}) +# %clone_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_4,), kwargs = {memory_format: torch.contiguous_format}) +# %permute_1 : Tensor "i32[2, 1, ((s37 + 127)//128), ((s12 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:1"[num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%clone_4, [0, 1, 3, 2]), kwargs = {}) +# %sum_4 : Tensor "i64[2, 1, ((s37 + 127)//128)][Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute_1, [-1]), kwargs = {}) +# %convert_element_type_8 : Tensor "i32[2, 1, ((s37 + 127)//128)][Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_4, torch.int32), kwargs = {}) +# return %sum_4,%convert_element_type_8 +triton_red_fused__to_copy_clone_slice_sum_transpose_5 = async_compile.triton('triton_red_fused__to_copy_clone_slice_sum_transpose_5', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_clone_slice_sum_transpose_5', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_clone_slice_sum_transpose_5(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (x0 + ks0*r0_2 + ks0*ks1*x1), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp5, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/pf/cpfql25wx6lqur42rsrcuy2k7352hze2vsa4tk5bnjcc2fpq7wvt.py +# Topologically Sorted Source Nodes: [q_indices], Original ATen: [aten._to_copy] +# Source node to ATen node mapping: +# q_indices => clone_6, convert_element_type_9 +# Graph fragment: +# %getitem_5 : Tensor "i64[2, 1, ((s37 + 127)//128), ((s12 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:1" = PlaceHolder[target=getitem_5] +# %convert_element_type_9 : Tensor "i32[2, 1, ((s37 + 127)//128), ((s12 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_5, torch.int32), kwargs = {}) +# %clone_6 : Tensor "i32[2, 1, ((s37 + 127)//128), ((s12 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128)), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_9,), kwargs = {memory_format: torch.contiguous_format}) +# return %clone_6 +triton_poi_fused__to_copy_6 = async_compile.triton('triton_poi_fused__to_copy_6', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 256}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_6', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused__to_copy_6(in_ptr0, out_ptr0, ks0, ks1, ks2, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % ks1) + x2 = xindex // ks2 + tmp0 = tl.load(in_ptr0 + (x1 + x0*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), xmask, eviction_policy='evict_last') + tmp1 = tmp0.to(tl.int32) + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp1, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1 = args + args.clear() + s12 = arg0_1 + s37 = arg1_1 + s21 = arg3_1 + assert_size_stride(arg2_1, (2, ), (1, )) + with torch.cuda._DeviceGuard(1): + torch.cuda.set_device(1) + buf12 = empty_strided_cuda((2, 1, (127 + s12) // 128, 1 + ((127 + s37) // 128)), (((127 + s12) // 128)*((127 + s37) // 128) + ((127 + s12) // 128), ((127 + s12) // 128)*((127 + s37) // 128) + ((127 + s12) // 128), 1 + ((127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] + triton_poi_fused_new_zeros_0_xnumel = 2*((127 + s12) // 128) + 2*((127 + s12) // 128)*((127 + s37) // 128) + stream1 = get_raw_stream(1) + triton_poi_fused_new_zeros_0.run(buf12, triton_poi_fused_new_zeros_0_xnumel, stream=stream1) + buf21 = empty_strided_cuda((2, 1, (127 + s12) // 128, 1 + ((127 + s37) // 128)), (((127 + s12) // 128)*((127 + s37) // 128) + ((127 + s12) // 128), ((127 + s12) // 128)*((127 + s37) // 128) + ((127 + s12) // 128), 1 + ((127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros] + triton_poi_fused_new_zeros_0_xnumel = 2*((127 + s12) // 128) + 2*((127 + s12) // 128)*((127 + s37) // 128) + stream1 = get_raw_stream(1) + triton_poi_fused_new_zeros_0.run(buf21, triton_poi_fused_new_zeros_0_xnumel, stream=stream1) + ps0 = (127 + s37) // 128 + ps1 = (127 + s12) // 128 + ps2 = ((127 + s12) // 128)*((127 + s37) // 128) + buf1 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 2*((127 + s12) // 128)*((127 + s37) // 128), (127 + s37) // 128, 1), torch.int32) + buf5 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 2*((127 + s12) // 128)*((127 + s37) // 128), (127 + s37) // 128, 1), torch.int32) + # Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_1, mask_2, mask_3, mask_block_sum, gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, full_blocks, full_blocks_1, dense_mask_1], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.constant_pad_nd, aten.permute, aten.sum, aten.gt, aten._to_copy] + triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream1 = get_raw_stream(1) + triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1.run(arg2_1, buf1, buf5, ps0, ps1, s12, s37, ps2, s21, triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1_xnumel, 16384, stream=stream1) + del arg2_1 + buf10 = empty_strided_cuda((2, 1, (127 + s12) // 128), (max(1, (127 + s12) // 128), max(1, (127 + s12) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [num_blocks_in_row, child_3], Original ATen: [aten.sum, aten._to_copy] + triton_red_fused__to_copy_sum_2_xnumel = 2*((127 + s12) // 128) + triton_red_fused__to_copy_sum_2_r0_numel = (127 + s37) // 128 + stream1 = get_raw_stream(1) + triton_red_fused__to_copy_sum_2.run(buf1, buf10, ps0, ps1, triton_red_fused__to_copy_sum_2_xnumel, triton_red_fused__to_copy_sum_2_r0_numel, stream=stream1) + buf19 = empty_strided_cuda((2, 1, (127 + s12) // 128), (max(1, (127 + s12) // 128), max(1, (127 + s12) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [num_blocks_in_row_1, child_7], Original ATen: [aten.sum, aten._to_copy] + triton_red_fused__to_copy_sum_2_xnumel = 2*((127 + s12) // 128) + triton_red_fused__to_copy_sum_2_r0_numel = (127 + s37) // 128 + stream1 = get_raw_stream(1) + triton_red_fused__to_copy_sum_2.run(buf5, buf19, ps0, ps1, triton_red_fused__to_copy_sum_2_xnumel, triton_red_fused__to_copy_sum_2_r0_numel, stream=stream1) + # Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort] + buf2 = torch.ops.aten.sort.stable(buf1, stable=True, dim=3, descending=True) + del buf1 + buf4 = buf2[1] + assert_size_stride(buf4, (2, 1, (127 + s12) // 128, (127 + s37) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), 2*max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), 'torch.ops.aten.sort.stable') + assert_alignment(buf4, 16, 'torch.ops.aten.sort.stable') + del buf2 + buf11 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2, setitem, arange_4, row_indices, col_range, unsqueeze_1, index_mask, child_4, valid_indices], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.lt, aten._to_copy, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream1 = get_raw_stream(1) + triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3.run(buf4, buf10, buf11, buf12, ps0, ps1, ps2, s37, triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3_xnumel, stream=stream1) + del buf4 + buf14 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 1, (127 + s37) // 128, 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_3], Original ATen: [aten.slice, aten.clone] + triton_poi_fused_clone_slice_4_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream1 = get_raw_stream(1) + triton_poi_fused_clone_slice_4.run(buf12, buf14, ps0, triton_poi_fused_clone_slice_4_xnumel, stream=stream1) + del buf12 + buf32 = empty_strided_cuda((2, 1, (127 + s37) // 128), (max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sum, aten._to_copy] + triton_red_fused__to_copy_clone_slice_sum_transpose_5_xnumel = 2*((127 + s37) // 128) + triton_red_fused__to_copy_clone_slice_sum_transpose_5_r0_numel = (127 + s12) // 128 + stream1 = get_raw_stream(1) + triton_red_fused__to_copy_clone_slice_sum_transpose_5.run(buf14, buf32, ps0, ps1, triton_red_fused__to_copy_clone_slice_sum_transpose_5_xnumel, triton_red_fused__to_copy_clone_slice_sum_transpose_5_r0_numel, stream=stream1) + # Topologically Sorted Source Nodes: [full_blocks, full_blocks_1, dense_mask_1, col_indices_1], Original ATen: [aten.eq, aten._to_copy, aten.sort] + buf6 = torch.ops.aten.sort.stable(buf5, stable=True, dim=3, descending=True) + del buf5 + buf8 = buf6[1] + assert_size_stride(buf8, (2, 1, (127 + s12) // 128, (127 + s37) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), 2*max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), 'torch.ops.aten.sort.stable') + assert_alignment(buf8, 16, 'torch.ops.aten.sort.stable') + del buf6 + buf20 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.lt, aten._to_copy, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream1 = get_raw_stream(1) + triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3.run(buf8, buf19, buf20, buf21, ps0, ps1, ps2, s37, triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3_xnumel, stream=stream1) + del buf8 + buf23 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 1, (127 + s37) // 128, 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_5], Original ATen: [aten.slice, aten.clone] + triton_poi_fused_clone_slice_4_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream1 = get_raw_stream(1) + triton_poi_fused_clone_slice_4.run(buf21, buf23, ps0, triton_poi_fused_clone_slice_4_xnumel, stream=stream1) + del buf21 + buf29 = empty_strided_cuda((2, 1, (127 + s37) // 128), (max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, num_blocks_in_row_3, full_q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sum, aten._to_copy] + triton_red_fused__to_copy_clone_slice_sum_transpose_5_xnumel = 2*((127 + s37) // 128) + triton_red_fused__to_copy_clone_slice_sum_transpose_5_r0_numel = (127 + s12) // 128 + stream1 = get_raw_stream(1) + triton_red_fused__to_copy_clone_slice_sum_transpose_5.run(buf23, buf29, ps0, ps1, triton_red_fused__to_copy_clone_slice_sum_transpose_5_xnumel, triton_red_fused__to_copy_clone_slice_sum_transpose_5_r0_numel, stream=stream1) + # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort] + buf15 = torch.ops.aten.sort.stable(reinterpret_tensor(buf14, (2, 1, (127 + s37) // 128, (127 + s12) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 0, 1, (127 + s37) // 128), 0), stable=True, dim=3, descending=True) + del buf14 + buf17 = buf15[1] + assert_size_stride(buf17, (2, 1, (127 + s37) // 128, (127 + s12) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), 1, max(1, (127 + s37) // 128)), 'torch.ops.aten.sort.stable') + assert_alignment(buf17, 16, 'torch.ops.aten.sort.stable') + del buf15 + buf30 = empty_strided_cuda((2, 1, (127 + s37) // 128, (127 + s12) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [q_indices], Original ATen: [aten._to_copy] + triton_poi_fused__to_copy_6_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream1 = get_raw_stream(1) + triton_poi_fused__to_copy_6.run(buf17, buf30, ps1, ps0, ps2, triton_poi_fused__to_copy_6_xnumel, stream=stream1) + del buf17 + # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, col_indices_3], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort] + buf24 = torch.ops.aten.sort.stable(reinterpret_tensor(buf23, (2, 1, (127 + s37) // 128, (127 + s12) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 0, 1, (127 + s37) // 128), 0), stable=True, dim=3, descending=True) + del buf23 + buf26 = buf24[1] + assert_size_stride(buf26, (2, 1, (127 + s37) // 128, (127 + s12) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), 1, max(1, (127 + s37) // 128)), 'torch.ops.aten.sort.stable') + assert_alignment(buf26, 16, 'torch.ops.aten.sort.stable') + del buf24 + buf27 = empty_strided_cuda((2, 1, (127 + s37) // 128, (127 + s12) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [full_q_indices], Original ATen: [aten._to_copy] + triton_poi_fused__to_copy_6_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream1 = get_raw_stream(1) + triton_poi_fused__to_copy_6.run(buf26, buf27, ps1, ps0, ps2, triton_poi_fused__to_copy_6_xnumel, stream=stream1) + del buf26 + return (buf27, buf29, buf30, buf32, buf20, buf19, buf11, buf10, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 1041 + arg1_1 = 1041 + arg2_1 = rand_strided((2, ), (1, ), device='cuda:1', dtype=torch.int64) + arg3_1 = 1041 + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/hw/ebac14ba1afffa0fc5fa33c387122f5af9b63c375d050128cd3f4e096bf0ccb1.best_config b/SpecForge-ext/cache/compiled_kernels/hw/ebac14ba1afffa0fc5fa33c387122f5af9b63c375d050128cd3f4e096bf0ccb1.best_config new file mode 100644 index 0000000000000000000000000000000000000000..7d56ea7451f6ff3ceffec392bc015b86ab20533e --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/hw/ebac14ba1afffa0fc5fa33c387122f5af9b63c375d050128cd3f4e096bf0ccb1.best_config @@ -0,0 +1 @@ +{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "4UWYNBR3KPWQGNAZ5LIIRE7YAZWTQP4CP3JS6GOSLWYDF5K7WTAA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/i3/ci3rczvkjfzjjdwh5p2iayrefknfewqfaczix7f5ctyg3kjbeozr.py b/SpecForge-ext/cache/compiled_kernels/i3/ci3rczvkjfzjjdwh5p2iayrefknfewqfaczix7f5ctyg3kjbeozr.py new file mode 100644 index 0000000000000000000000000000000000000000..53e1ebaaa49a692b8bd37f60269168873d9c1fbd --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/i3/ci3rczvkjfzjjdwh5p2iayrefknfewqfaczix7f5ctyg3kjbeozr.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 8 + HQ = 32 + Q_LEN = ks0 + ZKV = 8 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 4096*idx_zq*ks0, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks5 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/j6/96d096a43024530987db4ea996df0be2b36c9c72d0d854de4c0e243abc789527.best_config b/SpecForge-ext/cache/compiled_kernels/j6/96d096a43024530987db4ea996df0be2b36c9c72d0d854de4c0e243abc789527.best_config new file mode 100644 index 0000000000000000000000000000000000000000..5a032d4d2fd9dd4986ad3ca853be0e503c6ef5e0 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/j6/96d096a43024530987db4ea996df0be2b36c9c72d0d854de4c0e243abc789527.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 14, "triton_cache_hash": "CYLQGL4LYUHWUHXWUOPJ5IHPQCUACWXMO577UIG54KOMYOQPA6IQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/j6/cj6lb2lwab43zvel34z3wsdzgjns7efwyvjd2ycexqj3bnayivh6.py b/SpecForge-ext/cache/compiled_kernels/j6/cj6lb2lwab43zvel34z3wsdzgjns7efwyvjd2ycexqj3bnayivh6.py new file mode 100644 index 0000000000000000000000000000000000000000..0fecb8cd363e18f7c7894d14e77fba0131dfbb72 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/j6/cj6lb2lwab43zvel34z3wsdzgjns7efwyvjd2ycexqj3bnayivh6.py @@ -0,0 +1,24 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 1024}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/ja/cjac3canrif2aicy6istd4fkmpkoc3owxseahly23eb6yus63d7a.py b/SpecForge-ext/cache/compiled_kernels/ja/cjac3canrif2aicy6istd4fkmpkoc3owxseahly23eb6yus63d7a.py new file mode 100644 index 0000000000000000000000000000000000000000..a4ed7836806197a05c6dd1cfab1e85dbd9c7512f --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ja/cjac3canrif2aicy6istd4fkmpkoc3owxseahly23eb6yus63d7a.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 2097152, 262144, 128, 1 + + ZQ = 2 + HQ = 32 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/jc/6d1f3e9e8f70908acd2fb5b21d8b0873b636ce2f79878a945366f1df268b92a7.best_config b/SpecForge-ext/cache/compiled_kernels/jc/6d1f3e9e8f70908acd2fb5b21d8b0873b636ce2f79878a945366f1df268b92a7.best_config new file mode 100644 index 0000000000000000000000000000000000000000..990be040d913054ee650201b25cf2c95af882efd --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/jc/6d1f3e9e8f70908acd2fb5b21d8b0873b636ce2f79878a945366f1df268b92a7.best_config @@ -0,0 +1 @@ +{"XBLOCK": 8, "R0_BLOCK": 128, "num_warps": 8, "num_stages": 1, "configs_hash": "b70837e3723f218c7368cc2b49566dcd2bec3baf4c88b5e174a3f0822a6c86c0", "found_by_coordesc": false, "time_taken_ms": 142, "triton_cache_hash": "BZ2FPB5QIE7EHR6P7EPVPHR4HKS3YX3QQPIWQIT2R3EOJOAVWCGA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/jc/cjcezd4fm2g2fppy44lhtzc36sz7bi63sscwdmenwlvu3y4xt7np.py b/SpecForge-ext/cache/compiled_kernels/jc/cjcezd4fm2g2fppy44lhtzc36sz7bi63sscwdmenwlvu3y4xt7np.py new file mode 100644 index 0000000000000000000000000000000000000000..72436e6af33b6cefb9e7f55afd0c00fe5ea6fee5 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/jc/cjcezd4fm2g2fppy44lhtzc36sz7bi63sscwdmenwlvu3y4xt7np.py @@ -0,0 +1,49 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 524288, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 4194304, 'r0_': 268435456}} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 524288 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = ((xindex // 2048) % 32) + x2 = xindex // 65536 + x4 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, None) diff --git a/SpecForge-ext/cache/compiled_kernels/jc/cjckdkumcnlkvhcgcfckom4kb3kkdpks5eouyhcpnwscklfm3o54.py b/SpecForge-ext/cache/compiled_kernels/jc/cjckdkumcnlkvhcgcfckom4kb3kkdpks5eouyhcpnwscklfm3o54.py new file mode 100644 index 0000000000000000000000000000000000000000..84be8597e084ebb1c03187e86a638ceb3b5b9cfe --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/jc/cjckdkumcnlkvhcgcfckom4kb3kkdpks5eouyhcpnwscklfm3o54.py @@ -0,0 +1,63 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'add_persistent_rblock': True, 'tiling_scores': {'x': 0, 'r0_': 5242880000}} +) +@triton.jit +def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32) + _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + + _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine( + _tmp3_max, _tmp3_sum, tmp2, False + ) + + _tmp3_max = tl.where(r0_mask, _tmp3_max_next, _tmp3_max) + _tmp3_sum = tl.where(r0_mask, _tmp3_sum_next, _tmp3_sum) + + tmp3, tmp4 = triton_helpers.online_softmax_reduce( + _tmp3_max, _tmp3_sum, 1, False) + tmp3 = tmp3[:, None] + tmp4 = tmp4[:, None] + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp6 - tmp3 + tmp8 = libdevice.exp(tmp7) + tmp9 = (tmp8 / tmp4) + tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask) diff --git a/SpecForge-ext/cache/compiled_kernels/jc/e53b4e416d29cb48686db7f2d878a68523d3468827220b5d15ea6a75f18b24e0.best_config b/SpecForge-ext/cache/compiled_kernels/jc/e53b4e416d29cb48686db7f2d878a68523d3468827220b5d15ea6a75f18b24e0.best_config new file mode 100644 index 0000000000000000000000000000000000000000..cd9795263343a19ee8f06cf527807cd2d9adfee5 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/jc/e53b4e416d29cb48686db7f2d878a68523d3468827220b5d15ea6a75f18b24e0.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "R0_BLOCK": 2048, "num_warps": 16, "num_stages": 1, "configs_hash": "50b7a7455b8a2aa7fe5b57654ddf092584f02f34b265601866fdd653f06a5539", "found_by_coordesc": false, "time_taken_ms": 73, "triton_cache_hash": "GEZC7BNCXFQAGCZIOI2BQLAAUGS4IVUJ4QGCDMFUE3MMZMGBMJIQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/jf/cjfcccbml22u4digkobrilbxkt7z5njzbtumydsv4gwsqqbbakaw.py b/SpecForge-ext/cache/compiled_kernels/jf/cjfcccbml22u4digkobrilbxkt7z5njzbtumydsv4gwsqqbbakaw.py new file mode 100644 index 0000000000000000000000000000000000000000..14661a88ab81544f97c2e376f5066274b6278c77 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/jf/cjfcccbml22u4digkobrilbxkt7z5njzbtumydsv4gwsqqbbakaw.py @@ -0,0 +1,334 @@ +# AOT ID: ['2_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/el/celb4xosanyf3m2sx6v3t54w4bgkx65m4lb2newu7nggikw6jbxj.py +# Topologically Sorted Source Nodes: [hidden_states, hidden_states_1, to_1], Original ATen: [aten._to_copy, aten.mul, aten.sum] +# Source node to ATen node mapping: +# hidden_states => convert_element_type +# hidden_states_1 => mul_16 +# to_1 => convert_element_type_1 +# Graph fragment: +# %tangents_1 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:4" = PlaceHolder[target=tangents_1] +# %primals_4 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:4" = PlaceHolder[target=primals_4] +# %rsqrt : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:4" = PlaceHolder[target=rsqrt] +# %convert_element_type : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:4"[num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_4, torch.float32), kwargs = {}) +# %mul_16 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %rsqrt), kwargs = {}) +# %convert_element_type_1 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_16, torch.bfloat16), kwargs = {}) +# %mul_28 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %convert_element_type_1), kwargs = {}) +# %sum_1 : Tensor "bf16[1, 1, s33][s33, s33, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_28, [0, 1], True), kwargs = {}) +# return %buf0 +triton_red_fused__to_copy_mul_sum_0 = async_compile.triton('triton_red_fused__to_copy_mul_sum_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 131072, 'r0_': 128}, + reduction_hint=ReductionHint.OUTER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mul_sum_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_mul_sum_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = xindex // ks0 + x0 = (xindex % ks0) + _tmp13 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = r0_2 + x1*((31 + ks1*ks2) // 32) + tmp1 = ks1*ks2 + tmp2 = tmp0 < tmp1 + tmp3 = tl.load(in_ptr0 + (x0 + ks0*(((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2)))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x0 + ks0*(((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2)))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp5 = tmp4.to(tl.float32) + tmp6 = tl.load(in_ptr2 + (((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp7 = tmp5 * tmp6 + tmp8 = tmp7.to(tl.float32) + tmp9 = tmp3 * tmp8 + tmp10 = tl.full(tmp9.shape, 0, tmp9.dtype) + tmp11 = tl.where(tmp2, tmp9, tmp10) + tmp12 = tl.broadcast_to(tmp11, [XBLOCK, R0_BLOCK]) + tmp14 = _tmp13 + tmp12 + _tmp13 = tl.where(r0_mask & xmask, tmp14, _tmp13) + tmp13 = tl.sum(_tmp13, 1)[:, None] + tl.store(out_ptr0 + (x3), tmp13, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/iy/ciy3jtwq2kqsaaylz6g2uxngpmmalnqcompyd7v6diseejxhwvzs.py +# Topologically Sorted Source Nodes: [hidden_states, hidden_states_1, to_1], Original ATen: [aten._to_copy, aten.mul, aten.sum] +# Source node to ATen node mapping: +# hidden_states => convert_element_type +# hidden_states_1 => mul_16 +# to_1 => convert_element_type_1 +# Graph fragment: +# %buf0 : Tensor "f32[1, 1, s33, 32][32*s33, 32*s33, 1, s33]cuda:4" = PlaceHolder[target=buf0] +# %convert_element_type : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:4"[num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_4, torch.float32), kwargs = {}) +# %mul_16 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %rsqrt), kwargs = {}) +# %convert_element_type_1 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_16, torch.bfloat16), kwargs = {}) +# %mul_28 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %convert_element_type_1), kwargs = {}) +# %sum_1 : Tensor "bf16[1, 1, s33][s33, s33, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_28, [0, 1], True), kwargs = {}) +# return %sum_1 +triton_per_fused__to_copy_mul_sum_1 = async_compile.triton('triton_per_fused__to_copy_mul_sum_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 32}, + reduction_hint=ReductionHint.OUTER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_mul_sum_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_mul_sum_1(in_ptr0, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + ks0*r0_1), xmask, other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.where(xmask, tmp1, 0) + tmp4 = tl.sum(tmp3, 1)[:, None].to(tl.float32) + tl.store(out_ptr0 + (x0), tmp4, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ha/chaussdp77qpxewcaeemzjepzpid324l76xe5n5k4fihzix4dkro.py +# Topologically Sorted Source Nodes: [hidden_states], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.pow, aten.expand, aten.div, aten.add] +# Source node to ATen node mapping: +# hidden_states => convert_element_type +# Graph fragment: +# %tangents_1 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:4" = PlaceHolder[target=tangents_1] +# %primals_7 : Tensor "bf16[s33][1]cuda:4" = PlaceHolder[target=primals_7] +# %primals_4 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:4" = PlaceHolder[target=primals_4] +# %rsqrt : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:4" = PlaceHolder[target=rsqrt] +# %sum_2 : Tensor "f32[s47, s87, 1][s87, 1, s47*s87]cuda:4" = PlaceHolder[target=sum_2] +# %mul_27 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %primals_7), kwargs = {}) +# %convert_element_type : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:4"[num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_4, torch.float32), kwargs = {}) +# %convert_element_type_2 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:4"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_27, torch.float32), kwargs = {}) +# %mul_29 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_2, %convert_element_type), kwargs = {}) +# %mul_30 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_2, %rsqrt), kwargs = {}) +# %sum_2 : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_29, [2], True), kwargs = {}) +# %pow_2 : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%rsqrt, 3), kwargs = {}) +# %mul_31 : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%sum_2, -0.5), kwargs = {}) +# %mul_32 : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_31, %pow_2), kwargs = {}) +# %expand : Tensor "f32[s47, s87, s33][s87, 1, 0]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%mul_32, [%primals_1, %primals_2, %primals_3]), kwargs = {}) +# %div : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand, %primals_3), kwargs = {}) +# %pow_3 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type, 1.0), kwargs = {}) +# %mul_33 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_3, 2.0), kwargs = {}) +# %mul_34 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div, %mul_33), kwargs = {}) +# %add_37 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_30, %mul_34), kwargs = {}) +# %convert_element_type_3 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_37, torch.bfloat16), kwargs = {}) +# return %sum_2,%convert_element_type_3 +triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2 = async_compile.triton('triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp8 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr2 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tmp2.to(tl.float32) + tmp5 = tmp4.to(tl.float32) + tmp6 = tmp3 * tmp5 + tmp7 = tl.broadcast_to(tmp6, [XBLOCK, R0_BLOCK]) + tmp9 = _tmp8 + tmp7 + _tmp8 = tl.where(r0_mask & xmask, tmp9, _tmp8) + tmp8 = tl.sum(_tmp8, 1)[:, None] + tmp14 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last') + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp10 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp11 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp24 = tl.load(in_ptr2 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp12 = tmp10 * tmp11 + tmp13 = tmp12.to(tl.float32) + tmp15 = tmp13 * tmp14 + tmp16 = -0.5 + tmp17 = tmp8 * tmp16 + tmp18 = tmp14 * tmp14 + tmp19 = tmp18 * tmp14 + tmp20 = tmp17 * tmp19 + tmp21 = ks0 + tmp22 = tmp21.to(tl.float32) + tmp23 = (tmp20 / tmp22) + tmp25 = tmp24.to(tl.float32) + tmp26 = 2.0 + tmp27 = tmp25 * tmp26 + tmp28 = tmp23 * tmp27 + tmp29 = tmp15 + tmp28 + tmp30 = tmp29.to(tl.float32) + tl.store(out_ptr1 + (r0_1 + ks0*x0), tmp30, r0_mask & xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_6, primals_4, primals_7, rsqrt, tangents_1 = args + args.clear() + s47 = primals_1 + s87 = primals_2 + s33 = primals_3 + s82 = primals_6 + assert_size_stride(primals_4, (s47, s87, s33), (s33*s87, s33, 1)) + assert_size_stride(primals_7, (s33, ), (1, )) + assert_size_stride(rsqrt, (s47, s87, 1), (s87, 1, 1)) + assert_size_stride(tangents_1, (s47, s87, s33), (s33*s87, s33, 1)) + with torch.cuda._DeviceGuard(4): + torch.cuda.set_device(4) + buf0 = empty_strided_cuda((1, 1, s33, 32), (32*s33, 32*s33, 1, s33), torch.float32) + # Topologically Sorted Source Nodes: [hidden_states, hidden_states_1, to_1], Original ATen: [aten._to_copy, aten.mul, aten.sum] + triton_red_fused__to_copy_mul_sum_0_xnumel = 32*s33 + triton_red_fused__to_copy_mul_sum_0_r0_numel = (31 + s47*s87) // 32 + stream4 = get_raw_stream(4) + triton_red_fused__to_copy_mul_sum_0.run(tangents_1, primals_4, rsqrt, buf0, s33, s47, s87, triton_red_fused__to_copy_mul_sum_0_xnumel, triton_red_fused__to_copy_mul_sum_0_r0_numel, stream=stream4) + buf1 = empty_strided_cuda((1, 1, s33), (s33, s33, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [hidden_states, hidden_states_1, to_1], Original ATen: [aten._to_copy, aten.mul, aten.sum] + stream4 = get_raw_stream(4) + triton_per_fused__to_copy_mul_sum_1.run(buf0, buf1, s33, s33, 32, stream=stream4) + del buf0 + buf3 = empty_strided_cuda((s47, s87, s33), (s33*s87, s33, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [hidden_states], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.pow, aten.expand, aten.div, aten.add] + triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2_xnumel = s47*s87 + stream4 = get_raw_stream(4) + triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2.run(tangents_1, primals_7, primals_4, rsqrt, buf3, s33, triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2_xnumel, s33, stream=stream4) + del primals_4 + del primals_7 + del rsqrt + del tangents_1 + return (None, None, None, buf3, None, None, reinterpret_tensor(buf1, (s33, ), (1, ), 0), ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 2 + primals_2 = 2048 + primals_3 = 4096 + primals_6 = 840433664 + primals_4 = rand_strided((2, 2048, 4096), (8388608, 4096, 1), device='cuda:4', dtype=torch.bfloat16) + primals_7 = rand_strided((4096, ), (1, ), device='cuda:4', dtype=torch.bfloat16) + rsqrt = rand_strided((2, 2048, 1), (2048, 1, 1), device='cuda:4', dtype=torch.float32) + tangents_1 = rand_strided((2, 2048, 4096), (8388608, 4096, 1), device='cuda:4', dtype=torch.bfloat16) + fn = lambda: call([primals_1, primals_2, primals_3, primals_6, primals_4, primals_7, rsqrt, tangents_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/ji/cjij26clq6lcv6c2plwk2zxldtphmt23swyyv2i3vq3ujc4fkjp5.py b/SpecForge-ext/cache/compiled_kernels/ji/cjij26clq6lcv6c2plwk2zxldtphmt23swyyv2i3vq3ujc4fkjp5.py new file mode 100644 index 0000000000000000000000000000000000000000..3fade282c9be1e930d9b5b7b076a915d90537b1b --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ji/cjij26clq6lcv6c2plwk2zxldtphmt23swyyv2i3vq3ujc4fkjp5.py @@ -0,0 +1,86 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 128, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr4': '*i32', 'out_ptr5': '*i32', 'out_ptr6': '*i32', 'out_ptr7': '*i32', 'out_ptr8': '*i32', 'out_ptr9': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr7', 'out_ptr9'], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(in_ptr0, out_ptr4, out_ptr5, out_ptr6, out_ptr7, out_ptr8, out_ptr9, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 128 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + tmp0 = tl.load(in_ptr0 + (r0_1 + 16*x0), xmask, other=0.0) + tmp1 = tl.full([1, 1], 0, tl.int64) + tmp2 = tmp0 > tmp1 + tmp3 = tl.full([1, 1], 16384, tl.int64) + tmp4 = tmp0 < tmp3 + tmp5 = tmp2 & tmp4 + tmp6 = tmp5.to(tl.int8) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8.to(tl.int16) + tmp10 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp11 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12, tmp13, = triton_helpers.sort_with_index(tmp10, tmp11, None, 1, stable=True, descending=True) + tmp14 = tmp0 == tmp3 + tmp15 = tmp14.to(tl.int8) + tmp16 = tmp15.to(tl.int32) + tmp17 = tl.broadcast_to(tmp16, [XBLOCK, R0_BLOCK]) + tmp18, tmp19, = triton_helpers.sort_with_index(tmp17, tmp11, None, 1, stable=True, descending=True) + tmp20 = tmp7.to(tl.int64) + tmp21 = tl.broadcast_to(tmp20, [XBLOCK, R0_BLOCK]) + tmp23 = tl.where(xmask, tmp21, 0) + tmp24 = tl.sum(tmp23, 1)[:, None].to(tl.int64) + tmp25 = tmp16.to(tl.int64) + tmp26 = tl.broadcast_to(tmp25, [XBLOCK, R0_BLOCK]) + tmp28 = tl.where(xmask, tmp26, 0) + tmp29 = tl.sum(tmp28, 1)[:, None].to(tl.int64) + tmp30 = tmp24.to(tl.int32) + tmp31 = tmp29.to(tl.int32) + tmp32 = tmp13.to(tl.int64) + tmp33 = tmp32.to(tl.int32) + tmp34 = tmp8 < tmp30 + tmp35 = tl.full([1, 1], 16, tl.int32) + tmp36 = tl.where(tmp34, tmp33, tmp35) + tmp37 = tl.full([XBLOCK, R0_BLOCK], 17, tl.int32) + tmp38 = tmp36 + tmp37 + tmp39 = tmp36 < 0 + tmp40 = tl.where(tmp39, tmp38, tmp36) + tl.device_assert(((0 <= tmp40) & (tmp40 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp40 < 17") + tmp42 = tl.full([1, 1], 1, tl.int32) + tmp43 = tmp19.to(tl.int64) + tmp44 = tmp43.to(tl.int32) + tmp45 = tmp8 < tmp31 + tmp46 = tl.where(tmp45, tmp44, tmp35) + tmp47 = tmp46 + tmp37 + tmp48 = tmp46 < 0 + tmp49 = tl.where(tmp48, tmp47, tmp46) + tl.device_assert(((0 <= tmp49) & (tmp49 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp49 < 17") + tl.store(out_ptr4 + (x0), tmp30, xmask) + tl.store(out_ptr5 + (x0), tmp31, xmask) + tl.store(out_ptr6 + (r0_1 + 16*x0), tmp33, xmask) + tl.store(out_ptr7 + (tl.broadcast_to(tmp40 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) + tl.store(out_ptr8 + (r0_1 + 16*x0), tmp44, xmask) + tl.store(out_ptr9 + (tl.broadcast_to(tmp49 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/ji/dcd5379fb575ef4fd3e038c06141a69d8219d8c2beb121ade999431f3eddafdf.best_config b/SpecForge-ext/cache/compiled_kernels/ji/dcd5379fb575ef4fd3e038c06141a69d8219d8c2beb121ade999431f3eddafdf.best_config new file mode 100644 index 0000000000000000000000000000000000000000..8920a6ebe9dac1a267cf3c5b5085d70019ad08a3 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ji/dcd5379fb575ef4fd3e038c06141a69d8219d8c2beb121ade999431f3eddafdf.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "b6ac5ef64fddcad8fc8d2c05fa12424871fd9baa5a4158ff38ecebbafb55a4b1", "found_by_coordesc": false, "time_taken_ms": 36, "triton_cache_hash": "MMGM2ESHRXPRFAROBBDYKTZUJ2JVVKU2TB5DVA3EL4OF2SOELPMQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/jq/cjq5hv4rnv3k5awzzq6t2f4dupyimqnzm5i36pci6ox5vpquu66l.py b/SpecForge-ext/cache/compiled_kernels/jq/cjq5hv4rnv3k5awzzq6t2f4dupyimqnzm5i36pci6ox5vpquu66l.py new file mode 100644 index 0000000000000000000000000000000000000000..cbc17594e946f38957030129428016c7a0bf3903 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/jq/cjq5hv4rnv3k5awzzq6t2f4dupyimqnzm5i36pci6ox5vpquu66l.py @@ -0,0 +1,72 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2048, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 32768, 'r0_': 0}} +) +@triton.jit +def triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2048 + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = ((xindex // 16) % 16) + x0 = (xindex % 16) + x2 = xindex // 256 + tmp3 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + _tmp29 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x6 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_4 = r0_index // 128 + r0_3 = (r0_index % 128) + tmp0 = r0_4 + 128*x1 + tmp1 = r0_3 + 128*x0 + tmp2 = tmp0 >= tmp1 + tmp4 = tmp1 < tmp3 + tmp5 = tmp0 < tmp3 + tmp6 = tmp4 & tmp5 + tmp7 = tmp2 & tmp6 + tmp8 = tl.full([1, 1], False, tl.int1) + tmp9 = tmp8 | tmp7 + tmp10 = tl.full([1, 1], 2048, tl.int64) + tmp11 = tmp1 >= tmp10 + tmp12 = tmp11 & tmp4 + tmp13 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp14 = (tmp13 % tmp10) + tmp15 = tl.full([1, 1], 0, tl.int32) + tmp16 = tmp14 != tmp15 + tmp17 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp18 = (libdevice.signbit(tmp10) != 0) if (tmp10).dtype is tl.float32 else tmp10 < 0 + tmp19 = tmp17 != tmp18 + tmp20 = tmp16 & tmp19 + tmp21 = tmp14 + tmp10 + tmp22 = tl.where(tmp20, tmp21, tmp14) + tmp23 = tl.full([1, 1], 0, tl.int64) + tmp24 = tmp22 == tmp23 + tmp25 = tmp12 & tmp24 + tmp26 = tmp9 | tmp25 + tmp27 = tmp26.to(tl.int64) + tmp28 = tl.broadcast_to(tmp27, [XBLOCK, R0_BLOCK]) + tmp30 = _tmp29 + tmp28 + _tmp29 = tl.where(r0_mask & xmask, tmp30, _tmp29) + tmp29 = tl.sum(_tmp29, 1)[:, None] + tl.store(out_ptr0 + (x6), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/js/cjsacdzbsm56bj2gkhqa3emfz7vsvobriphpb6aucurwhb2km5xx.py b/SpecForge-ext/cache/compiled_kernels/js/cjsacdzbsm56bj2gkhqa3emfz7vsvobriphpb6aucurwhb2km5xx.py new file mode 100644 index 0000000000000000000000000000000000000000..d5291056a5204da99fcde4189b13691647a51301 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/js/cjsacdzbsm56bj2gkhqa3emfz7vsvobriphpb6aucurwhb2km5xx.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks2 + stride_q_idx_h = 16*ks3 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/js/cjse6ak6jsp3o35wdszmvjyn4cqeqewbex3a5ks2m6fqecygrmmg.py b/SpecForge-ext/cache/compiled_kernels/js/cjse6ak6jsp3o35wdszmvjyn4cqeqewbex3a5ks2m6fqecygrmmg.py new file mode 100644 index 0000000000000000000000000000000000000000..bdd75894be68c7ef2a4002abbeafda900487c23e --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/js/cjse6ak6jsp3o35wdszmvjyn4cqeqewbex3a5ks2m6fqecygrmmg.py @@ -0,0 +1,56 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 262144}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 151936 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tmp11 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last') + tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32) + tmp4 = tmp2 + tmp3 + tmp5 = tmp2 < 0 + tmp6 = tl.where(tmp5, tmp4, tmp2) + tl.device_assert(((0 <= tmp6) & (tmp6 < 151936)) | ~(xmask), "index out of bounds: 0 <= tmp6 < 151936") + tmp8 = tl.load(in_ptr1 + (tmp6), xmask, eviction_policy='evict_last').to(tl.int1) + tmp9 = tmp8.to(tl.int32) + tmp10 = tmp9.to(tl.int64) + tmp12 = tmp10 * tmp11 + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/js/cjsua5bk75g3qrlvlcodjiakenr5rs5kkcshkbp7bxk6gmv42cax.py b/SpecForge-ext/cache/compiled_kernels/js/cjsua5bk75g3qrlvlcodjiakenr5rs5kkcshkbp7bxk6gmv42cax.py new file mode 100644 index 0000000000000000000000000000000000000000..caa0fcd0dba84d6a5768319dfb9c1c8f36873253 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/js/cjsua5bk75g3qrlvlcodjiakenr5rs5kkcshkbp7bxk6gmv42cax.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 16 + stride_q_idx_h = 256 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/jy/cjyosafwol2ofpnwilyfdmxteu632gjxp7ja4dtj34fk4ylp7t3p.py b/SpecForge-ext/cache/compiled_kernels/jy/cjyosafwol2ofpnwilyfdmxteu632gjxp7ja4dtj34fk4ylp7t3p.py new file mode 100644 index 0000000000000000000000000000000000000000..2880791dbd5321c42743ff9028da16fc294dcd39 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/jy/cjyosafwol2ofpnwilyfdmxteu632gjxp7ja4dtj34fk4ylp7t3p.py @@ -0,0 +1,675 @@ +# AOT ID: ['6_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +from torch._C import _cuda_getCurrentRawStream as get_raw_stream +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/4y/c4yua3qi2b3xk6rn6ls5sdrsrpavp4zes7z62ki32y5ijfhzw4bb.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:5" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[2, 8, 2048, 128][2097152, 262144, 128, 1]cuda:5" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[2, 8, 2048, 128][2097152, 262144, 128, 1]cuda:5" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:5" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:5" = PlaceHolder[target=buf1] +# %primals_5 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:5" = PlaceHolder[target=primals_5] +# %primals_4 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:5" = PlaceHolder[target=primals_4] +# %primals_7 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:5" = PlaceHolder[target=primals_7] +# %primals_8 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:5" = PlaceHolder[target=primals_8] +# %primals_6 : Tensor "i64[2][1]cuda:5" = PlaceHolder[target=primals_6] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_2, %primals_3, %sdpa_score0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %sdpa_mask0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 2097152, 262144, 128, 1 + + ZQ = 2 + HQ = 32 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12 = args + args.clear() + assert_size_stride(primals_1, (2, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_2, (2, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_3, (2, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_4, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_5, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_6, (2, ), (1, )) + assert_size_stride(primals_7, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_8, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_9, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_11, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_12, (2, 1, 16, 16), (256, 256, 16, 1)) + with torch.cuda._DeviceGuard(5): + torch.cuda.set_device(5) + buf0 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32) + buf1 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32) + buf2 = empty_strided_cuda((2, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream5 = get_raw_stream(5) + triton_tem_fused_0.run(primals_1, primals_2, primals_3, buf0, buf1, primals_5, primals_4, primals_7, primals_8, primals_6, buf2, 16, 2, 32, stream=stream5) + del buf1 + return (buf2, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, buf2, buf0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:5', dtype=torch.bfloat16) + primals_2 = rand_strided((2, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:5', dtype=torch.bfloat16) + primals_3 = rand_strided((2, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:5', dtype=torch.bfloat16) + primals_4 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:5', dtype=torch.int32) + primals_5 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:5', dtype=torch.int32) + primals_6 = rand_strided((2, ), (1, ), device='cuda:5', dtype=torch.int64) + primals_7 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:5', dtype=torch.int32) + primals_8 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:5', dtype=torch.int32) + primals_9 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:5', dtype=torch.int32) + primals_10 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:5', dtype=torch.int32) + primals_11 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:5', dtype=torch.int32) + primals_12 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:5', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/jy/cjyz56auvrv5dtttgluspkuhcmpwyzr7c5yy4v7bqzoytljurdsp.py b/SpecForge-ext/cache/compiled_kernels/jy/cjyz56auvrv5dtttgluspkuhcmpwyzr7c5yy4v7bqzoytljurdsp.py new file mode 100644 index 0000000000000000000000000000000000000000..7a69d1d66ff751962c32cfee8f0fbaf57c845bbf --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/jy/cjyz56auvrv5dtttgluspkuhcmpwyzr7c5yy4v7bqzoytljurdsp.py @@ -0,0 +1,62 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*bf16', 'in_ptr1': 'fp64', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mean_mul_pow_rsqrt_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_mean_mul_pow_rsqrt_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tmp1 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp9 = in_ptr1 + tmp6 = ks0 + tmp7 = tmp6.to(tl.float32) + tmp8 = (tmp4 / tmp7) + tmp10 = tmp9.to(tl.float32) + tmp11 = tmp8 + tmp10 + tmp12 = libdevice.rsqrt(tmp11) + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, xmask) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp13 = tl.load(in_ptr2 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp14 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp15 = tmp14.to(tl.float32) + tmp16 = tmp15 * tmp12 + tmp17 = tmp16.to(tl.float32) + tmp18 = tmp13 * tmp17 + tl.store(out_ptr0 + (r0_1 + ks0*x0), tmp18, r0_mask & xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/k2/ck2soaxletvrzenn54tnhajabovvb5jlfowm66eg3l6mpcgjngtg.py b/SpecForge-ext/cache/compiled_kernels/k2/ck2soaxletvrzenn54tnhajabovvb5jlfowm66eg3l6mpcgjngtg.py new file mode 100644 index 0000000000000000000000000000000000000000..236030e74fb654b68418a920ac4922a7044c23dd --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/k2/ck2soaxletvrzenn54tnhajabovvb5jlfowm66eg3l6mpcgjngtg.py @@ -0,0 +1,1065 @@ +# AOT ID: ['9_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/c3/cc3guwnwiox3yzzjtaquh6k4sm6nn4lcmkep56rop3grqr44xorh.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:7" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 262144, 128, 1]cuda:7" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[2, 32, 2048][65536, 2048, 1]cuda:7" = PlaceHolder[target=buf0] +# %full_default : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:7, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_3, %primals_5, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, %primals_8, %primals_9, %primals_7, %primals_11, %primals_13, %primals_15, %primals_17, %primals_19, %primals_21, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_10,)), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_zeros_0 = async_compile.triton('triton_red_fused_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 131072, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 1048576, 'r0_': 67108864}} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 131072 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = ((xindex // 2048) % 32) + x2 = xindex // 65536 + x4 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wd/cwddux6qmu2dak3xt2ktqi6irvxaesmlncmzcf4b3cydua7hsfjf.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:7" = PlaceHolder[target=primals_1] +# %primals_3 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:7" = PlaceHolder[target=primals_3] +# %primals_5 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:7" = PlaceHolder[target=primals_5] +# %getitem_1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:7" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:7" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 262144, 128, 1]cuda:7" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:7" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:7" = PlaceHolder[target=getitem_5] +# %primals_9 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:7" = PlaceHolder[target=primals_9] +# %primals_7 : Tensor "i32[2, 1, 16, s72][16*s72, 16*s72, s72, 1]cuda:7" = PlaceHolder[target=primals_7] +# %primals_15 : Tensor "i32[2, 1, s56][s56, s56, 1]cuda:7" = PlaceHolder[target=primals_15] +# %primals_17 : Tensor "i32[2, 1, s84, 16][16*s84, 16*s84, 16, 1]cuda:7" = PlaceHolder[target=primals_17] +# %primals_11 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:7" = PlaceHolder[target=primals_11] +# %primals_13 : Tensor "i32[2, 1, 16, s4][16*s4, 16*s4, s4, 1]cuda:7" = PlaceHolder[target=primals_13] +# %primals_19 : Tensor "i32[2, 1, s99][s99, s99, 1]cuda:7" = PlaceHolder[target=primals_19] +# %primals_21 : Tensor "i32[2, 1, s6, 16][16*s6, 16*s6, 16, 1]cuda:7" = PlaceHolder[target=primals_21] +# %primals_10 : Tensor "i64[2][1]cuda:7" = PlaceHolder[target=primals_10] +# %full_default : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:7, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_3, %primals_5, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, %primals_8, %primals_9, %primals_7, %primals_11, %primals_13, %primals_15, %primals_17, %primals_19, %primals_21, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_10,)), kwargs = {}) +# return %getitem_4 +triton_tem_fused_zeros_1 = async_compile.triton('triton_tem_fused_zeros_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks2 + stride_q_idx_h = 16*ks3 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_8, primals_6, primals_12, primals_14, primals_16, primals_18, primals_20, primals_1, primals_3, primals_5, primals_7, primals_9, primals_10, primals_11, primals_13, primals_15, primals_17, primals_19, primals_21, getitem, getitem_1, tangents_1 = args + args.clear() + s0 = primals_8 + s72 = primals_6 + s4 = primals_12 + s56 = primals_14 + s84 = primals_16 + s99 = primals_18 + s6 = primals_20 + assert_size_stride(primals_1, (2, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_3, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_5, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_7, (2, 1, 16, s72), (16*s72, 16*s72, s72, 1)) + assert_size_stride(primals_9, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (2, ), (1, )) + assert_size_stride(primals_11, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_13, (2, 1, 16, s4), (16*s4, 16*s4, s4, 1)) + assert_size_stride(primals_15, (2, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_17, (2, 1, s84, 16), (16*s84, 16*s84, 16, 1)) + assert_size_stride(primals_19, (2, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_21, (2, 1, s6, 16), (16*s6, 16*s6, 16, 1)) + assert_size_stride(getitem, (2, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(getitem_1, (2, 32, 2048), (65536, 2048, 1)) + assert_size_stride(tangents_1, (2, 32, 2048, 128), (8388608, 262144, 128, 1)) + with torch.cuda._DeviceGuard(7): + torch.cuda.set_device(7) + buf1 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream7 = get_raw_stream(7) + triton_red_fused_zeros_0.run(getitem, tangents_1, buf1, 131072, 128, stream=stream7) + del getitem + buf3 = empty_strided_cuda((2, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((2, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16) + buf5 = empty_strided_cuda((2, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream7 = get_raw_stream(7) + triton_tem_fused_zeros_1.run(primals_1, primals_3, primals_5, getitem_1, buf1, tangents_1, buf3, buf4, primals_9, primals_7, primals_15, primals_17, primals_11, primals_13, primals_19, primals_21, primals_10, buf5, s0, s72, s56, s84, 64 + ((127 + s0) // 128), 2, 8, stream=stream7) + del buf1 + del getitem_1 + del primals_1 + del primals_10 + del primals_11 + del primals_13 + del primals_15 + del primals_17 + del primals_19 + del primals_21 + del primals_3 + del primals_5 + del primals_7 + del primals_9 + del tangents_1 + return (buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_8 = 4096 + primals_6 = 32 + primals_12 = 32 + primals_14 = 32 + primals_16 = 32 + primals_18 = 32 + primals_20 = 32 + primals_1 = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16) + primals_3 = rand_strided((2, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:7', dtype=torch.bfloat16) + primals_5 = rand_strided((2, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:7', dtype=torch.bfloat16) + primals_7 = rand_strided((2, 1, 16, 32), (512, 512, 32, 1), device='cuda:7', dtype=torch.int32) + primals_9 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:7', dtype=torch.int32) + primals_10 = rand_strided((2, ), (1, ), device='cuda:7', dtype=torch.int64) + primals_11 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:7', dtype=torch.int32) + primals_13 = rand_strided((2, 1, 16, 32), (512, 512, 32, 1), device='cuda:7', dtype=torch.int32) + primals_15 = rand_strided((2, 1, 32), (32, 32, 1), device='cuda:7', dtype=torch.int32) + primals_17 = rand_strided((2, 1, 32, 16), (512, 512, 16, 1), device='cuda:7', dtype=torch.int32) + primals_19 = rand_strided((2, 1, 32), (32, 32, 1), device='cuda:7', dtype=torch.int32) + primals_21 = rand_strided((2, 1, 32, 16), (512, 512, 16, 1), device='cuda:7', dtype=torch.int32) + getitem = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16) + getitem_1 = rand_strided((2, 32, 2048), (65536, 2048, 1), device='cuda:7', dtype=torch.float32) + tangents_1 = rand_strided((2, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:7', dtype=torch.bfloat16) + fn = lambda: call([primals_8, primals_6, primals_12, primals_14, primals_16, primals_18, primals_20, primals_1, primals_3, primals_5, primals_7, primals_9, primals_10, primals_11, primals_13, primals_15, primals_17, primals_19, primals_21, getitem, getitem_1, tangents_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/ka/8b03f17ad29941e9ea7ac150ef3f353df649b4edd9ed418abed5fcd1c7bde6c2.best_config b/SpecForge-ext/cache/compiled_kernels/ka/8b03f17ad29941e9ea7ac150ef3f353df649b4edd9ed418abed5fcd1c7bde6c2.best_config new file mode 100644 index 0000000000000000000000000000000000000000..6c3a3559496e6e4d68292da2e678eca0b03342ab --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ka/8b03f17ad29941e9ea7ac150ef3f353df649b4edd9ed418abed5fcd1c7bde6c2.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 65, "triton_cache_hash": "NFABHOURJ57C2IKXWDMS2VHZ76PCVKJVD7V6CBWJDLMT5TQE5GFA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/ka/ckakj5oalaoxcoqpyrpbjs75fxr6p2sp5eymgmxpomaulof2nifw.py b/SpecForge-ext/cache/compiled_kernels/ka/ckakj5oalaoxcoqpyrpbjs75fxr6p2sp5eymgmxpomaulof2nifw.py new file mode 100644 index 0000000000000000000000000000000000000000..d46bdd51e028ccebc0cbce9426944ab9ee473c2e --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ka/ckakj5oalaoxcoqpyrpbjs75fxr6p2sp5eymgmxpomaulof2nifw.py @@ -0,0 +1,56 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 67108864}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/ka/ckakxdyadsnew23mzaldde5vr4hhxvcp5mq45gbmb5gmgup5ucok.py b/SpecForge-ext/cache/compiled_kernels/ka/ckakxdyadsnew23mzaldde5vr4hhxvcp5mq45gbmb5gmgup5ucok.py new file mode 100644 index 0000000000000000000000000000000000000000..8dc8e1a605439f5e3270c062f6261c25a41a29ad --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ka/ckakxdyadsnew23mzaldde5vr4hhxvcp5mq45gbmb5gmgup5ucok.py @@ -0,0 +1,675 @@ +# AOT ID: ['6_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +from torch._C import _cuda_getCurrentRawStream as get_raw_stream +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/zs/czs5ajn54p3jyxmsxcfenxtcm3rwqng63ls3udjpktpl3vy352ky.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:4" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[8, 8, 2048, 128][2097152, 262144, 128, 1]cuda:4" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[8, 8, 2048, 128][2097152, 262144, 128, 1]cuda:4" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:4" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:4" = PlaceHolder[target=buf1] +# %primals_5 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:4" = PlaceHolder[target=primals_5] +# %primals_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:4" = PlaceHolder[target=primals_4] +# %primals_7 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:4" = PlaceHolder[target=primals_7] +# %primals_8 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:4" = PlaceHolder[target=primals_8] +# %primals_6 : Tensor "i64[8][1]cuda:4" = PlaceHolder[target=primals_6] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_2, %primals_3, %sdpa_score0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %sdpa_mask0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 2097152, 262144, 128, 1 + + ZQ = 8 + HQ = 32 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12 = args + args.clear() + assert_size_stride(primals_1, (8, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_2, (8, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_3, (8, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_4, (8, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_5, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_6, (8, ), (1, )) + assert_size_stride(primals_7, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_8, (8, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_9, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (8, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_11, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_12, (8, 1, 16, 16), (256, 256, 16, 1)) + with torch.cuda._DeviceGuard(4): + torch.cuda.set_device(4) + buf0 = empty_strided_cuda((8, 32, 2048), (65536, 2048, 1), torch.float32) + buf1 = empty_strided_cuda((8, 32, 2048), (65536, 2048, 1), torch.float32) + buf2 = empty_strided_cuda((8, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream4 = get_raw_stream(4) + triton_tem_fused_0.run(primals_1, primals_2, primals_3, buf0, buf1, primals_5, primals_4, primals_7, primals_8, primals_6, buf2, 16, 8, 32, stream=stream4) + del buf1 + return (buf2, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, buf2, buf0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16) + primals_2 = rand_strided((8, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:4', dtype=torch.bfloat16) + primals_3 = rand_strided((8, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:4', dtype=torch.bfloat16) + primals_4 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:4', dtype=torch.int32) + primals_5 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:4', dtype=torch.int32) + primals_6 = rand_strided((8, ), (1, ), device='cuda:4', dtype=torch.int64) + primals_7 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:4', dtype=torch.int32) + primals_8 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:4', dtype=torch.int32) + primals_9 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:4', dtype=torch.int32) + primals_10 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:4', dtype=torch.int32) + primals_11 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:4', dtype=torch.int32) + primals_12 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:4', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/ka/ckalrgzcc4h4y7ugp6g6gajbzlmuubj2lzzo2a2rlyn3p3llgyro.py b/SpecForge-ext/cache/compiled_kernels/ka/ckalrgzcc4h4y7ugp6g6gajbzlmuubj2lzzo2a2rlyn3p3llgyro.py new file mode 100644 index 0000000000000000000000000000000000000000..0dd39781c7869716f41d3ce7bd1a93e3ae63c5b6 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ka/ckalrgzcc4h4y7ugp6g6gajbzlmuubj2lzzo2a2rlyn3p3llgyro.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32', 'ks8': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128*ks1, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 2 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks8 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = ks8 + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/SpecForge-ext/cache/compiled_kernels/ki/420d573745373e2da6506b192fb5c004f8ab97222df7c77d3227f04d688e6aac.best_config b/SpecForge-ext/cache/compiled_kernels/ki/420d573745373e2da6506b192fb5c004f8ab97222df7c77d3227f04d688e6aac.best_config new file mode 100644 index 0000000000000000000000000000000000000000..bccc339c530946640745852b622c73570887ab86 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ki/420d573745373e2da6506b192fb5c004f8ab97222df7c77d3227f04d688e6aac.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "CLTRXNE5MHPP3O5A5W4Z4EQTTZVYMOP5IPJT6N44O6FTBZFXLMNA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/ki/b7cae6f54b72f2b8d7f5e29e45d4897b377c87b665df0ffe6dcae6711b6f1b57.best_config b/SpecForge-ext/cache/compiled_kernels/ki/b7cae6f54b72f2b8d7f5e29e45d4897b377c87b665df0ffe6dcae6711b6f1b57.best_config new file mode 100644 index 0000000000000000000000000000000000000000..ed4bbafec32134c55e06add8fdbae259cebe3543 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ki/b7cae6f54b72f2b8d7f5e29e45d4897b377c87b665df0ffe6dcae6711b6f1b57.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "6fcabd0411a839b7b5d117b5e6638bd1b5d7bc3379312c678d803859f08278a9", "found_by_coordesc": false, "time_taken_ms": 18, "triton_cache_hash": "EB4J5U2HKNQBLXRWK6B5L6ATOH55AWD3MB7P63KH5AKRGRDZER7A"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/ki/ckilhdvljb7gisvjeh27eoct2t7a3jnlamhdjdlsk4sziursotb7.py b/SpecForge-ext/cache/compiled_kernels/ki/ckilhdvljb7gisvjeh27eoct2t7a3jnlamhdjdlsk4sziursotb7.py new file mode 100644 index 0000000000000000000000000000000000000000..2390e9c632b0438c91a7fdba43fe027c7494a076 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ki/ckilhdvljb7gisvjeh27eoct2t7a3jnlamhdjdlsk4sziursotb7.py @@ -0,0 +1,26 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 512}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr0': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_slice_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_clone_slice_4(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = xindex // ks0 + x2 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + x1 + ks0*x1), xmask, eviction_policy='evict_last') + tl.store(out_ptr0 + (x2), tmp0, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/ki/ckinza2erpbni3br5d2bkasocypmw65p43nltocsbky5dit6v5gn.py b/SpecForge-ext/cache/compiled_kernels/ki/ckinza2erpbni3br5d2bkasocypmw65p43nltocsbky5dit6v5gn.py new file mode 100644 index 0000000000000000000000000000000000000000..2854640c39fae2c863fb6f2a456dc131d55f5b58 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ki/ckinza2erpbni3br5d2bkasocypmw65p43nltocsbky5dit6v5gn.py @@ -0,0 +1,50 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 256, 'r0_': 4096}} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 32 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % 16) + x1 = xindex // 16 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + 17*r0_2 + 272*x1), xmask, other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x3), tmp13, xmask) + tl.store(out_ptr3 + (x3), tmp14, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/kn/ckndvcvzkufaaxp7zvavaxav3yjtn5ovzc46qyzuzojreblqoq62.py b/SpecForge-ext/cache/compiled_kernels/kn/ckndvcvzkufaaxp7zvavaxav3yjtn5ovzc46qyzuzojreblqoq62.py new file mode 100644 index 0000000000000000000000000000000000000000..0a15e79785877ebdc50d8ca3694b7a4040bedf75 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/kn/ckndvcvzkufaaxp7zvavaxav3yjtn5ovzc46qyzuzojreblqoq62.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 16 + stride_q_idx_h = 256 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/kv/ckvlggwgq6tgpe4wofbzgxvnod4mevkevzeoso5lokg4oamzmnwy.py b/SpecForge-ext/cache/compiled_kernels/kv/ckvlggwgq6tgpe4wofbzgxvnod4mevkevzeoso5lokg4oamzmnwy.py new file mode 100644 index 0000000000000000000000000000000000000000..e98e1551b6bd5fb798133282d70542032aa9b461 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/kv/ckvlggwgq6tgpe4wofbzgxvnod4mevkevzeoso5lokg4oamzmnwy.py @@ -0,0 +1,682 @@ +# AOT ID: ['12_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/nz/cnzssz7d54icfw7gemvqefbnv4guf7ha2jsvspqqdi4gyentn32i.py +# Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] +# Source node to ATen node mapping: +# dense_mask_2 => full_default_1 +# Graph fragment: +# %full_default_1 : Tensor "i32[8, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, %floordiv_3, %add_201], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:4, pin_memory: False}) +# return %index_put +triton_poi_fused_new_zeros_0 = async_compile.triton('triton_poi_fused_new_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 2048}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hv/chvj5h3adlnuxifatrhlirixthstwv5pzbxvuapjby5cz2npck63.py +# Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_1, mask_2, mask_3, mask_block_sum, gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, full_blocks, full_blocks_1, dense_mask_1], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.constant_pad_nd, aten.permute, aten.sum, aten.gt, aten._to_copy] +# Source node to ATen node mapping: +# and_2 => bitwise_and_1 +# and_3 => bitwise_and_2 +# and_4 => bitwise_and_3, view_8 +# b => iota +# batched_outputs_2 => view_9 +# causal_mask => ge_2, view +# dense_mask => convert_element_type_2 +# dense_mask_1 => convert_element_type_5 +# diagnol_mask => eq_24 +# full_blocks => eq_45 +# full_blocks_1 => convert_element_type_1 +# gt => gt +# index => index +# index_1 => index_1 +# index_2 => index_2 +# lt => lt, view_1 +# lt_1 => lt_1, view_2 +# lt_3 => lt_3 +# m => iota_2 +# mask_1 => constant_pad_nd +# mask_2 => view_10 +# mask_3 => permute +# mask_block_sum => sum_1 +# n => iota_3 +# padding_mask => bitwise_and, view_3, view_4 +# padding_mask_1 => lt_2, view_6 +# partial_blocks => bitwise_and_4 +# partial_blocks_1 => convert_element_type +# remainder => remainder +# remainder_1 => remainder_1 +# result_1 => bitwise_or, full_default +# result_2 => bitwise_or_1 +# sub => sub_24, view_7 +# suffix_mask => ge_3 +# Graph fragment: +# %arg2_1 : Tensor "i64[8][1]cuda:4" = PlaceHolder[target=arg2_1] +# %sum_1 : Tensor "i64[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][(((s12 + 127)//128))*(((s37 + 127)//128)), 8*(((s12 + 127)//128))*(((s37 + 127)//128)), ((s37 + 127)//128), 1]cuda:4" = PlaceHolder[target=sum_1] +# %full_default : Tensor "b8[8, 1, 1][1, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1], False), kwargs = {dtype: torch.bool, layout: torch.strided, device: cuda:4, pin_memory: False}) +# %iota_2 : Tensor "i64[s12][1]cuda:4"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (%arg0_1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:4, requires_grad: False}) +# %view : Tensor "i64[s12, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [%arg0_1, 1]), kwargs = {}) +# %iota_3 : Tensor "i64[s37][1]cuda:4"[num_users=5] = call_function[target=torch.ops.prims.iota.default](args = (%arg1_1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:4, requires_grad: False}) +# %ge_2 : Tensor "b8[s12, s37][Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.ge.Tensor](args = (%view, %iota_3), kwargs = {}) +# %iota : Tensor "i64[8][1]cuda:4"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:4, requires_grad: False}) +# %index : Tensor "i64[8][1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%iota]), kwargs = {}) +# %view_1 : Tensor "i64[8, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index, [8, 1]), kwargs = {}) +# %lt : Tensor "b8[8, s37][Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_3, %view_1), kwargs = {}) +# %view_4 : Tensor "b8[8, 1, s37][Max(1, s37), s37, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt, [8, 1, %arg1_1]), kwargs = {}) +# %index_1 : Tensor "i64[8][1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%iota]), kwargs = {}) +# %view_2 : Tensor "i64[8, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_1, [8, 1]), kwargs = {}) +# %lt_1 : Tensor "b8[8, s12][Max(1, s12), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_2, %view_2), kwargs = {}) +# %view_3 : Tensor "b8[8, s12, 1][Max(1, s12), 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt_1, [8, %arg0_1, 1]), kwargs = {}) +# %bitwise_and : Tensor "b8[8, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_4, %view_3), kwargs = {}) +# %bitwise_and_1 : Tensor "b8[8, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_2, %bitwise_and), kwargs = {}) +# %bitwise_or : Tensor "b8[8, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%full_default, %bitwise_and_1), kwargs = {}) +# %ge_3 : Tensor "b8[s37][1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%iota_3, %arg3_1), kwargs = {}) +# %remainder : Tensor "i64[s37][1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%iota_3, %arg3_1), kwargs = {}) +# %index_2 : Tensor "i64[8][1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%iota]), kwargs = {}) +# %view_6 : Tensor "i64[8, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_2, [8, 1]), kwargs = {}) +# %lt_2 : Tensor "b8[8, s37][Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%remainder, %view_6), kwargs = {}) +# %bitwise_and_2 : Tensor "b8[8, s37][Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_3, %lt_2), kwargs = {}) +# %view_8 : Tensor "b8[8, 1, s37][Max(1, s37), s37, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_and_2, [8, 1, %arg1_1]), kwargs = {}) +# %view_7 : Tensor "i64[s12, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [%arg0_1, 1]), kwargs = {}) +# %sub_24 : Tensor "i64[s12, s37][Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%iota_3, %view_7), kwargs = {}) +# %remainder_1 : Tensor "i64[s12, s37][Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%sub_24, %arg3_1), kwargs = {}) +# %eq_24 : Tensor "b8[s12, s37][Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%remainder_1, 0), kwargs = {}) +# %bitwise_and_3 : Tensor "b8[8, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_8, %eq_24), kwargs = {}) +# %bitwise_or_1 : Tensor "b8[8, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%bitwise_or, %bitwise_and_3), kwargs = {}) +# %view_9 : Tensor "b8[8, 1, s12, s37][Max(1, s12)*Max(1, s37), s12*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_or_1, [8, 1, %arg0_1, %arg1_1]), kwargs = {}) +# %constant_pad_nd : Tensor "b8[8, 1, 128*(((s12 + 127)//128)), 128*(((s37 + 127)//128))][Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s37 + 127)//128))), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.constant_pad_nd.default](args = (%expand, [0, %sub_42, 0, %sub_44], 0.0), kwargs = {}) +# %view_10 : Tensor "b8[8, 1, ((s12 + 127)//128), 128, ((s37 + 127)//128), 128][Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), 128*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s37 + 127)//128))), 128, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%constant_pad_nd, [8, 1, %floordiv_3, 128, %floordiv_2, 128]), kwargs = {}) +# %permute : Tensor "b8[8, 1, ((s12 + 127)//128), ((s37 + 127)//128), 128, 128][Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), 128*Max(1, 128*(((s37 + 127)//128))), 128, Max(1, 128*(((s37 + 127)//128))), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 1, 2, 4, 3, 5]), kwargs = {}) +# %sum_1 : Tensor "i64[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=3] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute, [-2, -1]), kwargs = {}) +# %gt : Tensor "b8[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {}) +# %lt_3 : Tensor "b8[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %bitwise_and_4 : Tensor "b8[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%gt, %lt_3), kwargs = {}) +# %convert_element_type : Tensor "i8[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%bitwise_and_4, torch.int8), kwargs = {}) +# %convert_element_type_2 : Tensor "i32[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type, torch.int32), kwargs = {}) +# %eq_45 : Tensor "b8[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %convert_element_type_1 : Tensor "i8[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%eq_45, torch.int8), kwargs = {}) +# %convert_element_type_5 : Tensor "i32[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_1, torch.int32), kwargs = {}) +# return %sum_1,%convert_element_type_2,%convert_element_type_5 +triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1 = async_compile.triton('triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2048, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'ks5': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1(in_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, ks3, ks4, ks5, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = ((xindex // ks0) % ks1) + x0 = (xindex % ks0) + x2 = xindex // ks4 + _tmp46 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x5 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_4 = r0_index // 128 + r0_3 = (r0_index % 128) + tmp0 = r0_4 + 128*x1 + tmp1 = ks2 + tmp2 = tmp0 < tmp1 + tmp3 = r0_3 + 128*x0 + tmp4 = ks3 + tmp5 = tmp3 < tmp4 + tmp6 = tmp2 & tmp5 + tmp7 = r0_4 + 128*x1 + tmp8 = r0_3 + 128*x0 + tmp9 = tmp7 >= tmp8 + tmp10 = tl.load(in_ptr0 + (tl.broadcast_to(x2, [XBLOCK, R0_BLOCK])), r0_mask & tmp6 & xmask, eviction_policy='evict_last', other=0.0) + tmp11 = tmp8 < tmp10 + tmp12 = tmp7 < tmp10 + tmp13 = tmp11 & tmp12 + tmp14 = tmp9 & tmp13 + tmp15 = tl.full([1, 1], False, tl.int1) + tmp16 = tmp15 | tmp14 + tmp17 = tl.broadcast_to(ks5, [XBLOCK, R0_BLOCK]) + tmp18 = tmp8 >= tmp17 + tmp19 = (tmp8 % tmp17) + tmp20 = tl.full([1, 1], 0, tl.int32) + tmp21 = tmp19 != tmp20 + tmp22 = (libdevice.signbit(tmp19) != 0) if (tmp19).dtype is tl.float32 else tmp19 < 0 + tmp23 = (libdevice.signbit(tmp17) != 0) if (tmp17).dtype is tl.float32 else tmp17 < 0 + tmp24 = tmp22 != tmp23 + tmp25 = tmp21 & tmp24 + tmp26 = tmp19 + tmp17 + tmp27 = tl.where(tmp25, tmp26, tmp19) + tmp28 = tmp27 < tmp10 + tmp29 = tmp18 & tmp28 + tmp30 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp31 = (tmp30 % tmp17) + tmp32 = tmp31 != tmp20 + tmp33 = (libdevice.signbit(tmp31) != 0) if (tmp31).dtype is tl.float32 else tmp31 < 0 + tmp34 = tmp33 != tmp23 + tmp35 = tmp32 & tmp34 + tmp36 = tmp31 + tmp17 + tmp37 = tl.where(tmp35, tmp36, tmp31) + tmp38 = tl.full([1, 1], 0, tl.int64) + tmp39 = tmp37 == tmp38 + tmp40 = tmp29 & tmp39 + tmp41 = tmp16 | tmp40 + tmp42 = tl.full(tmp41.shape, False, tmp41.dtype) + tmp43 = tl.where(tmp6, tmp41, tmp42) + tmp44 = tmp43.to(tl.int64) + tmp45 = tl.broadcast_to(tmp44, [XBLOCK, R0_BLOCK]) + tmp47 = _tmp46 + tmp45 + _tmp46 = tl.where(r0_mask & xmask, tmp47, _tmp46) + tmp46 = tl.sum(_tmp46, 1)[:, None] + tmp48 = tl.full([1, 1], 0, tl.int64) + tmp49 = tmp46 > tmp48 + tmp50 = tl.full([1, 1], 16384, tl.int64) + tmp51 = tmp46 < tmp50 + tmp52 = tmp49 & tmp51 + tmp53 = tmp52.to(tl.int8) + tmp54 = tmp53.to(tl.int32) + tmp55 = tmp46 == tmp50 + tmp56 = tmp55.to(tl.int8) + tmp57 = tmp56.to(tl.int32) + tl.store(out_ptr1 + (x5), tmp54, xmask) + tl.store(out_ptr2 + (x5), tmp57, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/2e/c2etayrlw6ivbtj3uahv4l3y7x534xpzfww6cyknbe2kfe54yei5.py +# Topologically Sorted Source Nodes: [num_blocks_in_row, child_3], Original ATen: [aten.sum, aten._to_copy] +# Source node to ATen node mapping: +# child_3 => convert_element_type_3 +# num_blocks_in_row => sum_2 +# Graph fragment: +# %convert_element_type_2 : Tensor "i32[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][(((s12 + 127)//128))*(((s37 + 127)//128)), 8*(((s12 + 127)//128))*(((s37 + 127)//128)), ((s37 + 127)//128), 1]cuda:4" = PlaceHolder[target=convert_element_type_2] +# %sum_2 : Tensor "i64[8, 1, ((s12 + 127)//128)][((s12 + 127)//128), 8*(((s12 + 127)//128)), 1]cuda:4" = PlaceHolder[target=sum_2] +# %sum_2 : Tensor "i64[8, 1, ((s12 + 127)//128)][Max(1, ((s12 + 127)//128)), Max(1, ((s12 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_2, [-1]), kwargs = {}) +# %convert_element_type_3 : Tensor "i32[8, 1, ((s12 + 127)//128)][Max(1, ((s12 + 127)//128)), Max(1, ((s12 + 127)//128)), 1]cuda:4"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_2, torch.int32), kwargs = {}) +# return %sum_2,%convert_element_type_3 +triton_red_fused__to_copy_sum_2 = async_compile.triton('triton_red_fused__to_copy_sum_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 128, 'r0_': 16}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_sum_2(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + x2 = (xindex % ks1) + x3 = xindex // ks1 + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x2 + x3*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp5, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/nr/cnrsai53e7ialdlmhahe4zhhr3uynctmotdwm2ifjluqipqewdbf.py +# Topologically Sorted Source Nodes: [dense_mask_2, setitem, arange_4, row_indices, col_range, unsqueeze_1, index_mask, child_4, valid_indices], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.lt, aten._to_copy, aten.scalar_tensor, aten.where, aten.view, aten.index_put] +# Source node to ATen node mapping: +# arange_4 => iota_4 +# child_4 => convert_element_type_4 +# col_range => iota_5 +# dense_mask_2 => full_default_1 +# index_mask => lt_4 +# row_indices => unsqueeze +# setitem => full_default_2, index_put, iota_6, iota_7, unsqueeze_2, unsqueeze_3, unsqueeze_4, unsqueeze_5, unsqueeze_6 +# unsqueeze_1 => unsqueeze_1 +# valid_indices => scalar_tensor, where +# Graph fragment: +# %getitem_1 : Tensor "i64[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), 8*Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4" = PlaceHolder[target=getitem_1] +# %convert_element_type_3 : Tensor "i32[8, 1, ((s12 + 127)//128)][Max(1, ((s12 + 127)//128)), Max(1, ((s12 + 127)//128)), 1]cuda:4" = PlaceHolder[target=convert_element_type_3] +# %convert_element_type_4 : Tensor "i32[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4" = PlaceHolder[target=convert_element_type_4] +# %index_put : Tensor "i32[8, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][((((s37 + 127)//128)) + 1)*(((s12 + 127)//128)), ((((s37 + 127)//128)) + 1)*(((s12 + 127)//128)), (((s37 + 127)//128)) + 1, 1]cuda:4" = PlaceHolder[target=index_put] +# %full_default_1 : Tensor "i32[8, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, %floordiv_3, %add_201], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:4, pin_memory: False}) +# %iota_7 : Tensor "i64[8][1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:4, requires_grad: False}) +# %unsqueeze_4 : Tensor "i64[8, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_7, -1), kwargs = {}) +# %unsqueeze_5 : Tensor "i64[8, 1, 1][1, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_4, -1), kwargs = {}) +# %unsqueeze_6 : Tensor "i64[8, 1, 1, 1][1, 1, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_5, -1), kwargs = {}) +# %iota_6 : Tensor "i64[1][1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:4, requires_grad: False}) +# %unsqueeze_2 : Tensor "i64[1, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_6, -1), kwargs = {}) +# %unsqueeze_3 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_2, -1), kwargs = {}) +# %iota_4 : Tensor "i32[((s12 + 127)//128)][1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (%floordiv_3,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:4, requires_grad: False}) +# %unsqueeze : Tensor "i32[((s12 + 127)//128), 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_4, -1), kwargs = {}) +# %iota_5 : Tensor "i32[((s37 + 127)//128)][1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (%floordiv_2,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:4, requires_grad: False}) +# %unsqueeze_1 : Tensor "i32[8, 1, ((s12 + 127)//128), 1][Max(1, ((s12 + 127)//128)), Max(1, ((s12 + 127)//128)), 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_3, 3), kwargs = {}) +# %lt_4 : Tensor "b8[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_5, %unsqueeze_1), kwargs = {}) +# %convert_element_type_4 : Tensor "i32[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_1, torch.int32), kwargs = {}) +# %scalar_tensor : Tensor "i32[][]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (%floordiv_2,), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:4}) +# %where : Tensor "i32[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_4, %convert_element_type_4, %scalar_tensor), kwargs = {}) +# %full_default_2 : Tensor "i32[8, 1, 1, 1][1, 1, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:4, pin_memory: False}) +# %index_put : Tensor "i32[8, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_1, [%unsqueeze_6, %unsqueeze_3, %unsqueeze, %where], %full_default_2), kwargs = {}) +# return %convert_element_type_4,%buf13 +triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3 = async_compile.triton('triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 2048}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i32', 'out_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3', 'mutated_arg_names': ['out_ptr1'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3(in_ptr0, in_ptr1, out_ptr0, out_ptr1, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % ks1) + x2 = xindex // ks2 + x3 = xindex // ks0 + tmp0 = tl.load(in_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), xmask, eviction_policy='evict_last') + tmp2 = tl.load(in_ptr1 + (x3), xmask, eviction_policy='evict_last') + tmp1 = tmp0.to(tl.int32) + tmp3 = x0 + tmp4 = tmp3 < tmp2 + tmp5 = ks0 + tmp6 = tl.where(tmp4, tmp1, tmp5) + tmp7 = 1 + ks0 + tmp8 = tmp6 + tmp7 + tmp9 = tmp6 < 0 + tmp10 = tl.where(tmp9, tmp8, tmp6) + tl.device_assert(((0 <= tmp10) & (tmp10 < 1 + (triton_helpers.div_floor_integer(127 + ks3, 128)))) | ~(xmask), "index out of bounds: 0 <= tmp10 < 1 + (triton_helpers.div_floor_integer(127 + ks3, 128))") + tmp12 = tl.full([1], 1, tl.int32) + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp1, xmask) + tl.store(out_ptr1 + (tmp10 + x3 + ks0*x3), tmp12, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/fo/cfooe7ht55q5jhejzd3zyb3g5v64cvxjohkxeadllgnjxgiwo52v.py +# Topologically Sorted Source Nodes: [batched_outputs_3], Original ATen: [aten.slice, aten.clone] +# Source node to ATen node mapping: +# batched_outputs_3 => clone_4, slice_4 +# Graph fragment: +# %buf13 : Tensor "i32[8, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][((((s37 + 127)//128)) + 1)*(((s12 + 127)//128)), ((((s37 + 127)//128)) + 1)*(((s12 + 127)//128)), (((s37 + 127)//128)) + 1, 1]cuda:4" = PlaceHolder[target=buf13] +# %slice_4 : Tensor "i32[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, %floordiv_2), kwargs = {}) +# %clone_4 : Tensor "i32[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_4,), kwargs = {memory_format: torch.contiguous_format}) +# return %clone_4 +triton_poi_fused_clone_slice_4 = async_compile.triton('triton_poi_fused_clone_slice_4', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 2048}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr0': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_slice_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_clone_slice_4(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = xindex // ks0 + x2 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + x1 + ks0*x1), xmask, eviction_policy='evict_last') + tl.store(out_ptr0 + (x2), tmp0, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/si/csikk3n53lwihovd25mpe5kyjf7hnym4zk3xgfxvpcikm4bgg35n.py +# Topologically Sorted Source Nodes: [batched_outputs_3, transpose, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sum, aten._to_copy] +# Source node to ATen node mapping: +# batched_outputs_3 => clone_4, slice_4 +# num_blocks_in_row_2 => sum_4 +# q_num_blocks => convert_element_type_8 +# transpose => permute_1 +# Graph fragment: +# %clone_4 : Tensor "i32[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][(((s12 + 127)//128))*(((s37 + 127)//128)), 1, ((s37 + 127)//128), 1]cuda:4" = PlaceHolder[target=clone_4] +# %sum_4 : Tensor "i64[8, 1, ((s37 + 127)//128)][((s37 + 127)//128), 8*(((s37 + 127)//128)), 1]cuda:4" = PlaceHolder[target=sum_4] +# %slice_4 : Tensor "i32[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, %floordiv_2), kwargs = {}) +# %clone_4 : Tensor "i32[8, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_4,), kwargs = {memory_format: torch.contiguous_format}) +# %permute_1 : Tensor "i32[8, 1, ((s37 + 127)//128), ((s12 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:4"[num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%clone_4, [0, 1, 3, 2]), kwargs = {}) +# %sum_4 : Tensor "i64[8, 1, ((s37 + 127)//128)][Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute_1, [-1]), kwargs = {}) +# %convert_element_type_8 : Tensor "i32[8, 1, ((s37 + 127)//128)][Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_4, torch.int32), kwargs = {}) +# return %sum_4,%convert_element_type_8 +triton_red_fused__to_copy_clone_slice_sum_transpose_5 = async_compile.triton('triton_red_fused__to_copy_clone_slice_sum_transpose_5', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 128, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_clone_slice_sum_transpose_5', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_clone_slice_sum_transpose_5(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (x0 + ks0*r0_2 + ks0*ks1*x1), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp5, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/oo/coo6yp6pxkzn6fs2k6qo3o6btpvpbycwufcwhpka6lqffc6e2vth.py +# Topologically Sorted Source Nodes: [q_indices], Original ATen: [aten._to_copy] +# Source node to ATen node mapping: +# q_indices => clone_6, convert_element_type_9 +# Graph fragment: +# %getitem_5 : Tensor "i64[8, 1, ((s37 + 127)//128), ((s12 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:4" = PlaceHolder[target=getitem_5] +# %convert_element_type_9 : Tensor "i32[8, 1, ((s37 + 127)//128), ((s12 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_5, torch.int32), kwargs = {}) +# %clone_6 : Tensor "i32[8, 1, ((s37 + 127)//128), ((s12 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_9,), kwargs = {memory_format: torch.contiguous_format}) +# return %clone_6 +triton_poi_fused__to_copy_6 = async_compile.triton('triton_poi_fused__to_copy_6', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 2048}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_6', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused__to_copy_6(in_ptr0, out_ptr0, ks0, ks1, ks2, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % ks1) + x2 = xindex // ks2 + tmp0 = tl.load(in_ptr0 + (x1 + x0*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), xmask, eviction_policy='evict_last') + tmp1 = tmp0.to(tl.int32) + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp1, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1 = args + args.clear() + s12 = arg0_1 + s37 = arg1_1 + s21 = arg3_1 + assert_size_stride(arg2_1, (8, ), (1, )) + with torch.cuda._DeviceGuard(4): + torch.cuda.set_device(4) + buf12 = empty_strided_cuda((8, 1, (127 + s12) // 128, 1 + ((127 + s37) // 128)), (((127 + s12) // 128)*((127 + s37) // 128) + ((127 + s12) // 128), ((127 + s12) // 128)*((127 + s37) // 128) + ((127 + s12) // 128), 1 + ((127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] + triton_poi_fused_new_zeros_0_xnumel = 8*((127 + s12) // 128) + 8*((127 + s12) // 128)*((127 + s37) // 128) + stream4 = get_raw_stream(4) + triton_poi_fused_new_zeros_0.run(buf12, triton_poi_fused_new_zeros_0_xnumel, stream=stream4) + buf21 = empty_strided_cuda((8, 1, (127 + s12) // 128, 1 + ((127 + s37) // 128)), (((127 + s12) // 128)*((127 + s37) // 128) + ((127 + s12) // 128), ((127 + s12) // 128)*((127 + s37) // 128) + ((127 + s12) // 128), 1 + ((127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros] + triton_poi_fused_new_zeros_0_xnumel = 8*((127 + s12) // 128) + 8*((127 + s12) // 128)*((127 + s37) // 128) + stream4 = get_raw_stream(4) + triton_poi_fused_new_zeros_0.run(buf21, triton_poi_fused_new_zeros_0_xnumel, stream=stream4) + ps0 = (127 + s37) // 128 + ps1 = (127 + s12) // 128 + ps2 = ((127 + s12) // 128)*((127 + s37) // 128) + buf1 = empty_strided_cuda((8, 1, (127 + s12) // 128, (127 + s37) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 8*((127 + s12) // 128)*((127 + s37) // 128), (127 + s37) // 128, 1), torch.int32) + buf5 = empty_strided_cuda((8, 1, (127 + s12) // 128, (127 + s37) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 8*((127 + s12) // 128)*((127 + s37) // 128), (127 + s37) // 128, 1), torch.int32) + # Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_1, mask_2, mask_3, mask_block_sum, gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, full_blocks, full_blocks_1, dense_mask_1], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.constant_pad_nd, aten.permute, aten.sum, aten.gt, aten._to_copy] + triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1_xnumel = 8*((127 + s12) // 128)*((127 + s37) // 128) + stream4 = get_raw_stream(4) + triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1.run(arg2_1, buf1, buf5, ps0, ps1, s12, s37, ps2, s21, triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1_xnumel, 16384, stream=stream4) + del arg2_1 + buf10 = empty_strided_cuda((8, 1, (127 + s12) // 128), (max(1, (127 + s12) // 128), max(1, (127 + s12) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [num_blocks_in_row, child_3], Original ATen: [aten.sum, aten._to_copy] + triton_red_fused__to_copy_sum_2_xnumel = 8*((127 + s12) // 128) + triton_red_fused__to_copy_sum_2_r0_numel = (127 + s37) // 128 + stream4 = get_raw_stream(4) + triton_red_fused__to_copy_sum_2.run(buf1, buf10, ps0, ps1, triton_red_fused__to_copy_sum_2_xnumel, triton_red_fused__to_copy_sum_2_r0_numel, stream=stream4) + buf19 = empty_strided_cuda((8, 1, (127 + s12) // 128), (max(1, (127 + s12) // 128), max(1, (127 + s12) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [num_blocks_in_row_1, child_7], Original ATen: [aten.sum, aten._to_copy] + triton_red_fused__to_copy_sum_2_xnumel = 8*((127 + s12) // 128) + triton_red_fused__to_copy_sum_2_r0_numel = (127 + s37) // 128 + stream4 = get_raw_stream(4) + triton_red_fused__to_copy_sum_2.run(buf5, buf19, ps0, ps1, triton_red_fused__to_copy_sum_2_xnumel, triton_red_fused__to_copy_sum_2_r0_numel, stream=stream4) + # Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort] + buf2 = torch.ops.aten.sort.stable(buf1, stable=True, dim=3, descending=True) + del buf1 + buf4 = buf2[1] + assert_size_stride(buf4, (8, 1, (127 + s12) // 128, (127 + s37) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), 8*max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), 'torch.ops.aten.sort.stable') + assert_alignment(buf4, 16, 'torch.ops.aten.sort.stable') + del buf2 + buf11 = empty_strided_cuda((8, 1, (127 + s12) // 128, (127 + s37) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2, setitem, arange_4, row_indices, col_range, unsqueeze_1, index_mask, child_4, valid_indices], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.lt, aten._to_copy, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3_xnumel = 8*((127 + s12) // 128)*((127 + s37) // 128) + stream4 = get_raw_stream(4) + triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3.run(buf4, buf10, buf11, buf12, ps0, ps1, ps2, s37, triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3_xnumel, stream=stream4) + del buf4 + buf14 = empty_strided_cuda((8, 1, (127 + s12) // 128, (127 + s37) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 1, (127 + s37) // 128, 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_3], Original ATen: [aten.slice, aten.clone] + triton_poi_fused_clone_slice_4_xnumel = 8*((127 + s12) // 128)*((127 + s37) // 128) + stream4 = get_raw_stream(4) + triton_poi_fused_clone_slice_4.run(buf12, buf14, ps0, triton_poi_fused_clone_slice_4_xnumel, stream=stream4) + del buf12 + buf32 = empty_strided_cuda((8, 1, (127 + s37) // 128), (max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sum, aten._to_copy] + triton_red_fused__to_copy_clone_slice_sum_transpose_5_xnumel = 8*((127 + s37) // 128) + triton_red_fused__to_copy_clone_slice_sum_transpose_5_r0_numel = (127 + s12) // 128 + stream4 = get_raw_stream(4) + triton_red_fused__to_copy_clone_slice_sum_transpose_5.run(buf14, buf32, ps0, ps1, triton_red_fused__to_copy_clone_slice_sum_transpose_5_xnumel, triton_red_fused__to_copy_clone_slice_sum_transpose_5_r0_numel, stream=stream4) + # Topologically Sorted Source Nodes: [full_blocks, full_blocks_1, dense_mask_1, col_indices_1], Original ATen: [aten.eq, aten._to_copy, aten.sort] + buf6 = torch.ops.aten.sort.stable(buf5, stable=True, dim=3, descending=True) + del buf5 + buf8 = buf6[1] + assert_size_stride(buf8, (8, 1, (127 + s12) // 128, (127 + s37) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), 8*max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), 'torch.ops.aten.sort.stable') + assert_alignment(buf8, 16, 'torch.ops.aten.sort.stable') + del buf6 + buf20 = empty_strided_cuda((8, 1, (127 + s12) // 128, (127 + s37) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.lt, aten._to_copy, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3_xnumel = 8*((127 + s12) // 128)*((127 + s37) // 128) + stream4 = get_raw_stream(4) + triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3.run(buf8, buf19, buf20, buf21, ps0, ps1, ps2, s37, triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3_xnumel, stream=stream4) + del buf8 + buf23 = empty_strided_cuda((8, 1, (127 + s12) // 128, (127 + s37) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 1, (127 + s37) // 128, 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_5], Original ATen: [aten.slice, aten.clone] + triton_poi_fused_clone_slice_4_xnumel = 8*((127 + s12) // 128)*((127 + s37) // 128) + stream4 = get_raw_stream(4) + triton_poi_fused_clone_slice_4.run(buf21, buf23, ps0, triton_poi_fused_clone_slice_4_xnumel, stream=stream4) + del buf21 + buf29 = empty_strided_cuda((8, 1, (127 + s37) // 128), (max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, num_blocks_in_row_3, full_q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sum, aten._to_copy] + triton_red_fused__to_copy_clone_slice_sum_transpose_5_xnumel = 8*((127 + s37) // 128) + triton_red_fused__to_copy_clone_slice_sum_transpose_5_r0_numel = (127 + s12) // 128 + stream4 = get_raw_stream(4) + triton_red_fused__to_copy_clone_slice_sum_transpose_5.run(buf23, buf29, ps0, ps1, triton_red_fused__to_copy_clone_slice_sum_transpose_5_xnumel, triton_red_fused__to_copy_clone_slice_sum_transpose_5_r0_numel, stream=stream4) + # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort] + buf15 = torch.ops.aten.sort.stable(reinterpret_tensor(buf14, (8, 1, (127 + s37) // 128, (127 + s12) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 0, 1, (127 + s37) // 128), 0), stable=True, dim=3, descending=True) + del buf14 + buf17 = buf15[1] + assert_size_stride(buf17, (8, 1, (127 + s37) // 128, (127 + s12) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), 1, max(1, (127 + s37) // 128)), 'torch.ops.aten.sort.stable') + assert_alignment(buf17, 16, 'torch.ops.aten.sort.stable') + del buf15 + buf30 = empty_strided_cuda((8, 1, (127 + s37) // 128, (127 + s12) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [q_indices], Original ATen: [aten._to_copy] + triton_poi_fused__to_copy_6_xnumel = 8*((127 + s12) // 128)*((127 + s37) // 128) + stream4 = get_raw_stream(4) + triton_poi_fused__to_copy_6.run(buf17, buf30, ps1, ps0, ps2, triton_poi_fused__to_copy_6_xnumel, stream=stream4) + del buf17 + # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, col_indices_3], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort] + buf24 = torch.ops.aten.sort.stable(reinterpret_tensor(buf23, (8, 1, (127 + s37) // 128, (127 + s12) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 0, 1, (127 + s37) // 128), 0), stable=True, dim=3, descending=True) + del buf23 + buf26 = buf24[1] + assert_size_stride(buf26, (8, 1, (127 + s37) // 128, (127 + s12) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), 1, max(1, (127 + s37) // 128)), 'torch.ops.aten.sort.stable') + assert_alignment(buf26, 16, 'torch.ops.aten.sort.stable') + del buf24 + buf27 = empty_strided_cuda((8, 1, (127 + s37) // 128, (127 + s12) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [full_q_indices], Original ATen: [aten._to_copy] + triton_poi_fused__to_copy_6_xnumel = 8*((127 + s12) // 128)*((127 + s37) // 128) + stream4 = get_raw_stream(4) + triton_poi_fused__to_copy_6.run(buf26, buf27, ps1, ps0, ps2, triton_poi_fused__to_copy_6_xnumel, stream=stream4) + del buf26 + return (buf27, buf29, buf30, buf32, buf20, buf19, buf11, buf10, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 1896 + arg1_1 = 1896 + arg2_1 = rand_strided((8, ), (1, ), device='cuda:4', dtype=torch.int64) + arg3_1 = 1896 + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/ky/ckymincticcpi6whoxumnurwhspwbrhpbcg34533u5yjkbf7m3oy.py b/SpecForge-ext/cache/compiled_kernels/ky/ckymincticcpi6whoxumnurwhspwbrhpbcg34533u5yjkbf7m3oy.py new file mode 100644 index 0000000000000000000000000000000000000000..bda955366eca9521f58a9894e8ea1e1a7b6aaf2e --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ky/ckymincticcpi6whoxumnurwhspwbrhpbcg34533u5yjkbf7m3oy.py @@ -0,0 +1,57 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 262144}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 4096 + r0_numel = 151936 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tmp11 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last') + tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32) + tmp4 = tmp2 + tmp3 + tmp5 = tmp2 < 0 + tmp6 = tl.where(tmp5, tmp4, tmp2) + tl.device_assert((0 <= tmp6) & (tmp6 < 151936), "index out of bounds: 0 <= tmp6 < 151936") + tmp8 = tl.load(in_ptr1 + (tmp6), None, eviction_policy='evict_last').to(tl.int1) + tmp9 = tmp8.to(tl.int32) + tmp10 = tmp9.to(tl.int64) + tmp12 = tmp10 * tmp11 + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, None) diff --git a/SpecForge-ext/cache/compiled_kernels/ky/ckyo727c23ds2ldvxtv4ow3bw4pcn7d7h434mir4gk5k6c4aoyzb.py b/SpecForge-ext/cache/compiled_kernels/ky/ckyo727c23ds2ldvxtv4ow3bw4pcn7d7h434mir4gk5k6c4aoyzb.py new file mode 100644 index 0000000000000000000000000000000000000000..6ed2b7991be6cf2f5364afe54ec41160f66cd809 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ky/ckyo727c23ds2ldvxtv4ow3bw4pcn7d7h434mir4gk5k6c4aoyzb.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 16 + stride_q_idx_h = 256 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/SpecForge-ext/cache/compiled_kernels/lc/clckpzwmdsjy3vedxl7b4lnrjahvk5aoiizlcfms6pwh3lq3qkrl.py b/SpecForge-ext/cache/compiled_kernels/lc/clckpzwmdsjy3vedxl7b4lnrjahvk5aoiizlcfms6pwh3lq3qkrl.py new file mode 100644 index 0000000000000000000000000000000000000000..3502799534c4f57c6cbc30f5efb36f3af647a9fa --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/lc/clckpzwmdsjy3vedxl7b4lnrjahvk5aoiizlcfms6pwh3lq3qkrl.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks2 + stride_q_idx_h = 16*ks3 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/le/3e648a8e4777711ec9bc5328bd7f7c82364c1e523ca376587ecf381eb26afafd.best_config b/SpecForge-ext/cache/compiled_kernels/le/3e648a8e4777711ec9bc5328bd7f7c82364c1e523ca376587ecf381eb26afafd.best_config new file mode 100644 index 0000000000000000000000000000000000000000..0102fea510b9bf77ab661e714dfc816c066dc0d8 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/le/3e648a8e4777711ec9bc5328bd7f7c82364c1e523ca376587ecf381eb26afafd.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "IK5RT3JGLTF5PMMUH32NIWB2GXNU6R6CGIZSCRHU3I65YM226KDA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/le/cle7calkycyqdhov5vx6gchjr4rqfyvu5ylug6qms3lpm3hsggbf.py b/SpecForge-ext/cache/compiled_kernels/le/cle7calkycyqdhov5vx6gchjr4rqfyvu5ylug6qms3lpm3hsggbf.py new file mode 100644 index 0000000000000000000000000000000000000000..a6a468ee60b1c96dbb128b8c053e0888677a7a15 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/le/cle7calkycyqdhov5vx6gchjr4rqfyvu5ylug6qms3lpm3hsggbf.py @@ -0,0 +1,57 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 262144}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 4096 + r0_numel = 151936 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tmp11 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last') + tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32) + tmp4 = tmp2 + tmp3 + tmp5 = tmp2 < 0 + tmp6 = tl.where(tmp5, tmp4, tmp2) + tl.device_assert((0 <= tmp6) & (tmp6 < 151936), "index out of bounds: 0 <= tmp6 < 151936") + tmp8 = tl.load(in_ptr1 + (tmp6), None, eviction_policy='evict_last').to(tl.int1) + tmp9 = tmp8.to(tl.int32) + tmp10 = tmp9.to(tl.int64) + tmp12 = tmp10 * tmp11 + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, None) diff --git a/SpecForge-ext/cache/compiled_kernels/le/cleamjoeanagom3jardcugxvitolsddd3gk5g5kwo5srz3tb7zvf.py b/SpecForge-ext/cache/compiled_kernels/le/cleamjoeanagom3jardcugxvitolsddd3gk5g5kwo5srz3tb7zvf.py new file mode 100644 index 0000000000000000000000000000000000000000..eaf79526442228fc6725e578f2ea7378748749fa --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/le/cleamjoeanagom3jardcugxvitolsddd3gk5g5kwo5srz3tb7zvf.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32', 'ks8': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128*ks1, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 2 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks8 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = ks8 + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/SpecForge-ext/cache/compiled_kernels/le/clegjpwgwvwv2zw3wetfizxlrgbyvzek4xsj7xz53lvuve6dxvpd.py b/SpecForge-ext/cache/compiled_kernels/le/clegjpwgwvwv2zw3wetfizxlrgbyvzek4xsj7xz53lvuve6dxvpd.py new file mode 100644 index 0000000000000000000000000000000000000000..50a7d315d68ea47c77eb437c9fe7f9b25dc7f0a6 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/le/clegjpwgwvwv2zw3wetfizxlrgbyvzek4xsj7xz53lvuve6dxvpd.py @@ -0,0 +1,89 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1024, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1(in_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % 16) + x2 = xindex // ks2 + _tmp36 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x5 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = (r0_index % 128) + r0_4 = r0_index // 128 + tmp0 = r0_3 + 128*x0 + tmp1 = ks1 + tmp2 = tmp0 < tmp1 + tmp3 = r0_4 + 128*x1 + tmp4 = r0_3 + 128*x0 + tmp5 = tmp3 >= tmp4 + tmp6 = tl.load(in_ptr0 + (tl.broadcast_to(x2, [XBLOCK, R0_BLOCK])), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp7 = tmp4 < tmp6 + tmp8 = tmp3 < tmp6 + tmp9 = tmp7 & tmp8 + tmp10 = tmp5 & tmp9 + tmp11 = tl.full([1, 1], False, tl.int1) + tmp12 = tmp11 | tmp10 + tmp13 = tl.full([1, 1], 2048, tl.int64) + tmp14 = tmp4 >= tmp13 + tmp15 = ((r0_3 + 128*x0) % 2048) + tmp16 = tmp15 < tmp6 + tmp17 = tmp14 & tmp16 + tmp18 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp19 = (tmp18 % tmp13) + tmp20 = tl.full([1, 1], 0, tl.int32) + tmp21 = tmp19 != tmp20 + tmp22 = (libdevice.signbit(tmp19) != 0) if (tmp19).dtype is tl.float32 else tmp19 < 0 + tmp23 = (libdevice.signbit(tmp13) != 0) if (tmp13).dtype is tl.float32 else tmp13 < 0 + tmp24 = tmp22 != tmp23 + tmp25 = tmp21 & tmp24 + tmp26 = tmp19 + tmp13 + tmp27 = tl.where(tmp25, tmp26, tmp19) + tmp28 = tl.full([1, 1], 0, tl.int64) + tmp29 = tmp27 == tmp28 + tmp30 = tmp17 & tmp29 + tmp31 = tmp12 | tmp30 + tmp32 = tl.full(tmp31.shape, False, tmp31.dtype) + tmp33 = tl.where(tmp2, tmp31, tmp32) + tmp34 = tmp33.to(tl.int64) + tmp35 = tl.broadcast_to(tmp34, [XBLOCK, R0_BLOCK]) + tmp37 = _tmp36 + tmp35 + _tmp36 = tl.where(r0_mask & xmask, tmp37, _tmp36) + tmp36 = tl.sum(_tmp36, 1)[:, None] + tmp38 = tl.full([1, 1], 0, tl.int64) + tmp39 = tmp36 > tmp38 + tmp40 = tl.full([1, 1], 16384, tl.int64) + tmp41 = tmp36 < tmp40 + tmp42 = tmp39 & tmp41 + tmp43 = tmp42.to(tl.int8) + tmp44 = tmp43.to(tl.int32) + tmp45 = tmp36 == tmp40 + tmp46 = tmp45.to(tl.int8) + tmp47 = tmp46.to(tl.int32) + tl.store(out_ptr1 + (x5), tmp44, xmask) + tl.store(out_ptr2 + (x5), tmp47, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/le/clehbyg4yeecmekf5ytdyw6cumf5bexyvwy4tg4l6tlcizrncjzu.py b/SpecForge-ext/cache/compiled_kernels/le/clehbyg4yeecmekf5ytdyw6cumf5bexyvwy4tg4l6tlcizrncjzu.py new file mode 100644 index 0000000000000000000000000000000000000000..6286bcf84f36547305da9e03eb66d5b536b7113e --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/le/clehbyg4yeecmekf5ytdyw6cumf5bexyvwy4tg4l6tlcizrncjzu.py @@ -0,0 +1,27 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 512}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_6', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused__to_copy_6(in_ptr0, out_ptr0, ks0, ks1, ks2, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % ks1) + x2 = xindex // ks2 + tmp0 = tl.load(in_ptr0 + (x1 + x0*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), xmask, eviction_policy='evict_last') + tmp1 = tmp0.to(tl.int32) + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp1, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/lu/2a1eb51ccb82dc5f612d559afabf327d3525e30f4714104b8eeffb76140cf76a.best_config b/SpecForge-ext/cache/compiled_kernels/lu/2a1eb51ccb82dc5f612d559afabf327d3525e30f4714104b8eeffb76140cf76a.best_config new file mode 100644 index 0000000000000000000000000000000000000000..b06045019d512d071043c60a2787e7e25be43ca7 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/lu/2a1eb51ccb82dc5f612d559afabf327d3525e30f4714104b8eeffb76140cf76a.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "R0_BLOCK": 2048, "num_warps": 16, "num_stages": 1, "configs_hash": "8c03dc2e05d158372838fe4d32248dfba74b467c7576f6e1d3eb472c41b37c80", "found_by_coordesc": false, "time_taken_ms": 198, "triton_cache_hash": "YHAVDQXMEVV7S4RZ3RZ2CHWHFBN2O3IAF5U3VLSP72AQHADF3BWQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/lu/cluxwvj6orztk73a6fbxjkhuoenzcjuxfo67uhm7xsuadpewdeyg.py b/SpecForge-ext/cache/compiled_kernels/lu/cluxwvj6orztk73a6fbxjkhuoenzcjuxfo67uhm7xsuadpewdeyg.py new file mode 100644 index 0000000000000000000000000000000000000000..773e5b90e8e868c7d781251a3e740873104288a7 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/lu/cluxwvj6orztk73a6fbxjkhuoenzcjuxfo67uhm7xsuadpewdeyg.py @@ -0,0 +1,48 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 262144, 'r0_': 2097152000}} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = xindex // 2048 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + 65760000*x1), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, None) diff --git a/SpecForge-ext/cache/compiled_kernels/me/cmepf7arvjmdadawms6frhxuz4xsbsxt7fth33p3y3hmoextslbz.py b/SpecForge-ext/cache/compiled_kernels/me/cmepf7arvjmdadawms6frhxuz4xsbsxt7fth33p3y3hmoextslbz.py new file mode 100644 index 0000000000000000000000000000000000000000..7f4a5071dc539b44bd7c011cf0c11330b7e4bf5c --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/me/cmepf7arvjmdadawms6frhxuz4xsbsxt7fth33p3y3hmoextslbz.py @@ -0,0 +1,159 @@ +# AOT ID: ['1_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/l4/cl4mfrugc46xgxleh7kty7kuchbazhslnsdosx2m3jtt7ezzmr56.py +# Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax] +# Source node to ATen node mapping: +# target_head => convert_element_type +# target_p => div +# Graph fragment: +# %arg0_1 : Tensor "bf16[8, 2048, 32000][65536000, 32000, 1]cuda:3" = PlaceHolder[target=arg0_1] +# %getitem : Tensor "f32[8, 2048, 1][2048, 1, 16384]cuda:3" = PlaceHolder[target=getitem] +# %getitem_1 : Tensor "f32[8, 2048, 1][2048, 1, 16384]cuda:3" = PlaceHolder[target=getitem_1] +# %convert_element_type : Tensor "f32[8, 2048, 32000][65536000, 32000, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%arg0_1, torch.float32), kwargs = {}) +# %prepare_softmax_online_default : [num_users=2] = call_function[target=torch.ops.prims.prepare_softmax_online.default](args = (%convert_element_type, 2), kwargs = {}) +# %sub_tensor : Tensor "f32[8, 2048, 32000][65536000, 32000, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type, %getitem), kwargs = {}) +# %exp_default : Tensor "f32[8, 2048, 32000][65536000, 32000, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub_tensor,), kwargs = {}) +# %div : Tensor "f32[8, 2048, 32000][65536000, 32000, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%exp_default, %getitem_1), kwargs = {}) +# return %getitem,%getitem_1,%div +triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0 = async_compile.triton('triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'add_persistent_rblock': True, 'tiling_scores': {'x': 0, 'r0_': 5242880000}} +) +@triton.jit +def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32) + _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + + _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine( + _tmp3_max, _tmp3_sum, tmp2, False + ) + + _tmp3_max = tl.where(r0_mask, _tmp3_max_next, _tmp3_max) + _tmp3_sum = tl.where(r0_mask, _tmp3_sum_next, _tmp3_sum) + + tmp3, tmp4 = triton_helpers.online_softmax_reduce( + _tmp3_max, _tmp3_sum, 1, False) + tmp3 = tmp3[:, None] + tmp4 = tmp4[:, None] + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp6 - tmp3 + tmp8 = libdevice.exp(tmp7) + tmp9 = (tmp8 / tmp4) + tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, = args + args.clear() + assert_size_stride(arg0_1, (8, 2048, 32000), (65536000, 32000, 1)) + with torch.cuda._DeviceGuard(3): + torch.cuda.set_device(3) + buf2 = empty_strided_cuda((8, 2048, 32000), (65536000, 32000, 1), torch.float32) + # Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax] + stream3 = get_raw_stream(3) + triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0.run(arg0_1, buf2, 16384, 32000, stream=stream3) + del arg0_1 + return (buf2, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((8, 2048, 32000), (65536000, 32000, 1), device='cuda:3', dtype=torch.bfloat16) + fn = lambda: call([arg0_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/me/cmetp7q6n3svbn4t3birtvvfederz3hbjif6mn5o5ni7kp5vy6yi.py b/SpecForge-ext/cache/compiled_kernels/me/cmetp7q6n3svbn4t3birtvvfederz3hbjif6mn5o5ni7kp5vy6yi.py new file mode 100644 index 0000000000000000000000000000000000000000..16e8d675f75cb9b11b601204b70708bcb86f9d1f --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/me/cmetp7q6n3svbn4t3birtvvfederz3hbjif6mn5o5ni7kp5vy6yi.py @@ -0,0 +1,307 @@ +# AOT ID: ['4_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/2x/c2xgz3ru7j7sptpmoelww3e5lkmoeimpyawjjwmcpaujxtdorhwr.py +# Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul, x1, x2, neg, cat, mul_1, q_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] +# Source node to ATen node mapping: +# cat => cat +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# mul => mul_24 +# mul_1 => mul_45 +# neg => neg +# q_embed => add_54 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# x1 => slice_1 +# x2 => slice_2 +# Graph fragment: +# %primals_12 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:6" = PlaceHolder[target=primals_12] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:6" = PlaceHolder[target=primals_8] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:6" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:6" = PlaceHolder[target=primals_6] +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %mul_24 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_12, %unsqueeze), kwargs = {}) +# %slice_1 : Tensor "bf16[s48, s34, s9, (s24//2)][s24*s34*s9, s24, s24*s34, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_12, 3, 0, %floordiv), kwargs = {}) +# %slice_2 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s24*s34*s9, s24, s24*s34, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_12, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %neg : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s34*s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), s34*Max(1, s24 - ((s24//2))), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_2,), kwargs = {}) +# %cat : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg, %slice_1], -1), kwargs = {}) +# %mul_45 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat, %unsqueeze_1), kwargs = {}) +# %add_54 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_24, %mul_45), kwargs = {}) +# return %add_54 +triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0 = async_compile.triton('triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 67108864}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/6r/c6r6adrqwwhzfcdd5cyhmwl3cptpvwwhedzdpranw7esxeg5oyia.py +# Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul_2, x1_1, x2_1, neg_1, cat_1, mul_3, k_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] +# Source node to ATen node mapping: +# cat_1 => cat_1 +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# k_embed => add_90 +# mul_2 => mul_54 +# mul_3 => mul_75 +# neg_1 => neg_1 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# x1_1 => slice_3 +# x2_1 => slice_4 +# Graph fragment: +# %primals_13 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24, s24*s48, 1]cuda:6" = PlaceHolder[target=primals_13] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:6" = PlaceHolder[target=primals_8] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:6" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:6" = PlaceHolder[target=primals_6] +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %mul_54 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24, s24*s48, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_13, %unsqueeze), kwargs = {}) +# %slice_3 : Tensor "bf16[s48, s48, s9, (s24//2)][s24*s48*s9, s24, s24*s48, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_13, 3, 0, %floordiv), kwargs = {}) +# %slice_4 : Tensor "bf16[s48, s48, s9, s24 - ((s24//2))][s24*s48*s9, s24, s24*s48, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_13, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %neg_1 : Tensor "bf16[s48, s48, s9, s24 - ((s24//2))][s48*s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), s48*Max(1, s24 - ((s24//2))), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_4,), kwargs = {}) +# %cat_1 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_1, %slice_3], -1), kwargs = {}) +# %mul_75 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_1, %unsqueeze_1), kwargs = {}) +# %add_90 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24, s24*s48, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_54, %mul_75), kwargs = {}) +# return %add_90 +triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1 = async_compile.triton('triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13 = args + args.clear() + s92 = primals_1 + s24 = primals_2 + s96 = primals_3 + s79 = primals_5 + s9 = primals_7 + s38 = primals_9 + s48 = primals_10 + s34 = primals_11 + assert_size_stride(primals_4, (1, 1, s92, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_6, (1, 1, s79, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_8, (1, s9), (s9, 1)) + assert_size_stride(primals_12, (s48, s34, s9, s24), (s24*s34*s9, s24, s24*s34, 1)) + assert_size_stride(primals_13, (s48, s48, s9, s24), (s24*s48*s9, s24, s24*s48, 1)) + with torch.cuda._DeviceGuard(6): + torch.cuda.set_device(6) + ps0 = s24*s34 + buf0 = empty_strided_cuda((s48, s34, s9, s24), (s24*s34*s9, s24, s24*s34, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul, x1, x2, neg, cat, mul_1, q_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0_xnumel = s24*s34*s48*s9 + stream6 = get_raw_stream(6) + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0.run(primals_12, primals_8, primals_4, primals_6, buf0, ps0, s9, s92, s24, s79, triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0_xnumel, stream=stream6) + del primals_12 + ps1 = s24*s48 + buf1 = empty_strided_cuda((s48, s48, s9, s24), (s24*s48*s9, s24, s24*s48, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul_2, x1_1, x2_1, neg_1, cat_1, mul_3, k_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1_xnumel = s24*s9*s48*s48 + stream6 = get_raw_stream(6) + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1.run(primals_13, primals_8, primals_4, primals_6, buf1, ps1, s9, s92, s24, s79, triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1_xnumel, stream=stream6) + del primals_13 + return (buf0, buf1, primals_4, primals_6, primals_8, s24, s9, s48, s34, s92, s96, s79, s24 // 2, s24 + (-1)*(s24 // 2), ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 2048 + primals_2 = 128 + primals_3 = 5245440 + primals_4 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:6', dtype=torch.bfloat16) + primals_5 = 2048 + primals_6 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:6', dtype=torch.bfloat16) + primals_7 = 2048 + primals_8 = rand_strided((1, 2048), (2048, 1), device='cuda:6', dtype=torch.int64) + primals_9 = 1 + primals_10 = 8 + primals_11 = 32 + primals_12 = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:6', dtype=torch.bfloat16) + primals_13 = rand_strided((8, 8, 2048, 128), (2097152, 128, 1024, 1), device='cuda:6', dtype=torch.bfloat16) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/mh/cmhqiwaz2sq3txsgzgknchh6p4qp4xkdhtkilgy67m32sbzzukyz.py b/SpecForge-ext/cache/compiled_kernels/mh/cmhqiwaz2sq3txsgzgknchh6p4qp4xkdhtkilgy67m32sbzzukyz.py new file mode 100644 index 0000000000000000000000000000000000000000..8a87a2f5f2e8a02488a8169fd4bd6b9d17da880e --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/mh/cmhqiwaz2sq3txsgzgknchh6p4qp4xkdhtkilgy67m32sbzzukyz.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 16 + stride_q_idx_h = 256 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/mm/599951369d68f94ff34037a7ca4966f84f3af6db065106a4824a7babf7de6b9d.best_config b/SpecForge-ext/cache/compiled_kernels/mm/599951369d68f94ff34037a7ca4966f84f3af6db065106a4824a7babf7de6b9d.best_config new file mode 100644 index 0000000000000000000000000000000000000000..0f3077839695e120586afcad67e6cbed72fc84a3 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/mm/599951369d68f94ff34037a7ca4966f84f3af6db065106a4824a7babf7de6b9d.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "R0_BLOCK": 2048, "num_warps": 16, "num_stages": 1, "configs_hash": "50b7a7455b8a2aa7fe5b57654ddf092584f02f34b265601866fdd653f06a5539", "found_by_coordesc": false, "time_taken_ms": 74, "triton_cache_hash": "GEZC7BNCXFQAGCZIOI2BQLAAUGS4IVUJ4QGCDMFUE3MMZMGBMJIQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/mm/cmm6gxzvdt4w4heysmumz7etupuzsyhhinqrbrb7e7uar4xfr27g.py b/SpecForge-ext/cache/compiled_kernels/mm/cmm6gxzvdt4w4heysmumz7etupuzsyhhinqrbrb7e7uar4xfr27g.py new file mode 100644 index 0000000000000000000000000000000000000000..69e23a5f07285198b61c6e878cdf29e0330c5752 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/mm/cmm6gxzvdt4w4heysmumz7etupuzsyhhinqrbrb7e7uar4xfr27g.py @@ -0,0 +1,40 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 512}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i32', 'out_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3', 'mutated_arg_names': ['out_ptr1'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3(in_ptr0, in_ptr1, out_ptr0, out_ptr1, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % ks1) + x2 = xindex // ks2 + x3 = xindex // ks0 + tmp0 = tl.load(in_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), xmask, eviction_policy='evict_last') + tmp2 = tl.load(in_ptr1 + (x3), xmask, eviction_policy='evict_last') + tmp1 = tmp0.to(tl.int32) + tmp3 = x0 + tmp4 = tmp3 < tmp2 + tmp5 = ks0 + tmp6 = tl.where(tmp4, tmp1, tmp5) + tmp7 = 1 + ks0 + tmp8 = tmp6 + tmp7 + tmp9 = tmp6 < 0 + tmp10 = tl.where(tmp9, tmp8, tmp6) + tl.device_assert(((0 <= tmp10) & (tmp10 < 1 + (triton_helpers.div_floor_integer(127 + ks3, 128)))) | ~(xmask), "index out of bounds: 0 <= tmp10 < 1 + (triton_helpers.div_floor_integer(127 + ks3, 128))") + tmp12 = tl.full([1], 1, tl.int32) + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp1, xmask) + tl.store(out_ptr1 + (tmp10 + x3 + ks0*x3), tmp12, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/mm/cmmjeni3vm27vuxpkwem5trkvnesbjokn2d5noz6sbcpdfdl3cid.py b/SpecForge-ext/cache/compiled_kernels/mm/cmmjeni3vm27vuxpkwem5trkvnesbjokn2d5noz6sbcpdfdl3cid.py new file mode 100644 index 0000000000000000000000000000000000000000..797553e6ed35aadc84fce926dcac677c04e8d8ee --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/mm/cmmjeni3vm27vuxpkwem5trkvnesbjokn2d5noz6sbcpdfdl3cid.py @@ -0,0 +1,46 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 262144, 'r0_': 1048576000}} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, None) diff --git a/SpecForge-ext/cache/compiled_kernels/mm/cmmp5cb4b4xchyyotwouonjdn4i7oimojhwosocnjqx2t5kcq5jf.py b/SpecForge-ext/cache/compiled_kernels/mm/cmmp5cb4b4xchyyotwouonjdn4i7oimojhwosocnjqx2t5kcq5jf.py new file mode 100644 index 0000000000000000000000000000000000000000..8b938af86bcba50602704af2840f2726b974ba65 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/mm/cmmp5cb4b4xchyyotwouonjdn4i7oimojhwosocnjqx2t5kcq5jf.py @@ -0,0 +1,63 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'add_persistent_rblock': True, 'tiling_scores': {'x': 0, 'r0_': 5242880000}} +) +@triton.jit +def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32) + _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + + _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine( + _tmp3_max, _tmp3_sum, tmp2, False + ) + + _tmp3_max = tl.where(r0_mask, _tmp3_max_next, _tmp3_max) + _tmp3_sum = tl.where(r0_mask, _tmp3_sum_next, _tmp3_sum) + + tmp3, tmp4 = triton_helpers.online_softmax_reduce( + _tmp3_max, _tmp3_sum, 1, False) + tmp3 = tmp3[:, None] + tmp4 = tmp4[:, None] + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp6 - tmp3 + tmp8 = libdevice.exp(tmp7) + tmp9 = (tmp8 / tmp4) + tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask) diff --git a/SpecForge-ext/cache/compiled_kernels/mm/cmmzajy4xf5ahbs253khfun55p7rfbe253rdujk6gw2gr77mbn37.py b/SpecForge-ext/cache/compiled_kernels/mm/cmmzajy4xf5ahbs253khfun55p7rfbe253rdujk6gw2gr77mbn37.py new file mode 100644 index 0000000000000000000000000000000000000000..d0b8cda22411d46adfe9522395fb7a80d01d603a --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/mm/cmmzajy4xf5ahbs253khfun55p7rfbe253rdujk6gw2gr77mbn37.py @@ -0,0 +1,320 @@ +# AOT ID: ['4_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wp/cwp3fb3c4kq3jjlj5mue2pouzku7f5r3znogrbfwtcofaovvgmqa.py +# Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add] +# Source node to ATen node mapping: +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# Graph fragment: +# %tangents_2 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:1" = PlaceHolder[target=tangents_2] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:1" = PlaceHolder[target=primals_8] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:1" = PlaceHolder[target=primals_6] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:1" = PlaceHolder[target=primals_4] +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %mul_84 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, %unsqueeze_1), kwargs = {}) +# %slice_5 : Tensor "bf16[s48, s48, s9, s24 - ((s24//2))][s24*s48*s9, s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_84, 3, 0, %add_96), kwargs = {}) +# %slice_6 : Tensor "bf16[s48, s48, s9, (s24//2)][s24*s48*s9, s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_84, 3, %sub_72, %primals_2), kwargs = {}) +# %neg_2 : Tensor "bf16[s48, s48, s9, s24 - ((s24//2))][s48*s9*Max(1, s24 - ((s24//2))), s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_5,), kwargs = {}) +# %full_default : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.aten.full.default](args = ([%primals_10, %primals_10, %primals_7, %primals_2], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:1, pin_memory: False}) +# %slice_scatter_default : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default, %neg_2, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %slice_scatter_default_1 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default, %slice_6, 3, 0, %floordiv), kwargs = {}) +# %add_100 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default, %slice_scatter_default_1), kwargs = {}) +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %mul_85 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, %unsqueeze), kwargs = {}) +# %add_101 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_100, %mul_85), kwargs = {}) +# return %add_101 +triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0 = async_compile.triton('triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/5g/c5g26r4ygcctmxuptx453t3kikkqukh73touvd4yxv7futs36kgf.py +# Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add] +# Source node to ATen node mapping: +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# Graph fragment: +# %tangents_1 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:1" = PlaceHolder[target=tangents_1] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:1" = PlaceHolder[target=primals_8] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:1" = PlaceHolder[target=primals_6] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:1" = PlaceHolder[target=primals_4] +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %mul_86 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %unsqueeze_1), kwargs = {}) +# %slice_7 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s24*s34*s9, s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_86, 3, 0, %sub_72), kwargs = {}) +# %slice_8 : Tensor "bf16[s48, s34, s9, (s24//2)][s24*s34*s9, s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_86, 3, %sub_72, %primals_2), kwargs = {}) +# %neg_3 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s34*s9*Max(1, s24 - ((s24//2))), s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_7,), kwargs = {}) +# %full_default_2 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.aten.full.default](args = ([%primals_10, %primals_11, %primals_7, %primals_2], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:1, pin_memory: False}) +# %slice_scatter_default_2 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_2, %neg_3, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %slice_scatter_default_3 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_2, %slice_8, 3, 0, %floordiv), kwargs = {}) +# %add_106 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default_2, %slice_scatter_default_3), kwargs = {}) +# %mul_87 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %unsqueeze), kwargs = {}) +# %add_107 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_106, %mul_87), kwargs = {}) +# return %add_107 +triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1 = async_compile.triton('triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 67108864}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_2, primals_7, primals_10, primals_11, primals_1, primals_3, primals_5, floordiv, add_96, primals_4, primals_6, primals_8, tangents_1, tangents_2 = args + args.clear() + s24 = primals_2 + s9 = primals_7 + s48 = primals_10 + s34 = primals_11 + s92 = primals_1 + s96 = primals_3 + s79 = primals_5 + assert_size_stride(primals_4, (1, 1, s92, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_6, (1, 1, s79, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_8, (1, s9), (s9, 1)) + assert_size_stride(tangents_1, (s48, s34, s9, s24), (s24*s34*s9, s24*s9, s24, 1)) + assert_size_stride(tangents_2, (s48, s48, s9, s24), (s24*s48*s9, s24*s9, s24, 1)) + with torch.cuda._DeviceGuard(1): + torch.cuda.set_device(1) + buf0 = empty_strided_cuda((s48, s48, s9, s24), (s24*s48*s9, s24*s9, s24, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add] + triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0_xnumel = s24*s9*s48*s48 + stream1 = get_raw_stream(1) + triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0.run(tangents_2, primals_8, primals_6, primals_4, buf0, s24, s9, s79, s92, triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0_xnumel, stream=stream1) + del tangents_2 + buf1 = empty_strided_cuda((s48, s34, s9, s24), (s24*s34*s9, s24*s9, s24, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add] + triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1_xnumel = s24*s34*s48*s9 + stream1 = get_raw_stream(1) + triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1.run(tangents_1, primals_8, primals_6, primals_4, buf1, s24, s9, s79, s92, triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1_xnumel, stream=stream1) + del primals_4 + del primals_6 + del primals_8 + del tangents_1 + return (None, None, None, None, None, None, None, None, None, None, None, buf1, buf0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_2 = 128 + primals_7 = 2048 + primals_10 = 8 + primals_11 = 32 + primals_1 = 2048 + primals_3 = 5245440 + primals_5 = 2048 + floordiv = 64 + add_96 = 64 + primals_4 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:1', dtype=torch.bfloat16) + primals_6 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:1', dtype=torch.bfloat16) + primals_8 = rand_strided((1, 2048), (2048, 1), device='cuda:1', dtype=torch.int64) + tangents_1 = rand_strided((8, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:1', dtype=torch.bfloat16) + tangents_2 = rand_strided((8, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:1', dtype=torch.bfloat16) + fn = lambda: call([primals_2, primals_7, primals_10, primals_11, primals_1, primals_3, primals_5, floordiv, add_96, primals_4, primals_6, primals_8, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/mm/d7fbe0be67fee394a360fca69e2ba390f4e1580a88e6f242fc545d84004415a3.best_config b/SpecForge-ext/cache/compiled_kernels/mm/d7fbe0be67fee394a360fca69e2ba390f4e1580a88e6f242fc545d84004415a3.best_config new file mode 100644 index 0000000000000000000000000000000000000000..d3b58d3f4d2d57923b7dd1b74f7ebedcc091842a --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/mm/d7fbe0be67fee394a360fca69e2ba390f4e1580a88e6f242fc545d84004415a3.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 15, "triton_cache_hash": "S3UH64TOYTN473KAATRMGKZ5SLQ46EZYJVPR6TIL7QNMYCB3MSMA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/ms/cmss4vkiru5swhxy33hnse2fxrcngpiw7ozbdu4mwbcfqlcznlyx.py b/SpecForge-ext/cache/compiled_kernels/ms/cmss4vkiru5swhxy33hnse2fxrcngpiw7ozbdu4mwbcfqlcznlyx.py new file mode 100644 index 0000000000000000000000000000000000000000..acf3fa6432fa433b975edd1e65acfcd5823c526e --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ms/cmss4vkiru5swhxy33hnse2fxrcngpiw7ozbdu4mwbcfqlcznlyx.py @@ -0,0 +1,43 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_sum_2(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + x2 = (xindex % ks1) + x3 = xindex // ks1 + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x2 + x3*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp5, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/mu/cmujymfqcpdztl6mvthtg7e3oyr4wtaoz6javbq3nk2aj4dhshhs.py b/SpecForge-ext/cache/compiled_kernels/mu/cmujymfqcpdztl6mvthtg7e3oyr4wtaoz6javbq3nk2aj4dhshhs.py new file mode 100644 index 0000000000000000000000000000000000000000..5484dfc0b0ecfb57a25776fffa8f847c4e941a20 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/mu/cmujymfqcpdztl6mvthtg7e3oyr4wtaoz6javbq3nk2aj4dhshhs.py @@ -0,0 +1,72 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2048, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 32768, 'r0_': 0}} +) +@triton.jit +def triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2048 + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = ((xindex // 16) % 16) + x0 = (xindex % 16) + x2 = xindex // 256 + tmp3 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + _tmp29 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x6 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_4 = r0_index // 128 + r0_3 = (r0_index % 128) + tmp0 = r0_4 + 128*x1 + tmp1 = r0_3 + 128*x0 + tmp2 = tmp0 >= tmp1 + tmp4 = tmp1 < tmp3 + tmp5 = tmp0 < tmp3 + tmp6 = tmp4 & tmp5 + tmp7 = tmp2 & tmp6 + tmp8 = tl.full([1, 1], False, tl.int1) + tmp9 = tmp8 | tmp7 + tmp10 = tl.full([1, 1], 2048, tl.int64) + tmp11 = tmp1 >= tmp10 + tmp12 = tmp11 & tmp4 + tmp13 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp14 = (tmp13 % tmp10) + tmp15 = tl.full([1, 1], 0, tl.int32) + tmp16 = tmp14 != tmp15 + tmp17 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp18 = (libdevice.signbit(tmp10) != 0) if (tmp10).dtype is tl.float32 else tmp10 < 0 + tmp19 = tmp17 != tmp18 + tmp20 = tmp16 & tmp19 + tmp21 = tmp14 + tmp10 + tmp22 = tl.where(tmp20, tmp21, tmp14) + tmp23 = tl.full([1, 1], 0, tl.int64) + tmp24 = tmp22 == tmp23 + tmp25 = tmp12 & tmp24 + tmp26 = tmp9 | tmp25 + tmp27 = tmp26.to(tl.int64) + tmp28 = tl.broadcast_to(tmp27, [XBLOCK, R0_BLOCK]) + tmp30 = _tmp29 + tmp28 + _tmp29 = tl.where(r0_mask & xmask, tmp30, _tmp29) + tmp29 = tl.sum(_tmp29, 1)[:, None] + tl.store(out_ptr0 + (x6), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/mx/439bc7c4254f47f871e55d387d037de50f70d03b0e93eed05e6631d64d2130f0.best_config b/SpecForge-ext/cache/compiled_kernels/mx/439bc7c4254f47f871e55d387d037de50f70d03b0e93eed05e6631d64d2130f0.best_config new file mode 100644 index 0000000000000000000000000000000000000000..c2d9b36c5180887fa413aa1eb230c04dc216dd00 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/mx/439bc7c4254f47f871e55d387d037de50f70d03b0e93eed05e6631d64d2130f0.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "6fcabd0411a839b7b5d117b5e6638bd1b5d7bc3379312c678d803859f08278a9", "found_by_coordesc": false, "time_taken_ms": 18, "triton_cache_hash": "G2LU7LIHIOEHQSWVLFBJATACJ76YHM672CUBUDGJGAJUEQVWVOFQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/mx/cmxtehlygzy2cjddulwsvjghigetqtozdl5ft6qfk3edunt3obku.py b/SpecForge-ext/cache/compiled_kernels/mx/cmxtehlygzy2cjddulwsvjghigetqtozdl5ft6qfk3edunt3obku.py new file mode 100644 index 0000000000000000000000000000000000000000..43fd87c359ef1cc61444471b43df7fdc26075c9a --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/mx/cmxtehlygzy2cjddulwsvjghigetqtozdl5ft6qfk3edunt3obku.py @@ -0,0 +1,49 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 64, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % ks0) + x1 = xindex // ks0 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (r0_2 + x0 + 16*x1 + ks0*r0_2 + 16*ks0*x1), xmask, eviction_policy='evict_last', other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x0 + 16*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp13, xmask) + tl.store(out_ptr3 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp14, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/n2/715cf3664139902b9b998a7ab2378e02ebe9b3f3148b727b65d78a661e7c2657.best_config b/SpecForge-ext/cache/compiled_kernels/n2/715cf3664139902b9b998a7ab2378e02ebe9b3f3148b727b65d78a661e7c2657.best_config new file mode 100644 index 0000000000000000000000000000000000000000..b06045019d512d071043c60a2787e7e25be43ca7 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/n2/715cf3664139902b9b998a7ab2378e02ebe9b3f3148b727b65d78a661e7c2657.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "R0_BLOCK": 2048, "num_warps": 16, "num_stages": 1, "configs_hash": "8c03dc2e05d158372838fe4d32248dfba74b467c7576f6e1d3eb472c41b37c80", "found_by_coordesc": false, "time_taken_ms": 198, "triton_cache_hash": "YHAVDQXMEVV7S4RZ3RZ2CHWHFBN2O3IAF5U3VLSP72AQHADF3BWQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/n2/cn24lurjdnbidkarxbtzqpcvotiay3hsbqwsbqw73gg63elg6tak.py b/SpecForge-ext/cache/compiled_kernels/n2/cn24lurjdnbidkarxbtzqpcvotiay3hsbqwsbqw73gg63elg6tak.py new file mode 100644 index 0000000000000000000000000000000000000000..4d781e2eb5d532e293f057cde175f03d55275adb --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/n2/cn24lurjdnbidkarxbtzqpcvotiay3hsbqwsbqw73gg63elg6tak.py @@ -0,0 +1,56 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/n2/cn2qn3nko6rnbf67due6abzxfdfu3uhcmocvdugng3k7xef7zc3x.py b/SpecForge-ext/cache/compiled_kernels/n2/cn2qn3nko6rnbf67due6abzxfdfu3uhcmocvdugng3k7xef7zc3x.py new file mode 100644 index 0000000000000000000000000000000000000000..7f056b16ac6aefd41dacae4d374995799b3a8758 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/n2/cn2qn3nko6rnbf67due6abzxfdfu3uhcmocvdugng3k7xef7zc3x.py @@ -0,0 +1,41 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2, 'r0_': 8192}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 8, 'r0_': 131072}} +) +@triton.jit +def triton_red_fused_sum_3(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2 + r0_numel = 8192 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = _tmp2 + tmp1 + _tmp2 = tl.where(r0_mask & xmask, tmp3, _tmp2) + tmp2 = tl.sum(_tmp2, 1)[:, None] + tl.store(out_ptr0 + (x0), tmp2, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/n2/cn2tg45k3x55ywffuxfwgqxig5mfgdvup54b54q7qi4axpjaewim.py b/SpecForge-ext/cache/compiled_kernels/n2/cn2tg45k3x55ywffuxfwgqxig5mfgdvup54b54q7qi4axpjaewim.py new file mode 100644 index 0000000000000000000000000000000000000000..a9afdf05425f194edac2aeaf0b15cd33ebb2b663 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/n2/cn2tg45k3x55ywffuxfwgqxig5mfgdvup54b54q7qi4axpjaewim.py @@ -0,0 +1,48 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 262144, 'r0_': 2097152000}} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = xindex // 2048 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + 65760000*x1), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, None) diff --git a/SpecForge-ext/cache/compiled_kernels/n2/d2756d049ac90e7a818163a886bf3aceb280ae08dbb32a412d0a5e8a4f56b3df.best_config b/SpecForge-ext/cache/compiled_kernels/n2/d2756d049ac90e7a818163a886bf3aceb280ae08dbb32a412d0a5e8a4f56b3df.best_config new file mode 100644 index 0000000000000000000000000000000000000000..b9c83cd70cc4f7d46eca037549afe001d843ad6c --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/n2/d2756d049ac90e7a818163a886bf3aceb280ae08dbb32a412d0a5e8a4f56b3df.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 49, "triton_cache_hash": "Z2RWAHMO7VUWQKIIRA5A46JYV2SEXHWLKREQM7TOP6VGUWDXAYAQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/n4/3b108be27ed8197fd2c3db26a09f274abbb148fa3ed494c9797f4b5c7cd96355.best_config b/SpecForge-ext/cache/compiled_kernels/n4/3b108be27ed8197fd2c3db26a09f274abbb148fa3ed494c9797f4b5c7cd96355.best_config new file mode 100644 index 0000000000000000000000000000000000000000..b3edbf99a0eba8d80382a60b6e8c1e8217d859f3 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/n4/3b108be27ed8197fd2c3db26a09f274abbb148fa3ed494c9797f4b5c7cd96355.best_config @@ -0,0 +1 @@ +{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "L4G4T7LJLJ5UKCYTJO6FX7X7X5CAA5AHUH7H57L6AM4BNGXXAAVQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/n4/cn44zhfsqswlfcaweodtohqt7stdnvkic4sqxo4katmawv3urke5.py b/SpecForge-ext/cache/compiled_kernels/n4/cn44zhfsqswlfcaweodtohqt7stdnvkic4sqxo4katmawv3urke5.py new file mode 100644 index 0000000000000000000000000000000000000000..093c568cbc95bbc06f8d11392e6c751f0633dbbe --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/n4/cn44zhfsqswlfcaweodtohqt7stdnvkic4sqxo4katmawv3urke5.py @@ -0,0 +1,26 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 512}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr0': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_slice_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_clone_slice_4(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = xindex // ks0 + x2 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + x1 + ks0*x1), xmask, eviction_policy='evict_last') + tl.store(out_ptr0 + (x2), tmp0, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/n7/cn7nomckjhwazqou2vmx3hgrloslbe6x4cp3jyfyvmbrabn6aapk.py b/SpecForge-ext/cache/compiled_kernels/n7/cn7nomckjhwazqou2vmx3hgrloslbe6x4cp3jyfyvmbrabn6aapk.py new file mode 100644 index 0000000000000000000000000000000000000000..e5e1c43f9fdecdd975ee2e7365bf31505955f86e --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/n7/cn7nomckjhwazqou2vmx3hgrloslbe6x4cp3jyfyvmbrabn6aapk.py @@ -0,0 +1,682 @@ +# AOT ID: ['12_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/fp/cfpqdwmnrpuw6cpaelcqs7w2is6wrhnduqm3fuvr47rbdqaklmih.py +# Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] +# Source node to ATen node mapping: +# dense_mask_2 => full_default_1 +# Graph fragment: +# %full_default_1 : Tensor "i32[2, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, %floordiv_3, %add_201], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:7, pin_memory: False}) +# return %index_put +triton_poi_fused_new_zeros_0 = async_compile.triton('triton_poi_fused_new_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 512}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/jx/cjxtezx44vmqh622f3tpmaklof56br4eylt3nz4a46kavvz2gwqw.py +# Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_1, mask_2, mask_3, mask_block_sum, gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, full_blocks, full_blocks_1, dense_mask_1], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.constant_pad_nd, aten.permute, aten.sum, aten.gt, aten._to_copy] +# Source node to ATen node mapping: +# and_2 => bitwise_and_1 +# and_3 => bitwise_and_2 +# and_4 => bitwise_and_3, view_8 +# b => iota +# batched_outputs_2 => view_9 +# causal_mask => ge_2, view +# dense_mask => convert_element_type_2 +# dense_mask_1 => convert_element_type_5 +# diagnol_mask => eq_24 +# full_blocks => eq_45 +# full_blocks_1 => convert_element_type_1 +# gt => gt +# index => index +# index_1 => index_1 +# index_2 => index_2 +# lt => lt, view_1 +# lt_1 => lt_1, view_2 +# lt_3 => lt_3 +# m => iota_2 +# mask_1 => constant_pad_nd +# mask_2 => view_10 +# mask_3 => permute +# mask_block_sum => sum_1 +# n => iota_3 +# padding_mask => bitwise_and, view_3, view_4 +# padding_mask_1 => lt_2, view_6 +# partial_blocks => bitwise_and_4 +# partial_blocks_1 => convert_element_type +# remainder => remainder +# remainder_1 => remainder_1 +# result_1 => bitwise_or, full_default +# result_2 => bitwise_or_1 +# sub => sub_24, view_7 +# suffix_mask => ge_3 +# Graph fragment: +# %arg2_1 : Tensor "i64[2][1]cuda:7" = PlaceHolder[target=arg2_1] +# %sum_1 : Tensor "i64[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][(((s12 + 127)//128))*(((s37 + 127)//128)), 2*(((s12 + 127)//128))*(((s37 + 127)//128)), ((s37 + 127)//128), 1]cuda:7" = PlaceHolder[target=sum_1] +# %full_default : Tensor "b8[2, 1, 1][1, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 1], False), kwargs = {dtype: torch.bool, layout: torch.strided, device: cuda:7, pin_memory: False}) +# %iota_2 : Tensor "i64[s12][1]cuda:7"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (%arg0_1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:7, requires_grad: False}) +# %view : Tensor "i64[s12, 1][1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [%arg0_1, 1]), kwargs = {}) +# %iota_3 : Tensor "i64[s37][1]cuda:7"[num_users=5] = call_function[target=torch.ops.prims.iota.default](args = (%arg1_1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:7, requires_grad: False}) +# %ge_2 : Tensor "b8[s12, s37][Max(1, s37), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.ge.Tensor](args = (%view, %iota_3), kwargs = {}) +# %iota : Tensor "i64[2][1]cuda:7"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (2,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:7, requires_grad: False}) +# %index : Tensor "i64[2][1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%iota]), kwargs = {}) +# %view_1 : Tensor "i64[2, 1][1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index, [2, 1]), kwargs = {}) +# %lt : Tensor "b8[2, s37][Max(1, s37), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_3, %view_1), kwargs = {}) +# %view_4 : Tensor "b8[2, 1, s37][Max(1, s37), s37, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt, [2, 1, %arg1_1]), kwargs = {}) +# %index_1 : Tensor "i64[2][1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%iota]), kwargs = {}) +# %view_2 : Tensor "i64[2, 1][1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_1, [2, 1]), kwargs = {}) +# %lt_1 : Tensor "b8[2, s12][Max(1, s12), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_2, %view_2), kwargs = {}) +# %view_3 : Tensor "b8[2, s12, 1][Max(1, s12), 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt_1, [2, %arg0_1, 1]), kwargs = {}) +# %bitwise_and : Tensor "b8[2, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_4, %view_3), kwargs = {}) +# %bitwise_and_1 : Tensor "b8[2, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_2, %bitwise_and), kwargs = {}) +# %bitwise_or : Tensor "b8[2, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%full_default, %bitwise_and_1), kwargs = {}) +# %ge_3 : Tensor "b8[s37][1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%iota_3, %arg3_1), kwargs = {}) +# %remainder : Tensor "i64[s37][1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%iota_3, %arg3_1), kwargs = {}) +# %index_2 : Tensor "i64[2][1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%iota]), kwargs = {}) +# %view_6 : Tensor "i64[2, 1][1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_2, [2, 1]), kwargs = {}) +# %lt_2 : Tensor "b8[2, s37][Max(1, s37), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%remainder, %view_6), kwargs = {}) +# %bitwise_and_2 : Tensor "b8[2, s37][Max(1, s37), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_3, %lt_2), kwargs = {}) +# %view_8 : Tensor "b8[2, 1, s37][Max(1, s37), s37, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_and_2, [2, 1, %arg1_1]), kwargs = {}) +# %view_7 : Tensor "i64[s12, 1][1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [%arg0_1, 1]), kwargs = {}) +# %sub_24 : Tensor "i64[s12, s37][Max(1, s37), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%iota_3, %view_7), kwargs = {}) +# %remainder_1 : Tensor "i64[s12, s37][Max(1, s37), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%sub_24, %arg3_1), kwargs = {}) +# %eq_24 : Tensor "b8[s12, s37][Max(1, s37), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%remainder_1, 0), kwargs = {}) +# %bitwise_and_3 : Tensor "b8[2, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_8, %eq_24), kwargs = {}) +# %bitwise_or_1 : Tensor "b8[2, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%bitwise_or, %bitwise_and_3), kwargs = {}) +# %view_9 : Tensor "b8[2, 1, s12, s37][Max(1, s12)*Max(1, s37), s12*Max(1, s37), Max(1, s37), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_or_1, [2, 1, %arg0_1, %arg1_1]), kwargs = {}) +# %constant_pad_nd : Tensor "b8[2, 1, 128*(((s12 + 127)//128)), 128*(((s37 + 127)//128))][Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s37 + 127)//128))), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.constant_pad_nd.default](args = (%expand, [0, %sub_42, 0, %sub_44], 0.0), kwargs = {}) +# %view_10 : Tensor "b8[2, 1, ((s12 + 127)//128), 128, ((s37 + 127)//128), 128][Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), 128*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s37 + 127)//128))), 128, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%constant_pad_nd, [2, 1, %floordiv_3, 128, %floordiv_2, 128]), kwargs = {}) +# %permute : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128), 128, 128][Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), 128*Max(1, 128*(((s37 + 127)//128))), 128, Max(1, 128*(((s37 + 127)//128))), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 1, 2, 4, 3, 5]), kwargs = {}) +# %sum_1 : Tensor "i64[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:7"[num_users=3] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute, [-2, -1]), kwargs = {}) +# %gt : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {}) +# %lt_3 : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %bitwise_and_4 : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%gt, %lt_3), kwargs = {}) +# %convert_element_type : Tensor "i8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%bitwise_and_4, torch.int8), kwargs = {}) +# %convert_element_type_2 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:7"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type, torch.int32), kwargs = {}) +# %eq_45 : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %convert_element_type_1 : Tensor "i8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%eq_45, torch.int8), kwargs = {}) +# %convert_element_type_5 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:7"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_1, torch.int32), kwargs = {}) +# return %sum_1,%convert_element_type_2,%convert_element_type_5 +triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1 = async_compile.triton('triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 512, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'ks5': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1(in_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, ks3, ks4, ks5, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = ((xindex // ks0) % ks1) + x0 = (xindex % ks0) + x2 = xindex // ks4 + _tmp46 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x5 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_4 = r0_index // 128 + r0_3 = (r0_index % 128) + tmp0 = r0_4 + 128*x1 + tmp1 = ks2 + tmp2 = tmp0 < tmp1 + tmp3 = r0_3 + 128*x0 + tmp4 = ks3 + tmp5 = tmp3 < tmp4 + tmp6 = tmp2 & tmp5 + tmp7 = r0_4 + 128*x1 + tmp8 = r0_3 + 128*x0 + tmp9 = tmp7 >= tmp8 + tmp10 = tl.load(in_ptr0 + (tl.broadcast_to(x2, [XBLOCK, R0_BLOCK])), r0_mask & tmp6 & xmask, eviction_policy='evict_last', other=0.0) + tmp11 = tmp8 < tmp10 + tmp12 = tmp7 < tmp10 + tmp13 = tmp11 & tmp12 + tmp14 = tmp9 & tmp13 + tmp15 = tl.full([1, 1], False, tl.int1) + tmp16 = tmp15 | tmp14 + tmp17 = tl.broadcast_to(ks5, [XBLOCK, R0_BLOCK]) + tmp18 = tmp8 >= tmp17 + tmp19 = (tmp8 % tmp17) + tmp20 = tl.full([1, 1], 0, tl.int32) + tmp21 = tmp19 != tmp20 + tmp22 = (libdevice.signbit(tmp19) != 0) if (tmp19).dtype is tl.float32 else tmp19 < 0 + tmp23 = (libdevice.signbit(tmp17) != 0) if (tmp17).dtype is tl.float32 else tmp17 < 0 + tmp24 = tmp22 != tmp23 + tmp25 = tmp21 & tmp24 + tmp26 = tmp19 + tmp17 + tmp27 = tl.where(tmp25, tmp26, tmp19) + tmp28 = tmp27 < tmp10 + tmp29 = tmp18 & tmp28 + tmp30 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp31 = (tmp30 % tmp17) + tmp32 = tmp31 != tmp20 + tmp33 = (libdevice.signbit(tmp31) != 0) if (tmp31).dtype is tl.float32 else tmp31 < 0 + tmp34 = tmp33 != tmp23 + tmp35 = tmp32 & tmp34 + tmp36 = tmp31 + tmp17 + tmp37 = tl.where(tmp35, tmp36, tmp31) + tmp38 = tl.full([1, 1], 0, tl.int64) + tmp39 = tmp37 == tmp38 + tmp40 = tmp29 & tmp39 + tmp41 = tmp16 | tmp40 + tmp42 = tl.full(tmp41.shape, False, tmp41.dtype) + tmp43 = tl.where(tmp6, tmp41, tmp42) + tmp44 = tmp43.to(tl.int64) + tmp45 = tl.broadcast_to(tmp44, [XBLOCK, R0_BLOCK]) + tmp47 = _tmp46 + tmp45 + _tmp46 = tl.where(r0_mask & xmask, tmp47, _tmp46) + tmp46 = tl.sum(_tmp46, 1)[:, None] + tmp48 = tl.full([1, 1], 0, tl.int64) + tmp49 = tmp46 > tmp48 + tmp50 = tl.full([1, 1], 16384, tl.int64) + tmp51 = tmp46 < tmp50 + tmp52 = tmp49 & tmp51 + tmp53 = tmp52.to(tl.int8) + tmp54 = tmp53.to(tl.int32) + tmp55 = tmp46 == tmp50 + tmp56 = tmp55.to(tl.int8) + tmp57 = tmp56.to(tl.int32) + tl.store(out_ptr1 + (x5), tmp54, xmask) + tl.store(out_ptr2 + (x5), tmp57, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/b3/cb3yxtybf744swmcpe2lvz7uxmfgl5a6kt4up2cmxf36y3ryayam.py +# Topologically Sorted Source Nodes: [num_blocks_in_row, child_3], Original ATen: [aten.sum, aten._to_copy] +# Source node to ATen node mapping: +# child_3 => convert_element_type_3 +# num_blocks_in_row => sum_2 +# Graph fragment: +# %convert_element_type_2 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][(((s12 + 127)//128))*(((s37 + 127)//128)), 2*(((s12 + 127)//128))*(((s37 + 127)//128)), ((s37 + 127)//128), 1]cuda:7" = PlaceHolder[target=convert_element_type_2] +# %sum_2 : Tensor "i64[2, 1, ((s12 + 127)//128)][((s12 + 127)//128), 2*(((s12 + 127)//128)), 1]cuda:7" = PlaceHolder[target=sum_2] +# %sum_2 : Tensor "i64[2, 1, ((s12 + 127)//128)][Max(1, ((s12 + 127)//128)), Max(1, ((s12 + 127)//128)), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_2, [-1]), kwargs = {}) +# %convert_element_type_3 : Tensor "i32[2, 1, ((s12 + 127)//128)][Max(1, ((s12 + 127)//128)), Max(1, ((s12 + 127)//128)), 1]cuda:7"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_2, torch.int32), kwargs = {}) +# return %sum_2,%convert_element_type_3 +triton_red_fused__to_copy_sum_2 = async_compile.triton('triton_red_fused__to_copy_sum_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_sum_2(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + x2 = (xindex % ks1) + x3 = xindex // ks1 + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x2 + x3*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp5, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/np/cnp6x45zcthhoa6foyolgsvdqvn4rmfofom26u7qk7euk6rr6ib2.py +# Topologically Sorted Source Nodes: [dense_mask_2, setitem, arange_4, row_indices, col_range, unsqueeze_1, index_mask, child_4, valid_indices], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.lt, aten._to_copy, aten.scalar_tensor, aten.where, aten.view, aten.index_put] +# Source node to ATen node mapping: +# arange_4 => iota_4 +# child_4 => convert_element_type_4 +# col_range => iota_5 +# dense_mask_2 => full_default_1 +# index_mask => lt_4 +# row_indices => unsqueeze +# setitem => full_default_2, index_put, iota_6, iota_7, unsqueeze_2, unsqueeze_3, unsqueeze_4, unsqueeze_5, unsqueeze_6 +# unsqueeze_1 => unsqueeze_1 +# valid_indices => scalar_tensor, where +# Graph fragment: +# %getitem_1 : Tensor "i64[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), 2*Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:7" = PlaceHolder[target=getitem_1] +# %convert_element_type_3 : Tensor "i32[2, 1, ((s12 + 127)//128)][Max(1, ((s12 + 127)//128)), Max(1, ((s12 + 127)//128)), 1]cuda:7" = PlaceHolder[target=convert_element_type_3] +# %convert_element_type_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:7" = PlaceHolder[target=convert_element_type_4] +# %index_put : Tensor "i32[2, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][((((s37 + 127)//128)) + 1)*(((s12 + 127)//128)), ((((s37 + 127)//128)) + 1)*(((s12 + 127)//128)), (((s37 + 127)//128)) + 1, 1]cuda:7" = PlaceHolder[target=index_put] +# %full_default_1 : Tensor "i32[2, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, %floordiv_3, %add_201], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:7, pin_memory: False}) +# %iota_7 : Tensor "i64[2][1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (2,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:7, requires_grad: False}) +# %unsqueeze_4 : Tensor "i64[2, 1][1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_7, -1), kwargs = {}) +# %unsqueeze_5 : Tensor "i64[2, 1, 1][1, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_4, -1), kwargs = {}) +# %unsqueeze_6 : Tensor "i64[2, 1, 1, 1][1, 1, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_5, -1), kwargs = {}) +# %iota_6 : Tensor "i64[1][1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:7, requires_grad: False}) +# %unsqueeze_2 : Tensor "i64[1, 1][1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_6, -1), kwargs = {}) +# %unsqueeze_3 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_2, -1), kwargs = {}) +# %iota_4 : Tensor "i32[((s12 + 127)//128)][1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (%floordiv_3,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:7, requires_grad: False}) +# %unsqueeze : Tensor "i32[((s12 + 127)//128), 1][1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_4, -1), kwargs = {}) +# %iota_5 : Tensor "i32[((s37 + 127)//128)][1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (%floordiv_2,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:7, requires_grad: False}) +# %unsqueeze_1 : Tensor "i32[2, 1, ((s12 + 127)//128), 1][Max(1, ((s12 + 127)//128)), Max(1, ((s12 + 127)//128)), 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_3, 3), kwargs = {}) +# %lt_4 : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_5, %unsqueeze_1), kwargs = {}) +# %convert_element_type_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:7"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_1, torch.int32), kwargs = {}) +# %scalar_tensor : Tensor "i32[][]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (%floordiv_2,), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:7}) +# %where : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_4, %convert_element_type_4, %scalar_tensor), kwargs = {}) +# %full_default_2 : Tensor "i32[2, 1, 1, 1][1, 1, 1, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:7, pin_memory: False}) +# %index_put : Tensor "i32[2, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_1, [%unsqueeze_6, %unsqueeze_3, %unsqueeze, %where], %full_default_2), kwargs = {}) +# return %convert_element_type_4,%buf13 +triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3 = async_compile.triton('triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 512}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i32', 'out_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3', 'mutated_arg_names': ['out_ptr1'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3(in_ptr0, in_ptr1, out_ptr0, out_ptr1, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % ks1) + x2 = xindex // ks2 + x3 = xindex // ks0 + tmp0 = tl.load(in_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), xmask, eviction_policy='evict_last') + tmp2 = tl.load(in_ptr1 + (x3), xmask, eviction_policy='evict_last') + tmp1 = tmp0.to(tl.int32) + tmp3 = x0 + tmp4 = tmp3 < tmp2 + tmp5 = ks0 + tmp6 = tl.where(tmp4, tmp1, tmp5) + tmp7 = 1 + ks0 + tmp8 = tmp6 + tmp7 + tmp9 = tmp6 < 0 + tmp10 = tl.where(tmp9, tmp8, tmp6) + tl.device_assert(((0 <= tmp10) & (tmp10 < 1 + (triton_helpers.div_floor_integer(127 + ks3, 128)))) | ~(xmask), "index out of bounds: 0 <= tmp10 < 1 + (triton_helpers.div_floor_integer(127 + ks3, 128))") + tmp12 = tl.full([1], 1, tl.int32) + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp1, xmask) + tl.store(out_ptr1 + (tmp10 + x3 + ks0*x3), tmp12, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/si/csifv3wugwezbiapa7tfir7475ta4akdayy7mujclukw2elqooca.py +# Topologically Sorted Source Nodes: [batched_outputs_3], Original ATen: [aten.slice, aten.clone] +# Source node to ATen node mapping: +# batched_outputs_3 => clone_4, slice_4 +# Graph fragment: +# %buf13 : Tensor "i32[2, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][((((s37 + 127)//128)) + 1)*(((s12 + 127)//128)), ((((s37 + 127)//128)) + 1)*(((s12 + 127)//128)), (((s37 + 127)//128)) + 1, 1]cuda:7" = PlaceHolder[target=buf13] +# %slice_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, %floordiv_2), kwargs = {}) +# %clone_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_4,), kwargs = {memory_format: torch.contiguous_format}) +# return %clone_4 +triton_poi_fused_clone_slice_4 = async_compile.triton('triton_poi_fused_clone_slice_4', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 512}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr0': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_slice_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_clone_slice_4(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = xindex // ks0 + x2 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + x1 + ks0*x1), xmask, eviction_policy='evict_last') + tl.store(out_ptr0 + (x2), tmp0, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chl76ryeuxy3vqhn7zzfop2anxhx2ux6vnshomo7va3ohvp5db2o.py +# Topologically Sorted Source Nodes: [batched_outputs_3, transpose, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sum, aten._to_copy] +# Source node to ATen node mapping: +# batched_outputs_3 => clone_4, slice_4 +# num_blocks_in_row_2 => sum_4 +# q_num_blocks => convert_element_type_8 +# transpose => permute_1 +# Graph fragment: +# %clone_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][(((s12 + 127)//128))*(((s37 + 127)//128)), 1, ((s37 + 127)//128), 1]cuda:7" = PlaceHolder[target=clone_4] +# %sum_4 : Tensor "i64[2, 1, ((s37 + 127)//128)][((s37 + 127)//128), 2*(((s37 + 127)//128)), 1]cuda:7" = PlaceHolder[target=sum_4] +# %slice_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, %floordiv_2), kwargs = {}) +# %clone_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_4,), kwargs = {memory_format: torch.contiguous_format}) +# %permute_1 : Tensor "i32[2, 1, ((s37 + 127)//128), ((s12 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:7"[num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%clone_4, [0, 1, 3, 2]), kwargs = {}) +# %sum_4 : Tensor "i64[2, 1, ((s37 + 127)//128)][Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute_1, [-1]), kwargs = {}) +# %convert_element_type_8 : Tensor "i32[2, 1, ((s37 + 127)//128)][Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_4, torch.int32), kwargs = {}) +# return %sum_4,%convert_element_type_8 +triton_red_fused__to_copy_clone_slice_sum_transpose_5 = async_compile.triton('triton_red_fused__to_copy_clone_slice_sum_transpose_5', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_clone_slice_sum_transpose_5', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_clone_slice_sum_transpose_5(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (x0 + ks0*r0_2 + ks0*ks1*x1), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp5, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ek/cekwqdbzromt7lrid3o5qjwspeentagjtfojo3ws7jkiiimzulw7.py +# Topologically Sorted Source Nodes: [q_indices], Original ATen: [aten._to_copy] +# Source node to ATen node mapping: +# q_indices => clone_6, convert_element_type_9 +# Graph fragment: +# %getitem_5 : Tensor "i64[2, 1, ((s37 + 127)//128), ((s12 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:7" = PlaceHolder[target=getitem_5] +# %convert_element_type_9 : Tensor "i32[2, 1, ((s37 + 127)//128), ((s12 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:7"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_5, torch.int32), kwargs = {}) +# %clone_6 : Tensor "i32[2, 1, ((s37 + 127)//128), ((s12 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128)), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_9,), kwargs = {memory_format: torch.contiguous_format}) +# return %clone_6 +triton_poi_fused__to_copy_6 = async_compile.triton('triton_poi_fused__to_copy_6', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 512}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_6', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused__to_copy_6(in_ptr0, out_ptr0, ks0, ks1, ks2, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % ks1) + x2 = xindex // ks2 + tmp0 = tl.load(in_ptr0 + (x1 + x0*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), xmask, eviction_policy='evict_last') + tmp1 = tmp0.to(tl.int32) + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp1, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1 = args + args.clear() + s12 = arg0_1 + s37 = arg1_1 + s21 = arg3_1 + assert_size_stride(arg2_1, (2, ), (1, )) + with torch.cuda._DeviceGuard(7): + torch.cuda.set_device(7) + buf12 = empty_strided_cuda((2, 1, (127 + s12) // 128, 1 + ((127 + s37) // 128)), (((127 + s12) // 128)*((127 + s37) // 128) + ((127 + s12) // 128), ((127 + s12) // 128)*((127 + s37) // 128) + ((127 + s12) // 128), 1 + ((127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] + triton_poi_fused_new_zeros_0_xnumel = 2*((127 + s12) // 128) + 2*((127 + s12) // 128)*((127 + s37) // 128) + stream7 = get_raw_stream(7) + triton_poi_fused_new_zeros_0.run(buf12, triton_poi_fused_new_zeros_0_xnumel, stream=stream7) + buf21 = empty_strided_cuda((2, 1, (127 + s12) // 128, 1 + ((127 + s37) // 128)), (((127 + s12) // 128)*((127 + s37) // 128) + ((127 + s12) // 128), ((127 + s12) // 128)*((127 + s37) // 128) + ((127 + s12) // 128), 1 + ((127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros] + triton_poi_fused_new_zeros_0_xnumel = 2*((127 + s12) // 128) + 2*((127 + s12) // 128)*((127 + s37) // 128) + stream7 = get_raw_stream(7) + triton_poi_fused_new_zeros_0.run(buf21, triton_poi_fused_new_zeros_0_xnumel, stream=stream7) + ps0 = (127 + s37) // 128 + ps1 = (127 + s12) // 128 + ps2 = ((127 + s12) // 128)*((127 + s37) // 128) + buf1 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 2*((127 + s12) // 128)*((127 + s37) // 128), (127 + s37) // 128, 1), torch.int32) + buf5 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 2*((127 + s12) // 128)*((127 + s37) // 128), (127 + s37) // 128, 1), torch.int32) + # Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_1, mask_2, mask_3, mask_block_sum, gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, full_blocks, full_blocks_1, dense_mask_1], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.constant_pad_nd, aten.permute, aten.sum, aten.gt, aten._to_copy] + triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream7 = get_raw_stream(7) + triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1.run(arg2_1, buf1, buf5, ps0, ps1, s12, s37, ps2, s21, triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1_xnumel, 16384, stream=stream7) + del arg2_1 + buf10 = empty_strided_cuda((2, 1, (127 + s12) // 128), (max(1, (127 + s12) // 128), max(1, (127 + s12) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [num_blocks_in_row, child_3], Original ATen: [aten.sum, aten._to_copy] + triton_red_fused__to_copy_sum_2_xnumel = 2*((127 + s12) // 128) + triton_red_fused__to_copy_sum_2_r0_numel = (127 + s37) // 128 + stream7 = get_raw_stream(7) + triton_red_fused__to_copy_sum_2.run(buf1, buf10, ps0, ps1, triton_red_fused__to_copy_sum_2_xnumel, triton_red_fused__to_copy_sum_2_r0_numel, stream=stream7) + buf19 = empty_strided_cuda((2, 1, (127 + s12) // 128), (max(1, (127 + s12) // 128), max(1, (127 + s12) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [num_blocks_in_row_1, child_7], Original ATen: [aten.sum, aten._to_copy] + triton_red_fused__to_copy_sum_2_xnumel = 2*((127 + s12) // 128) + triton_red_fused__to_copy_sum_2_r0_numel = (127 + s37) // 128 + stream7 = get_raw_stream(7) + triton_red_fused__to_copy_sum_2.run(buf5, buf19, ps0, ps1, triton_red_fused__to_copy_sum_2_xnumel, triton_red_fused__to_copy_sum_2_r0_numel, stream=stream7) + # Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort] + buf2 = torch.ops.aten.sort.stable(buf1, stable=True, dim=3, descending=True) + del buf1 + buf4 = buf2[1] + assert_size_stride(buf4, (2, 1, (127 + s12) // 128, (127 + s37) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), 2*max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), 'torch.ops.aten.sort.stable') + assert_alignment(buf4, 16, 'torch.ops.aten.sort.stable') + del buf2 + buf11 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2, setitem, arange_4, row_indices, col_range, unsqueeze_1, index_mask, child_4, valid_indices], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.lt, aten._to_copy, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream7 = get_raw_stream(7) + triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3.run(buf4, buf10, buf11, buf12, ps0, ps1, ps2, s37, triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3_xnumel, stream=stream7) + del buf4 + buf14 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 1, (127 + s37) // 128, 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_3], Original ATen: [aten.slice, aten.clone] + triton_poi_fused_clone_slice_4_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream7 = get_raw_stream(7) + triton_poi_fused_clone_slice_4.run(buf12, buf14, ps0, triton_poi_fused_clone_slice_4_xnumel, stream=stream7) + del buf12 + buf32 = empty_strided_cuda((2, 1, (127 + s37) // 128), (max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sum, aten._to_copy] + triton_red_fused__to_copy_clone_slice_sum_transpose_5_xnumel = 2*((127 + s37) // 128) + triton_red_fused__to_copy_clone_slice_sum_transpose_5_r0_numel = (127 + s12) // 128 + stream7 = get_raw_stream(7) + triton_red_fused__to_copy_clone_slice_sum_transpose_5.run(buf14, buf32, ps0, ps1, triton_red_fused__to_copy_clone_slice_sum_transpose_5_xnumel, triton_red_fused__to_copy_clone_slice_sum_transpose_5_r0_numel, stream=stream7) + # Topologically Sorted Source Nodes: [full_blocks, full_blocks_1, dense_mask_1, col_indices_1], Original ATen: [aten.eq, aten._to_copy, aten.sort] + buf6 = torch.ops.aten.sort.stable(buf5, stable=True, dim=3, descending=True) + del buf5 + buf8 = buf6[1] + assert_size_stride(buf8, (2, 1, (127 + s12) // 128, (127 + s37) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), 2*max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), 'torch.ops.aten.sort.stable') + assert_alignment(buf8, 16, 'torch.ops.aten.sort.stable') + del buf6 + buf20 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.lt, aten._to_copy, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream7 = get_raw_stream(7) + triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3.run(buf8, buf19, buf20, buf21, ps0, ps1, ps2, s37, triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3_xnumel, stream=stream7) + del buf8 + buf23 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 1, (127 + s37) // 128, 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_5], Original ATen: [aten.slice, aten.clone] + triton_poi_fused_clone_slice_4_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream7 = get_raw_stream(7) + triton_poi_fused_clone_slice_4.run(buf21, buf23, ps0, triton_poi_fused_clone_slice_4_xnumel, stream=stream7) + del buf21 + buf29 = empty_strided_cuda((2, 1, (127 + s37) // 128), (max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, num_blocks_in_row_3, full_q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sum, aten._to_copy] + triton_red_fused__to_copy_clone_slice_sum_transpose_5_xnumel = 2*((127 + s37) // 128) + triton_red_fused__to_copy_clone_slice_sum_transpose_5_r0_numel = (127 + s12) // 128 + stream7 = get_raw_stream(7) + triton_red_fused__to_copy_clone_slice_sum_transpose_5.run(buf23, buf29, ps0, ps1, triton_red_fused__to_copy_clone_slice_sum_transpose_5_xnumel, triton_red_fused__to_copy_clone_slice_sum_transpose_5_r0_numel, stream=stream7) + # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort] + buf15 = torch.ops.aten.sort.stable(reinterpret_tensor(buf14, (2, 1, (127 + s37) // 128, (127 + s12) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 0, 1, (127 + s37) // 128), 0), stable=True, dim=3, descending=True) + del buf14 + buf17 = buf15[1] + assert_size_stride(buf17, (2, 1, (127 + s37) // 128, (127 + s12) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), 1, max(1, (127 + s37) // 128)), 'torch.ops.aten.sort.stable') + assert_alignment(buf17, 16, 'torch.ops.aten.sort.stable') + del buf15 + buf30 = empty_strided_cuda((2, 1, (127 + s37) // 128, (127 + s12) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [q_indices], Original ATen: [aten._to_copy] + triton_poi_fused__to_copy_6_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream7 = get_raw_stream(7) + triton_poi_fused__to_copy_6.run(buf17, buf30, ps1, ps0, ps2, triton_poi_fused__to_copy_6_xnumel, stream=stream7) + del buf17 + # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, col_indices_3], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort] + buf24 = torch.ops.aten.sort.stable(reinterpret_tensor(buf23, (2, 1, (127 + s37) // 128, (127 + s12) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 0, 1, (127 + s37) // 128), 0), stable=True, dim=3, descending=True) + del buf23 + buf26 = buf24[1] + assert_size_stride(buf26, (2, 1, (127 + s37) // 128, (127 + s12) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), 1, max(1, (127 + s37) // 128)), 'torch.ops.aten.sort.stable') + assert_alignment(buf26, 16, 'torch.ops.aten.sort.stable') + del buf24 + buf27 = empty_strided_cuda((2, 1, (127 + s37) // 128, (127 + s12) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [full_q_indices], Original ATen: [aten._to_copy] + triton_poi_fused__to_copy_6_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream7 = get_raw_stream(7) + triton_poi_fused__to_copy_6.run(buf26, buf27, ps1, ps0, ps2, triton_poi_fused__to_copy_6_xnumel, stream=stream7) + del buf26 + return (buf27, buf29, buf30, buf32, buf20, buf19, buf11, buf10, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 1904 + arg1_1 = 1904 + arg2_1 = rand_strided((2, ), (1, ), device='cuda:7', dtype=torch.int64) + arg3_1 = 1904 + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/ns/cnslziyflxu2476dyruev3x6nkek5zrjbz7kpgftr4piqbj2wwj2.py b/SpecForge-ext/cache/compiled_kernels/ns/cnslziyflxu2476dyruev3x6nkek5zrjbz7kpgftr4piqbj2wwj2.py new file mode 100644 index 0000000000000000000000000000000000000000..bff61548873f72d35b131611b4b2a522c50bb39b --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ns/cnslziyflxu2476dyruev3x6nkek5zrjbz7kpgftr4piqbj2wwj2.py @@ -0,0 +1,1083 @@ +# AOT ID: ['13_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ef/cefhdnts4a3zdc5ez2wyobzit3i2jmgbi72deqkx342da5va3ux2.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[2, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:2" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[2, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:2" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[2, 32, s37][32*s37, s37, 1]cuda:2" = PlaceHolder[target=buf0] +# %full_default : Tensor "f32[2, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 32, %primals_10], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:2, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_17, %primals_20, %primals_22, %primals_25, %primals_27, %primals_30, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_14, %primals_15)), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_zeros_0 = async_compile.triton('triton_red_fused_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 131072, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % 32) + x2 = xindex // ks1 + x5 = triton_helpers.div_floor_integer(xindex, ks0) + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x4 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 4096*ks0*x2), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x0 + 128*x5*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/vj/cvjzly27bwmukx4ax55pzamadufonbvzwq44ofqo7zyxiclgqpht.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_2 : Tensor "bf16[2, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:2" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:2" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:2" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[2, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:2" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[2, 32, s37][32*s37, s37, 1]cuda:2" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[2, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:2" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[2, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:2" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:2" = PlaceHolder[target=getitem_5] +# %primals_13 : Tensor "i32[2, 1, s99][s99, s99, 1]cuda:2" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[2, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:2" = PlaceHolder[target=primals_9] +# %primals_22 : Tensor "i32[2, 1, s56][s56, s56, 1]cuda:2" = PlaceHolder[target=primals_22] +# %primals_25 : Tensor "i32[2, 1, s84, s53][s53*s84, s53*s84, s53, 1]cuda:2" = PlaceHolder[target=primals_25] +# %primals_17 : Tensor "i32[2, 1, s94][s94, s94, 1]cuda:2" = PlaceHolder[target=primals_17] +# %primals_20 : Tensor "i32[2, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:2" = PlaceHolder[target=primals_20] +# %primals_27 : Tensor "i32[2, 1, s100][s100, s100, 1]cuda:2" = PlaceHolder[target=primals_27] +# %primals_30 : Tensor "i32[2, 1, s6, s10][s10*s6, s10*s6, s10, 1]cuda:2" = PlaceHolder[target=primals_30] +# %primals_14 : Tensor "i64[2][1]cuda:2" = PlaceHolder[target=primals_14] +# %full_default : Tensor "f32[2, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 32, %primals_10], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:2, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_17, %primals_20, %primals_22, %primals_25, %primals_27, %primals_30, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_14, %primals_15)), kwargs = {}) +# return %getitem_4 +triton_tem_fused_zeros_1 = async_compile.triton('triton_tem_fused_zeros_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32', 'ks8': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128*ks1, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 2 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks8 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = ks8 + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_10, primals_11, primals_15, primals_7, primals_8, primals_12, primals_16, primals_18, primals_19, primals_21, primals_24, primals_23, primals_26, primals_29, primals_28, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_17, primals_20, primals_22, primals_25, primals_27, primals_30, getitem, getitem_1, tangents_1 = args + args.clear() + s37 = primals_10 + s0 = primals_11 + s75 = primals_15 + s22 = primals_7 + s72 = primals_8 + s99 = primals_12 + s94 = primals_16 + s28 = primals_18 + s4 = primals_19 + s56 = primals_21 + s53 = primals_24 + s84 = primals_23 + s100 = primals_26 + s10 = primals_29 + s6 = primals_28 + assert_size_stride(primals_2, (2, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_6, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_9, (2, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (2, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_14, (2, ), (1, )) + assert_size_stride(primals_17, (2, 1, s94), (s94, s94, 1)) + assert_size_stride(primals_20, (2, 1, s28, s4), (s28*s4, s28*s4, s4, 1)) + assert_size_stride(primals_22, (2, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_25, (2, 1, s84, s53), (s53*s84, s53*s84, s53, 1)) + assert_size_stride(primals_27, (2, 1, s100), (s100, s100, 1)) + assert_size_stride(primals_30, (2, 1, s6, s10), (s10*s6, s10*s6, s10, 1)) + assert_size_stride(getitem, (2, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(getitem_1, (2, 32, s37), (32*max(1, s37), max(1, s37), 1)) + assert_size_stride(tangents_1, (2, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1)) + with torch.cuda._DeviceGuard(2): + torch.cuda.set_device(2) + ps0 = 32*s37 + buf1 = empty_strided_cuda((2, 32, s37), (32*s37, s37, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + triton_red_fused_zeros_0_xnumel = 64*s37 + stream2 = get_raw_stream(2) + triton_red_fused_zeros_0.run(getitem, tangents_1, buf1, s37, ps0, triton_red_fused_zeros_0_xnumel, 128, stream=stream2) + del getitem + buf3 = empty_strided_cuda((2, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((2, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16) + buf5 = empty_strided_cuda((2, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream2 = get_raw_stream(2) + triton_tem_fused_zeros_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_13, primals_9, primals_22, primals_25, primals_17, primals_20, primals_27, primals_30, primals_14, buf5, s37, s0, s99, s22, s72, s56, s53, s84, s75, 4*((127 + s37) // 128) + ((127 + s0) // 128), 2, 8, stream=stream2) + del buf1 + del getitem_1 + del primals_13 + del primals_14 + del primals_17 + del primals_2 + del primals_20 + del primals_22 + del primals_25 + del primals_27 + del primals_30 + del primals_4 + del primals_6 + del primals_9 + del tangents_1 + return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_10 = 1856 + primals_11 = 1856 + primals_15 = 1856 + primals_7 = 15 + primals_8 = 15 + primals_12 = 15 + primals_16 = 15 + primals_18 = 15 + primals_19 = 15 + primals_21 = 15 + primals_24 = 15 + primals_23 = 15 + primals_26 = 15 + primals_29 = 15 + primals_28 = 15 + primals_2 = rand_strided((2, 32, 1856, 128), (7602176, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16) + primals_4 = rand_strided((2, 8, 1856, 128), (1900544, 237568, 128, 1), device='cuda:2', dtype=torch.bfloat16) + primals_6 = rand_strided((2, 8, 1856, 128), (1900544, 237568, 128, 1), device='cuda:2', dtype=torch.bfloat16) + primals_9 = rand_strided((2, 1, 15, 15), (225, 225, 15, 1), device='cuda:2', dtype=torch.int32) + primals_13 = rand_strided((2, 1, 15), (15, 15, 1), device='cuda:2', dtype=torch.int32) + primals_14 = rand_strided((2, ), (1, ), device='cuda:2', dtype=torch.int64) + primals_17 = rand_strided((2, 1, 15), (15, 15, 1), device='cuda:2', dtype=torch.int32) + primals_20 = rand_strided((2, 1, 15, 15), (225, 225, 15, 1), device='cuda:2', dtype=torch.int32) + primals_22 = rand_strided((2, 1, 15), (15, 15, 1), device='cuda:2', dtype=torch.int32) + primals_25 = rand_strided((2, 1, 15, 15), (225, 225, 15, 1), device='cuda:2', dtype=torch.int32) + primals_27 = rand_strided((2, 1, 15), (15, 15, 1), device='cuda:2', dtype=torch.int32) + primals_30 = rand_strided((2, 1, 15, 15), (225, 225, 15, 1), device='cuda:2', dtype=torch.int32) + getitem = rand_strided((2, 32, 1856, 128), (7602176, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16) + getitem_1 = rand_strided((2, 32, 1856), (59392, 1856, 1), device='cuda:2', dtype=torch.float32) + tangents_1 = rand_strided((2, 32, 1856, 128), (7602176, 237568, 128, 1), device='cuda:2', dtype=torch.bfloat16) + fn = lambda: call([primals_10, primals_11, primals_15, primals_7, primals_8, primals_12, primals_16, primals_18, primals_19, primals_21, primals_24, primals_23, primals_26, primals_29, primals_28, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_17, primals_20, primals_22, primals_25, primals_27, primals_30, getitem, getitem_1, tangents_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/o2/2005a08e3f28c2ba6bce4c95dcad972f0b1bfb9396d1edbc64844649668d1190.best_config b/SpecForge-ext/cache/compiled_kernels/o2/2005a08e3f28c2ba6bce4c95dcad972f0b1bfb9396d1edbc64844649668d1190.best_config new file mode 100644 index 0000000000000000000000000000000000000000..6c3a3559496e6e4d68292da2e678eca0b03342ab --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/o2/2005a08e3f28c2ba6bce4c95dcad972f0b1bfb9396d1edbc64844649668d1190.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 65, "triton_cache_hash": "NFABHOURJ57C2IKXWDMS2VHZ76PCVKJVD7V6CBWJDLMT5TQE5GFA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/o2/co2y6jknd5elxfam4xxcqiwymrtcpe4mjfkclsyxgznl443h7uak.py b/SpecForge-ext/cache/compiled_kernels/o2/co2y6jknd5elxfam4xxcqiwymrtcpe4mjfkclsyxgznl443h7uak.py new file mode 100644 index 0000000000000000000000000000000000000000..52c4998afd375b111f6bcd088aa25201f4c244bb --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/o2/co2y6jknd5elxfam4xxcqiwymrtcpe4mjfkclsyxgznl443h7uak.py @@ -0,0 +1,56 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 67108864}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/o7/53f081c4bfe2ed8a80488981a5c3fef3b1297ea3a48b766653d0a1282370c07c.best_config b/SpecForge-ext/cache/compiled_kernels/o7/53f081c4bfe2ed8a80488981a5c3fef3b1297ea3a48b766653d0a1282370c07c.best_config new file mode 100644 index 0000000000000000000000000000000000000000..c2d9b36c5180887fa413aa1eb230c04dc216dd00 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/o7/53f081c4bfe2ed8a80488981a5c3fef3b1297ea3a48b766653d0a1282370c07c.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "6fcabd0411a839b7b5d117b5e6638bd1b5d7bc3379312c678d803859f08278a9", "found_by_coordesc": false, "time_taken_ms": 18, "triton_cache_hash": "G2LU7LIHIOEHQSWVLFBJATACJ76YHM672CUBUDGJGAJUEQVWVOFQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/o7/co726ngn6j5cflfgsucsgervch74qdyubyfq42i7lqykher2plou.py b/SpecForge-ext/cache/compiled_kernels/o7/co726ngn6j5cflfgsucsgervch74qdyubyfq42i7lqykher2plou.py new file mode 100644 index 0000000000000000000000000000000000000000..4578b9d3bdb6f3685ef297450b50b4101994bea2 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/o7/co726ngn6j5cflfgsucsgervch74qdyubyfq42i7lqykher2plou.py @@ -0,0 +1,49 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 64, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % ks0) + x1 = xindex // ks0 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (r0_2 + x0 + 16*x1 + ks0*r0_2 + 16*ks0*x1), xmask, eviction_policy='evict_last', other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x0 + 16*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp13, xmask) + tl.store(out_ptr3 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp14, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/ob/cob65ptxwcswkyjowvaxmwnu4cpoiijoxwce6eyz2ndtpqxwqxm5.py b/SpecForge-ext/cache/compiled_kernels/ob/cob65ptxwcswkyjowvaxmwnu4cpoiijoxwce6eyz2ndtpqxwqxm5.py new file mode 100644 index 0000000000000000000000000000000000000000..70981e7dec4cf43aabdfd2be673126e3da3f6339 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ob/cob65ptxwcswkyjowvaxmwnu4cpoiijoxwce6eyz2ndtpqxwqxm5.py @@ -0,0 +1,45 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/og/474e6ee76c08daf6313ea5d043b7b7c3cc575319b509bae02c408f0d100938f4.best_config b/SpecForge-ext/cache/compiled_kernels/og/474e6ee76c08daf6313ea5d043b7b7c3cc575319b509bae02c408f0d100938f4.best_config new file mode 100644 index 0000000000000000000000000000000000000000..c1c51c5048e176f0cf0b0d2646bd98c4186a3cba --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/og/474e6ee76c08daf6313ea5d043b7b7c3cc575319b509bae02c408f0d100938f4.best_config @@ -0,0 +1 @@ +{"XBLOCK": 8, "R0_BLOCK": 128, "num_warps": 8, "num_stages": 1, "configs_hash": "48464ea7d171263ae4fed5184e32a30841f1081b8df295ec1f8e2f76e5287c9d", "found_by_coordesc": false, "time_taken_ms": 60, "triton_cache_hash": "EGDJYO36DUYGK3UQBUH6S7RMVKF77GGHWVMFFZR5R4TDMIZ4YVJA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/og/cogol55cthk4zevsy3dlqiyzipefv735ge2wtaddvk436qht5nox.py b/SpecForge-ext/cache/compiled_kernels/og/cogol55cthk4zevsy3dlqiyzipefv735ge2wtaddvk436qht5nox.py new file mode 100644 index 0000000000000000000000000000000000000000..f86429bbada3c8795b1ac60e7d732b43a2d3bbde --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/og/cogol55cthk4zevsy3dlqiyzipefv735ge2wtaddvk436qht5nox.py @@ -0,0 +1,49 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 131072, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 1048576, 'r0_': 67108864}} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 131072 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = ((xindex // 2048) % 32) + x2 = xindex // 65536 + x4 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, None) diff --git a/SpecForge-ext/cache/compiled_kernels/oj/coj6s3yvbdrqlv35il7rc5tnynrhnq5y53jqgehvkxnth3ocn534.py b/SpecForge-ext/cache/compiled_kernels/oj/coj6s3yvbdrqlv35il7rc5tnynrhnq5y53jqgehvkxnth3ocn534.py new file mode 100644 index 0000000000000000000000000000000000000000..79ac2fa1a68db2d6587f21cb688e59ad65628f2b --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/oj/coj6s3yvbdrqlv35il7rc5tnynrhnq5y53jqgehvkxnth3ocn534.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32', 'ks8': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128*ks1, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 2 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks8 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = ks8 + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/oj/cojsahfovcfxdl6jiqd3rnl6dcgz42rhl57v3a3hw2ovvedibuhs.py b/SpecForge-ext/cache/compiled_kernels/oj/cojsahfovcfxdl6jiqd3rnl6dcgz42rhl57v3a3hw2ovvedibuhs.py new file mode 100644 index 0000000000000000000000000000000000000000..7e4e0a76482ef08899847e964362bbfcf7cc17ee --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/oj/cojsahfovcfxdl6jiqd3rnl6dcgz42rhl57v3a3hw2ovvedibuhs.py @@ -0,0 +1,159 @@ +# AOT ID: ['1_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ge/cgeiqrn6fc2445boi46pfasru3dymyjiw2xhga6ztucscbgv3gtp.py +# Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax] +# Source node to ATen node mapping: +# target_head => convert_element_type +# target_p => div +# Graph fragment: +# %arg0_1 : Tensor "bf16[2, 2048, 32000][65536000, 32000, 1]cuda:2" = PlaceHolder[target=arg0_1] +# %getitem : Tensor "f32[2, 2048, 1][2048, 1, 4096]cuda:2" = PlaceHolder[target=getitem] +# %getitem_1 : Tensor "f32[2, 2048, 1][2048, 1, 4096]cuda:2" = PlaceHolder[target=getitem_1] +# %convert_element_type : Tensor "f32[2, 2048, 32000][65536000, 32000, 1]cuda:2"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%arg0_1, torch.float32), kwargs = {}) +# %prepare_softmax_online_default : [num_users=2] = call_function[target=torch.ops.prims.prepare_softmax_online.default](args = (%convert_element_type, 2), kwargs = {}) +# %sub_tensor : Tensor "f32[2, 2048, 32000][65536000, 32000, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type, %getitem), kwargs = {}) +# %exp_default : Tensor "f32[2, 2048, 32000][65536000, 32000, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub_tensor,), kwargs = {}) +# %div : Tensor "f32[2, 2048, 32000][65536000, 32000, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%exp_default, %getitem_1), kwargs = {}) +# return %getitem,%getitem_1,%div +triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0 = async_compile.triton('triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'add_persistent_rblock': True, 'tiling_scores': {'x': 0, 'r0_': 1310720000}} +) +@triton.jit +def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 4096 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32) + _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + + _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine( + _tmp3_max, _tmp3_sum, tmp2, False + ) + + _tmp3_max = tl.where(r0_mask, _tmp3_max_next, _tmp3_max) + _tmp3_sum = tl.where(r0_mask, _tmp3_sum_next, _tmp3_sum) + + tmp3, tmp4 = triton_helpers.online_softmax_reduce( + _tmp3_max, _tmp3_sum, 1, False) + tmp3 = tmp3[:, None] + tmp4 = tmp4[:, None] + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp6 - tmp3 + tmp8 = libdevice.exp(tmp7) + tmp9 = (tmp8 / tmp4) + tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, = args + args.clear() + assert_size_stride(arg0_1, (2, 2048, 32000), (65536000, 32000, 1)) + with torch.cuda._DeviceGuard(2): + torch.cuda.set_device(2) + buf2 = empty_strided_cuda((2, 2048, 32000), (65536000, 32000, 1), torch.float32) + # Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax] + stream2 = get_raw_stream(2) + triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0.run(arg0_1, buf2, 4096, 32000, stream=stream2) + del arg0_1 + return (buf2, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((2, 2048, 32000), (65536000, 32000, 1), device='cuda:2', dtype=torch.bfloat16) + fn = lambda: call([arg0_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/om/comip353oiewhfdkwlwkto6lm73uvpt5uezpl73dfzbfd6jxshzl.py b/SpecForge-ext/cache/compiled_kernels/om/comip353oiewhfdkwlwkto6lm73uvpt5uezpl73dfzbfd6jxshzl.py new file mode 100644 index 0000000000000000000000000000000000000000..a1ca1b2fba5d1861cd2fa51345f7d1a0a2c299d6 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/om/comip353oiewhfdkwlwkto6lm73uvpt5uezpl73dfzbfd6jxshzl.py @@ -0,0 +1,693 @@ +# AOT ID: ['9_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +from torch._C import _cuda_getCurrentRawStream as get_raw_stream +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/e6/ce6g3e5xikzaf3a5wmxill5os7magq3p3hzz7uw37za4jjui6tk6.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:0" = PlaceHolder[target=primals_1] +# %primals_3 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:0" = PlaceHolder[target=primals_3] +# %primals_5 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:0" = PlaceHolder[target=primals_5] +# %getitem_1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:0" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:0" = PlaceHolder[target=buf1] +# %primals_9 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:0" = PlaceHolder[target=primals_9] +# %primals_7 : Tensor "i32[8, 1, 16, s72][16*s72, 16*s72, s72, 1]cuda:0" = PlaceHolder[target=primals_7] +# %primals_11 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:0" = PlaceHolder[target=primals_11] +# %primals_13 : Tensor "i32[8, 1, 16, s4][16*s4, 16*s4, s4, 1]cuda:0" = PlaceHolder[target=primals_13] +# %primals_10 : Tensor "i64[8][1]cuda:0" = PlaceHolder[target=primals_10] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_3, %primals_5, %sdpa_score0, (2048, %primals_8, %primals_9, %primals_7, %primals_11, %primals_13, %primals_15, %primals_17, %primals_19, %primals_21, 128, 128, %sdpa_mask0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_10,)), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 8 + HQ = 32 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21 = args + args.clear() + s0 = primals_2 + s43 = primals_4 + s72 = primals_6 + s71 = primals_8 + s4 = primals_12 + s56 = primals_14 + s84 = primals_16 + s99 = primals_18 + s6 = primals_20 + assert_size_stride(primals_1, (8, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_3, (8, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_5, (8, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_7, (8, 1, 16, s72), (16*s72, 16*s72, s72, 1)) + assert_size_stride(primals_9, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (8, ), (1, )) + assert_size_stride(primals_11, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_13, (8, 1, 16, s4), (16*s4, 16*s4, s4, 1)) + assert_size_stride(primals_15, (8, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_17, (8, 1, s84, 16), (16*s84, 16*s84, 16, 1)) + assert_size_stride(primals_19, (8, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_21, (8, 1, s6, 16), (16*s6, 16*s6, 16, 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf0 = empty_strided_cuda((8, 32, 2048), (65536, 2048, 1), torch.float32) + buf1 = empty_strided_cuda((8, 32, 2048), (65536, 2048, 1), torch.float32) + buf2 = empty_strided_cuda((8, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream0 = get_raw_stream(0) + triton_tem_fused_0.run(primals_1, primals_3, primals_5, buf0, buf1, primals_9, primals_7, primals_11, primals_13, primals_10, buf2, s0, s72, 16, 8, 32, stream=stream0) + del buf1 + return (buf2, primals_1, primals_3, primals_5, primals_7, primals_9, primals_10, primals_11, primals_13, primals_15, primals_17, primals_19, primals_21, buf2, buf0, s0, s72, s4, s56, s84, s99, s6, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:0', dtype=torch.bfloat16) + primals_2 = 4096 + primals_3 = rand_strided((8, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:0', dtype=torch.bfloat16) + primals_4 = 4096 + primals_5 = rand_strided((8, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:0', dtype=torch.bfloat16) + primals_6 = 32 + primals_7 = rand_strided((8, 1, 16, 32), (512, 512, 32, 1), device='cuda:0', dtype=torch.int32) + primals_8 = 4096 + primals_9 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:0', dtype=torch.int32) + primals_10 = rand_strided((8, ), (1, ), device='cuda:0', dtype=torch.int64) + primals_11 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:0', dtype=torch.int32) + primals_12 = 32 + primals_13 = rand_strided((8, 1, 16, 32), (512, 512, 32, 1), device='cuda:0', dtype=torch.int32) + primals_14 = 32 + primals_15 = rand_strided((8, 1, 32), (32, 32, 1), device='cuda:0', dtype=torch.int32) + primals_16 = 32 + primals_17 = rand_strided((8, 1, 32, 16), (512, 512, 16, 1), device='cuda:0', dtype=torch.int32) + primals_18 = 32 + primals_19 = rand_strided((8, 1, 32), (32, 32, 1), device='cuda:0', dtype=torch.int32) + primals_20 = 32 + primals_21 = rand_strided((8, 1, 32, 16), (512, 512, 16, 1), device='cuda:0', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/om/commwbmpskkyeb2il5plnuk3v52gjxr4cz37htdlnrshyczvy5xx.py b/SpecForge-ext/cache/compiled_kernels/om/commwbmpskkyeb2il5plnuk3v52gjxr4cz37htdlnrshyczvy5xx.py new file mode 100644 index 0000000000000000000000000000000000000000..8716b7a1245ef158cc6ffad420a12fdb857a3dba --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/om/commwbmpskkyeb2il5plnuk3v52gjxr4cz37htdlnrshyczvy5xx.py @@ -0,0 +1,89 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1(in_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % 16) + x2 = xindex // ks2 + _tmp36 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x5 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = (r0_index % 128) + r0_4 = r0_index // 128 + tmp0 = r0_3 + 128*x0 + tmp1 = ks1 + tmp2 = tmp0 < tmp1 + tmp3 = r0_4 + 128*x1 + tmp4 = r0_3 + 128*x0 + tmp5 = tmp3 >= tmp4 + tmp6 = tl.load(in_ptr0 + (tl.broadcast_to(x2, [XBLOCK, R0_BLOCK])), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp7 = tmp4 < tmp6 + tmp8 = tmp3 < tmp6 + tmp9 = tmp7 & tmp8 + tmp10 = tmp5 & tmp9 + tmp11 = tl.full([1, 1], False, tl.int1) + tmp12 = tmp11 | tmp10 + tmp13 = tl.full([1, 1], 2048, tl.int64) + tmp14 = tmp4 >= tmp13 + tmp15 = ((r0_3 + 128*x0) % 2048) + tmp16 = tmp15 < tmp6 + tmp17 = tmp14 & tmp16 + tmp18 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp19 = (tmp18 % tmp13) + tmp20 = tl.full([1, 1], 0, tl.int32) + tmp21 = tmp19 != tmp20 + tmp22 = (libdevice.signbit(tmp19) != 0) if (tmp19).dtype is tl.float32 else tmp19 < 0 + tmp23 = (libdevice.signbit(tmp13) != 0) if (tmp13).dtype is tl.float32 else tmp13 < 0 + tmp24 = tmp22 != tmp23 + tmp25 = tmp21 & tmp24 + tmp26 = tmp19 + tmp13 + tmp27 = tl.where(tmp25, tmp26, tmp19) + tmp28 = tl.full([1, 1], 0, tl.int64) + tmp29 = tmp27 == tmp28 + tmp30 = tmp17 & tmp29 + tmp31 = tmp12 | tmp30 + tmp32 = tl.full(tmp31.shape, False, tmp31.dtype) + tmp33 = tl.where(tmp2, tmp31, tmp32) + tmp34 = tmp33.to(tl.int64) + tmp35 = tl.broadcast_to(tmp34, [XBLOCK, R0_BLOCK]) + tmp37 = _tmp36 + tmp35 + _tmp36 = tl.where(r0_mask & xmask, tmp37, _tmp36) + tmp36 = tl.sum(_tmp36, 1)[:, None] + tmp38 = tl.full([1, 1], 0, tl.int64) + tmp39 = tmp36 > tmp38 + tmp40 = tl.full([1, 1], 16384, tl.int64) + tmp41 = tmp36 < tmp40 + tmp42 = tmp39 & tmp41 + tmp43 = tmp42.to(tl.int8) + tmp44 = tmp43.to(tl.int32) + tmp45 = tmp36 == tmp40 + tmp46 = tmp45.to(tl.int8) + tmp47 = tmp46.to(tl.int32) + tl.store(out_ptr1 + (x5), tmp44, xmask) + tl.store(out_ptr2 + (x5), tmp47, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/om/comum6k4pw6zmh6ju4rycwr6etxjcb3abs5wdl44clwcpipkgtbx.py b/SpecForge-ext/cache/compiled_kernels/om/comum6k4pw6zmh6ju4rycwr6etxjcb3abs5wdl44clwcpipkgtbx.py new file mode 100644 index 0000000000000000000000000000000000000000..4fbd4ba5a78159b5796bd6306b4deb85216f89c0 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/om/comum6k4pw6zmh6ju4rycwr6etxjcb3abs5wdl44clwcpipkgtbx.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32', 'ks8': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128*ks1, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 2 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks8 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = ks8 + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/oo/1330f2612a3bfcef6da9324d0d3e8d960babf6c97bc942061654f27cf1e1d388.best_config b/SpecForge-ext/cache/compiled_kernels/oo/1330f2612a3bfcef6da9324d0d3e8d960babf6c97bc942061654f27cf1e1d388.best_config new file mode 100644 index 0000000000000000000000000000000000000000..0102fea510b9bf77ab661e714dfc816c066dc0d8 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/oo/1330f2612a3bfcef6da9324d0d3e8d960babf6c97bc942061654f27cf1e1d388.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "IK5RT3JGLTF5PMMUH32NIWB2GXNU6R6CGIZSCRHU3I65YM226KDA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/oo/coo6yp6pxkzn6fs2k6qo3o6btpvpbycwufcwhpka6lqffc6e2vth.py b/SpecForge-ext/cache/compiled_kernels/oo/coo6yp6pxkzn6fs2k6qo3o6btpvpbycwufcwhpka6lqffc6e2vth.py new file mode 100644 index 0000000000000000000000000000000000000000..4ade2b9e622af3f8dc1ea1878779bbf2901bd0aa --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/oo/coo6yp6pxkzn6fs2k6qo3o6btpvpbycwufcwhpka6lqffc6e2vth.py @@ -0,0 +1,27 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 2048}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_6', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused__to_copy_6(in_ptr0, out_ptr0, ks0, ks1, ks2, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % ks1) + x2 = xindex // ks2 + tmp0 = tl.load(in_ptr0 + (x1 + x0*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), xmask, eviction_policy='evict_last') + tmp1 = tmp0.to(tl.int32) + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp1, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/or/corueie7pxhqkyfb3s5e5zahkjku7uoaoyv6sseiyavb3i6dibx6.py b/SpecForge-ext/cache/compiled_kernels/or/corueie7pxhqkyfb3s5e5zahkjku7uoaoyv6sseiyavb3i6dibx6.py new file mode 100644 index 0000000000000000000000000000000000000000..dd3c81e856784a4ca13925801bad158c6af86712 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/or/corueie7pxhqkyfb3s5e5zahkjku7uoaoyv6sseiyavb3i6dibx6.py @@ -0,0 +1,418 @@ +# AOT ID: ['15_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wl/cwlepqkeid7zo46auilgnirm5hqxuf7wqtbi3bhndddz2uyg7dbu.py +# Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax => argmax +# Graph fragment: +# %arg1_1 : Tensor "bf16[8, s3, 32000][32000*s3, 32000, 1]cuda:7" = PlaceHolder[target=arg1_1] +# %argmax : Tensor "i64[8, s3][s3, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {}) +# return %argmax +triton_red_fused_argmax_0 = async_compile.triton('triton_red_fused_argmax_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ne/cneup3jgvtc2kyazyreujmpq5tpdftpmosibbc6a6rhgokp3hfek.py +# Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax_1 => argmax_1 +# Graph fragment: +# %arg4_1 : Tensor "f32[8, s3, 32000][s71, 32000, 1]cuda:7" = PlaceHolder[target=arg4_1] +# %argmax_1 : Tensor "i64[8, s3][s3, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg4_1, -1), kwargs = {}) +# return %argmax_1 +triton_red_fused_argmax_1 = async_compile.triton('triton_red_fused_argmax_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(1,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + ks1*x1), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/d2/cd2itktqorx3l5od6vqzz3m73qqojgnnlgzl47ehhpxtuclki64i.py +# Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum] +# Source node to ATen node mapping: +# eq => eq_2 +# mul => mul_7 +# squeeze => squeeze +# sum_1 => sum_1 +# Graph fragment: +# %argmax : Tensor "i64[8, s3][s3, 1]cuda:7" = PlaceHolder[target=argmax] +# %argmax_1 : Tensor "i64[8, s3][s3, 1]cuda:7" = PlaceHolder[target=argmax_1] +# %arg5_1 : Tensor "i64[8, s3, 1][s3, 1, 1]cuda:7" = PlaceHolder[target=arg5_1] +# %eq_2 : Tensor "b8[8, s3][s3, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {}) +# %squeeze : Tensor "i64[8, s3][s3, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg5_1, -1), kwargs = {}) +# %mul_7 : Tensor "i64[8, s3][s3, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq_2, %squeeze), kwargs = {}) +# %sum_1 : Tensor "i64[][]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_7,), kwargs = {}) +# return %buf3 +triton_red_fused_eq_mul_squeeze_sum_2 = async_compile.triton('triton_red_fused_eq_mul_squeeze_sum_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2, 'r0_': 8192}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'out_ptr0': '*i64', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (ks0*((((r0_1 + 4*ks0*x0) // ks0) % 8)) + ((r0_1 % ks0))), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp1 = tl.load(in_ptr1 + (ks0*((((r0_1 + 4*ks0*x0) // ks0) % 8)) + ((r0_1 % ks0))), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp4 = tl.load(in_ptr2 + (ks0*((((r0_1 + 4*ks0*x0) // ks0) % 8)) + ((r0_1 % ks0))), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask & xmask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + tl.store(out_ptr0 + (x0), tmp7, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/6h/c6hb7ee3npf6er65o5pppoq7uu3izd2oii7lz3cy4tsiysyvjtwd.py +# Topologically Sorted Source Nodes: [sum_2], Original ATen: [aten.sum] +# Source node to ATen node mapping: +# sum_2 => sum_2 +# Graph fragment: +# %arg7_1 : Tensor "i64[8, s14, 1][s14, 1, 1]cuda:7" = PlaceHolder[target=arg7_1] +# %sum_2 : Tensor "i64[][]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg7_1,), kwargs = {}) +# return %buf5 +triton_red_fused_sum_3 = async_compile.triton('triton_red_fused_sum_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2, 'r0_': 8192}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_sum_3(in_ptr0, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (ks0*((((r0_1 + 4*ks0*x0) // ks0) % 8)) + ((r0_1 % ks0))), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = _tmp2 + tmp1 + _tmp2 = tl.where(r0_mask & xmask, tmp3, _tmp2) + tmp2 = tl.sum(_tmp2, 1)[:, None] + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/6s/c6sofobf6aibktijbyii34pvwp2u36pgbnmz5t2jtcp2eu7hxsct.py +# Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1, sum_2, clamp_min, truediv], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum, aten.clamp_min, aten.div] +# Source node to ATen node mapping: +# clamp_min => clamp_min +# eq => eq_2 +# mul => mul_7 +# squeeze => squeeze +# sum_1 => sum_1 +# sum_2 => sum_2 +# truediv => div +# Graph fragment: +# %buf3 : Tensor "i64[2][1]cuda:7" = PlaceHolder[target=buf3] +# %buf5 : Tensor "i64[2][1]cuda:7" = PlaceHolder[target=buf5] +# %sum_1 : Tensor "i64[][]cuda:7" = PlaceHolder[target=sum_1] +# %sum_2 : Tensor "i64[][]cuda:7" = PlaceHolder[target=sum_2] +# %eq_2 : Tensor "b8[8, s3][s3, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {}) +# %squeeze : Tensor "i64[8, s3][s3, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg5_1, -1), kwargs = {}) +# %mul_7 : Tensor "i64[8, s3][s3, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq_2, %squeeze), kwargs = {}) +# %sum_1 : Tensor "i64[][]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_7,), kwargs = {}) +# %sum_2 : Tensor "i64[][]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg7_1,), kwargs = {}) +# %clamp_min : Tensor "f32[][]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%sum_2, 1e-06), kwargs = {}) +# %div : Tensor "f32[][]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, %clamp_min), kwargs = {}) +# return %sum_1,%sum_2,%div +triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4 = async_compile.triton('triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 1, 'r0_': 2}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'out_ptr2': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'r0_': 8}} +) +@triton.jit +def triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4(in_ptr0, in_ptr1, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 1 + r0_numel = 2 + R0_BLOCK: tl.constexpr = 2 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), None) + tmp4 = tl.load(in_ptr1 + (r0_0), None) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.sum(tmp1, 1)[:, None].to(tl.int64) + tmp5 = tl.broadcast_to(tmp4, [XBLOCK, R0_BLOCK]) + tmp7 = tl.sum(tmp5, 1)[:, None].to(tl.int64) + tmp8 = tmp3.to(tl.float32) + tmp9 = tmp7.to(tl.float32) + tmp10 = 1e-06 + tmp11 = triton_helpers.maximum(tmp9, tmp10) + tmp12 = (tmp8 / tmp11) + tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp12, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1 = args + args.clear() + s3 = arg0_1 + s71 = arg2_1 + s0 = arg3_1 + s14 = arg6_1 + assert_size_stride(arg1_1, (8, s3, 32000), (32000*s3, 32000, 1)) + assert_size_stride(arg4_1, (8, s3, 32000), (s71, 32000, 1)) + assert_size_stride(arg5_1, (8, s3, 1), (s3, 1, 1)) + assert_size_stride(arg7_1, (8, s14, 1), (s14, 1, 1)) + with torch.cuda._DeviceGuard(7): + torch.cuda.set_device(7) + buf0 = empty_strided_cuda((8, s3), (s3, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] + triton_red_fused_argmax_0_xnumel = 8*s3 + stream7 = get_raw_stream(7) + triton_red_fused_argmax_0.run(arg1_1, buf0, triton_red_fused_argmax_0_xnumel, 32000, stream=stream7) + del arg1_1 + buf1 = empty_strided_cuda((8, s3), (s3, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] + triton_red_fused_argmax_1_xnumel = 8*s3 + stream7 = get_raw_stream(7) + triton_red_fused_argmax_1.run(arg4_1, buf1, s3, s71, triton_red_fused_argmax_1_xnumel, 32000, stream=stream7) + del arg4_1 + buf3 = empty_strided_cuda((2, ), (1, ), torch.int64) + # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum] + triton_red_fused_eq_mul_squeeze_sum_2_r0_numel = 4*s3 + stream7 = get_raw_stream(7) + triton_red_fused_eq_mul_squeeze_sum_2.run(buf0, buf1, arg5_1, buf3, s3, 2, triton_red_fused_eq_mul_squeeze_sum_2_r0_numel, stream=stream7) + del arg5_1 + del buf0 + del buf1 + buf5 = empty_strided_cuda((2, ), (1, ), torch.int64) + # Topologically Sorted Source Nodes: [sum_2], Original ATen: [aten.sum] + triton_red_fused_sum_3_r0_numel = 4*s14 + stream7 = get_raw_stream(7) + triton_red_fused_sum_3.run(arg7_1, buf5, s14, 2, triton_red_fused_sum_3_r0_numel, stream=stream7) + del arg7_1 + buf7 = empty_strided_cuda((), (), torch.float32) + # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1, sum_2, clamp_min, truediv], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum, aten.clamp_min, aten.div] + stream7 = get_raw_stream(7) + triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4.run(buf3, buf5, buf7, 1, 2, stream=stream7) + del buf3 + del buf5 + return (buf7, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 2025 + arg1_1 = rand_strided((8, 2025, 32000), (64800000, 32000, 1), device='cuda:7', dtype=torch.bfloat16) + arg2_1 = 65024000 + arg3_1 = 32000 + arg4_1 = rand_strided((8, 2025, 32000), (65024000, 32000, 1), device='cuda:7', dtype=torch.float32) + arg5_1 = rand_strided((8, 2025, 1), (2025, 1, 1), device='cuda:7', dtype=torch.int64) + arg6_1 = 2025 + arg7_1 = rand_strided((8, 2025, 1), (2025, 1, 1), device='cuda:7', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/or/corydm2c3xxksaci4ntuve665fj5b23b7ktrninp6eo7v36mnlmw.py b/SpecForge-ext/cache/compiled_kernels/or/corydm2c3xxksaci4ntuve665fj5b23b7ktrninp6eo7v36mnlmw.py new file mode 100644 index 0000000000000000000000000000000000000000..e5ba839ec252acbf0a69e88eef39dc273840a0cf --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/or/corydm2c3xxksaci4ntuve665fj5b23b7ktrninp6eo7v36mnlmw.py @@ -0,0 +1,44 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 1, 'r0_': 2}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'out_ptr2': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'r0_': 8}} +) +@triton.jit +def triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4(in_ptr0, in_ptr1, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 1 + r0_numel = 2 + R0_BLOCK: tl.constexpr = 2 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), None) + tmp4 = tl.load(in_ptr1 + (r0_0), None) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.sum(tmp1, 1)[:, None].to(tl.int64) + tmp5 = tl.broadcast_to(tmp4, [XBLOCK, R0_BLOCK]) + tmp7 = tl.sum(tmp5, 1)[:, None].to(tl.int64) + tmp8 = tmp3.to(tl.float32) + tmp9 = tmp7.to(tl.float32) + tmp10 = 1e-06 + tmp11 = triton_helpers.maximum(tmp9, tmp10) + tmp12 = (tmp8 / tmp11) + tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp12, None) diff --git a/SpecForge-ext/cache/compiled_kernels/ou/coud2o52t4gsvsfv4xou2jy2lheb3owvezjxdxtjzrjhw7aywbbm.py b/SpecForge-ext/cache/compiled_kernels/ou/coud2o52t4gsvsfv4xou2jy2lheb3owvezjxdxtjzrjhw7aywbbm.py new file mode 100644 index 0000000000000000000000000000000000000000..f2f66a26904648b8b4181f6fbe009c79111c9d01 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ou/coud2o52t4gsvsfv4xou2jy2lheb3owvezjxdxtjzrjhw7aywbbm.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 8 + HQ = 32 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/ou/coukwgrmgroilk7bnuxtvs5dntno4azugcazcisd5pndhztcj2tw.py b/SpecForge-ext/cache/compiled_kernels/ou/coukwgrmgroilk7bnuxtvs5dntno4azugcazcisd5pndhztcj2tw.py new file mode 100644 index 0000000000000000000000000000000000000000..99155e9ab78dee6be40cc7cea3eb406c82418c6a --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ou/coukwgrmgroilk7bnuxtvs5dntno4azugcazcisd5pndhztcj2tw.py @@ -0,0 +1,693 @@ +# AOT ID: ['9_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +from torch._C import _cuda_getCurrentRawStream as get_raw_stream +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/fc/cfclrfzjl3tubzdja6ambskjxk6u7eawljcyu54gp6fv5g4t6zls.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:6" = PlaceHolder[target=primals_1] +# %primals_3 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:6" = PlaceHolder[target=primals_3] +# %primals_5 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:6" = PlaceHolder[target=primals_5] +# %getitem_1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:6" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:6" = PlaceHolder[target=buf1] +# %primals_9 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:6" = PlaceHolder[target=primals_9] +# %primals_7 : Tensor "i32[8, 1, 16, s72][16*s72, 16*s72, s72, 1]cuda:6" = PlaceHolder[target=primals_7] +# %primals_11 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:6" = PlaceHolder[target=primals_11] +# %primals_13 : Tensor "i32[8, 1, 16, s4][16*s4, 16*s4, s4, 1]cuda:6" = PlaceHolder[target=primals_13] +# %primals_10 : Tensor "i64[8][1]cuda:6" = PlaceHolder[target=primals_10] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_3, %primals_5, %sdpa_score0, (2048, %primals_8, %primals_9, %primals_7, %primals_11, %primals_13, %primals_15, %primals_17, %primals_19, %primals_21, 128, 128, %sdpa_mask0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_10,)), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 8 + HQ = 32 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21 = args + args.clear() + s0 = primals_2 + s43 = primals_4 + s72 = primals_6 + s71 = primals_8 + s4 = primals_12 + s56 = primals_14 + s84 = primals_16 + s99 = primals_18 + s6 = primals_20 + assert_size_stride(primals_1, (8, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_3, (8, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_5, (8, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_7, (8, 1, 16, s72), (16*s72, 16*s72, s72, 1)) + assert_size_stride(primals_9, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (8, ), (1, )) + assert_size_stride(primals_11, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_13, (8, 1, 16, s4), (16*s4, 16*s4, s4, 1)) + assert_size_stride(primals_15, (8, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_17, (8, 1, s84, 16), (16*s84, 16*s84, 16, 1)) + assert_size_stride(primals_19, (8, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_21, (8, 1, s6, 16), (16*s6, 16*s6, 16, 1)) + with torch.cuda._DeviceGuard(6): + torch.cuda.set_device(6) + buf0 = empty_strided_cuda((8, 32, 2048), (65536, 2048, 1), torch.float32) + buf1 = empty_strided_cuda((8, 32, 2048), (65536, 2048, 1), torch.float32) + buf2 = empty_strided_cuda((8, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream6 = get_raw_stream(6) + triton_tem_fused_0.run(primals_1, primals_3, primals_5, buf0, buf1, primals_9, primals_7, primals_11, primals_13, primals_10, buf2, s0, s72, 16, 8, 32, stream=stream6) + del buf1 + return (buf2, primals_1, primals_3, primals_5, primals_7, primals_9, primals_10, primals_11, primals_13, primals_15, primals_17, primals_19, primals_21, buf2, buf0, s0, s72, s4, s56, s84, s99, s6, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:6', dtype=torch.bfloat16) + primals_2 = 4096 + primals_3 = rand_strided((8, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:6', dtype=torch.bfloat16) + primals_4 = 4096 + primals_5 = rand_strided((8, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:6', dtype=torch.bfloat16) + primals_6 = 32 + primals_7 = rand_strided((8, 1, 16, 32), (512, 512, 32, 1), device='cuda:6', dtype=torch.int32) + primals_8 = 4096 + primals_9 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:6', dtype=torch.int32) + primals_10 = rand_strided((8, ), (1, ), device='cuda:6', dtype=torch.int64) + primals_11 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:6', dtype=torch.int32) + primals_12 = 32 + primals_13 = rand_strided((8, 1, 16, 32), (512, 512, 32, 1), device='cuda:6', dtype=torch.int32) + primals_14 = 32 + primals_15 = rand_strided((8, 1, 32), (32, 32, 1), device='cuda:6', dtype=torch.int32) + primals_16 = 32 + primals_17 = rand_strided((8, 1, 32, 16), (512, 512, 16, 1), device='cuda:6', dtype=torch.int32) + primals_18 = 32 + primals_19 = rand_strided((8, 1, 32), (32, 32, 1), device='cuda:6', dtype=torch.int32) + primals_20 = 32 + primals_21 = rand_strided((8, 1, 32, 16), (512, 512, 16, 1), device='cuda:6', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/ow/cowf5gtk4q6ge3wg4p36ieckswovdl46kpqdrzr7ay6zzl6qcshc.py b/SpecForge-ext/cache/compiled_kernels/ow/cowf5gtk4q6ge3wg4p36ieckswovdl46kpqdrzr7ay6zzl6qcshc.py new file mode 100644 index 0000000000000000000000000000000000000000..8bd632cf2f74489e9a7a46d99aa46b9734fe28c6 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ow/cowf5gtk4q6ge3wg4p36ieckswovdl46kpqdrzr7ay6zzl6qcshc.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 16 + stride_q_idx_h = 256 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/SpecForge-ext/cache/compiled_kernels/oz/cozd5s5mys6m4rd7jjzn44y2f3osic7xndigisjhi4ejbqy6dqmb.py b/SpecForge-ext/cache/compiled_kernels/oz/cozd5s5mys6m4rd7jjzn44y2f3osic7xndigisjhi4ejbqy6dqmb.py new file mode 100644 index 0000000000000000000000000000000000000000..606842e2e490d597b7f4cb8f665efc31850472b2 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/oz/cozd5s5mys6m4rd7jjzn44y2f3osic7xndigisjhi4ejbqy6dqmb.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks2 + stride_q_idx_h = 16*ks3 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/SpecForge-ext/cache/compiled_kernels/pe/bd73cc33af104115931a81ae68092b63c8ffd51a8b0c1f156a4dbb75e847d1a7.best_config b/SpecForge-ext/cache/compiled_kernels/pe/bd73cc33af104115931a81ae68092b63c8ffd51a8b0c1f156a4dbb75e847d1a7.best_config new file mode 100644 index 0000000000000000000000000000000000000000..2606c4b85dba872f17d98e48096cb7db293c5e54 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/pe/bd73cc33af104115931a81ae68092b63c8ffd51a8b0c1f156a4dbb75e847d1a7.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "6fcabd0411a839b7b5d117b5e6638bd1b5d7bc3379312c678d803859f08278a9", "found_by_coordesc": false, "time_taken_ms": 17, "triton_cache_hash": "EB4J5U2HKNQBLXRWK6B5L6ATOH55AWD3MB7P63KH5AKRGRDZER7A"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/pe/cpecz443wnneogc65oicauauoytwy7k6ryeyv24laczmux6pdi2b.py b/SpecForge-ext/cache/compiled_kernels/pe/cpecz443wnneogc65oicauauoytwy7k6ryeyv24laczmux6pdi2b.py new file mode 100644 index 0000000000000000000000000000000000000000..a41de8f36f93305f631dc76415f293f9d8b02f8c --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/pe/cpecz443wnneogc65oicauauoytwy7k6ryeyv24laczmux6pdi2b.py @@ -0,0 +1,50 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 256, 'r0_': 4096}} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 32 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % 16) + x1 = xindex // 16 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + 17*r0_2 + 272*x1), xmask, other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x3), tmp13, xmask) + tl.store(out_ptr3 + (x3), tmp14, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/pg/cpgdxbzkoe5fdebg4uspni2vq7hrekjkduwlnjdqhflv6bogduul.py b/SpecForge-ext/cache/compiled_kernels/pg/cpgdxbzkoe5fdebg4uspni2vq7hrekjkduwlnjdqhflv6bogduul.py new file mode 100644 index 0000000000000000000000000000000000000000..488014c4033f47d37a77e4f44cdb8929e5b5b7fe --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/pg/cpgdxbzkoe5fdebg4uspni2vq7hrekjkduwlnjdqhflv6bogduul.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 2 + HQ = 32 + Q_LEN = ks0 + ZKV = 2 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 4096*idx_zq*ks0, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks5 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/pw/cpwpu2cmad2immvjopaovup6vhhl7hcq6byi5ndp5mujt3vksqp7.py b/SpecForge-ext/cache/compiled_kernels/pw/cpwpu2cmad2immvjopaovup6vhhl7hcq6byi5ndp5mujt3vksqp7.py new file mode 100644 index 0000000000000000000000000000000000000000..373bea01771fa0341e8cbf3d8785a3950893b3c4 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/pw/cpwpu2cmad2immvjopaovup6vhhl7hcq6byi5ndp5mujt3vksqp7.py @@ -0,0 +1,711 @@ +# AOT ID: ['13_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +from torch._C import _cuda_getCurrentRawStream as get_raw_stream +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/od/codfpjb7dz7enkazdvqf5ewtodhwkmtmkpna4l26z45yma2zw2xk.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_2 : Tensor "bf16[2, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:5" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:5" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:5" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[2, 32, s37][32*s37, s37, 1]cuda:5" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[2, 32, s37][32*s37, s37, 1]cuda:5" = PlaceHolder[target=buf1] +# %primals_13 : Tensor "i32[2, 1, s99][s99, s99, 1]cuda:5" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[2, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:5" = PlaceHolder[target=primals_9] +# %primals_17 : Tensor "i32[2, 1, s94][s94, s94, 1]cuda:5" = PlaceHolder[target=primals_17] +# %primals_20 : Tensor "i32[2, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:5" = PlaceHolder[target=primals_20] +# %primals_14 : Tensor "i64[2][1]cuda:5" = PlaceHolder[target=primals_14] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_17, %primals_20, %primals_22, %primals_25, %primals_27, %primals_30, 128, 128, %sdpa_mask0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_14, %primals_15)), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 2 + HQ = 32 + Q_LEN = ks0 + ZKV = 2 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 4096*idx_zq*ks0, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks5 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30 = args + args.clear() + s50 = primals_1 + s0 = primals_3 + s43 = primals_5 + s22 = primals_7 + s72 = primals_8 + s37 = primals_10 + s71 = primals_11 + s99 = primals_12 + s75 = primals_15 + s94 = primals_16 + s28 = primals_18 + s4 = primals_19 + s56 = primals_21 + s84 = primals_23 + s53 = primals_24 + s100 = primals_26 + s6 = primals_28 + s10 = primals_29 + assert_size_stride(primals_2, (2, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_6, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_9, (2, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (2, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_14, (2, ), (1, )) + assert_size_stride(primals_17, (2, 1, s94), (s94, s94, 1)) + assert_size_stride(primals_20, (2, 1, s28, s4), (s28*s4, s28*s4, s4, 1)) + assert_size_stride(primals_22, (2, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_25, (2, 1, s84, s53), (s53*s84, s53*s84, s53, 1)) + assert_size_stride(primals_27, (2, 1, s100), (s100, s100, 1)) + assert_size_stride(primals_30, (2, 1, s6, s10), (s10*s6, s10*s6, s10, 1)) + with torch.cuda._DeviceGuard(5): + torch.cuda.set_device(5) + buf0 = empty_strided_cuda((2, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((2, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((2, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream5 = get_raw_stream(5) + triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_13, primals_9, primals_17, primals_20, primals_14, buf2, s37, s0, s99, s22, s72, s75, (127 + s37) // 128, 2, 32, stream=stream5) + del buf1 + return (buf2, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_17, primals_20, primals_22, primals_25, primals_27, primals_30, buf2, buf0, s37, s0, s75, s22, s72, s99, s94, s28, s4, s56, s53, s84, s100, s10, s6, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 1569 + primals_2 = rand_strided((2, 32, 1569, 128), (6426624, 128, 4096, 1), device='cuda:5', dtype=torch.bfloat16) + primals_3 = 1569 + primals_4 = rand_strided((2, 8, 1569, 128), (1606656, 200832, 128, 1), device='cuda:5', dtype=torch.bfloat16) + primals_5 = 1569 + primals_6 = rand_strided((2, 8, 1569, 128), (1606656, 200832, 128, 1), device='cuda:5', dtype=torch.bfloat16) + primals_7 = 13 + primals_8 = 13 + primals_9 = rand_strided((2, 1, 13, 13), (169, 169, 13, 1), device='cuda:5', dtype=torch.int32) + primals_10 = 1569 + primals_11 = 1569 + primals_12 = 13 + primals_13 = rand_strided((2, 1, 13), (13, 13, 1), device='cuda:5', dtype=torch.int32) + primals_14 = rand_strided((2, ), (1, ), device='cuda:5', dtype=torch.int64) + primals_15 = 1569 + primals_16 = 13 + primals_17 = rand_strided((2, 1, 13), (13, 13, 1), device='cuda:5', dtype=torch.int32) + primals_18 = 13 + primals_19 = 13 + primals_20 = rand_strided((2, 1, 13, 13), (169, 169, 13, 1), device='cuda:5', dtype=torch.int32) + primals_21 = 13 + primals_22 = rand_strided((2, 1, 13), (13, 13, 1), device='cuda:5', dtype=torch.int32) + primals_23 = 13 + primals_24 = 13 + primals_25 = rand_strided((2, 1, 13, 13), (169, 169, 13, 1), device='cuda:5', dtype=torch.int32) + primals_26 = 13 + primals_27 = rand_strided((2, 1, 13), (13, 13, 1), device='cuda:5', dtype=torch.int32) + primals_28 = 13 + primals_29 = 13 + primals_30 = rand_strided((2, 1, 13, 13), (169, 169, 13, 1), device='cuda:5', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/pz/55d99bed8ef60879075dca9a24c03212b4ba7d3cb3730f76207fdd4c3676d39f.best_config b/SpecForge-ext/cache/compiled_kernels/pz/55d99bed8ef60879075dca9a24c03212b4ba7d3cb3730f76207fdd4c3676d39f.best_config new file mode 100644 index 0000000000000000000000000000000000000000..37707241555f35a01f7e4a693e0cda27ae37aab0 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/pz/55d99bed8ef60879075dca9a24c03212b4ba7d3cb3730f76207fdd4c3676d39f.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 22, "triton_cache_hash": "XRR2QXTZQK4DSBTDJUTNXO6FEFXI2IIRKSC5GYSBWLTL56SKI4WA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/pz/cpz4qtb54zohukrmcagdai5zuh3utgu3eax4ghwvhtdsekz2cclm.py b/SpecForge-ext/cache/compiled_kernels/pz/cpz4qtb54zohukrmcagdai5zuh3utgu3eax4ghwvhtdsekz2cclm.py new file mode 100644 index 0000000000000000000000000000000000000000..11ec8e4274e12bf9fc5d1fb6592a999855dbb683 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/pz/cpz4qtb54zohukrmcagdai5zuh3utgu3eax4ghwvhtdsekz2cclm.py @@ -0,0 +1,62 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32, 'r0_': 32}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'in_ptr1': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr3'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2(in_ptr0, in_ptr1, out_ptr1, out_ptr2, out_ptr3, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x0), tmp5, xmask) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp6 = tl.load(in_ptr1 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8 < tmp5 + tmp10 = ks0 + tmp11 = tl.where(tmp9, tmp7, tmp10) + tmp12 = 1 + ks0 + tmp13 = tmp11 + tmp12 + tmp14 = tmp11 < 0 + tmp15 = tl.where(tmp14, tmp13, tmp11) + tl.device_assert(((0 <= tmp15) & (tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128)))) | ~(r0_mask & xmask), "index out of bounds: 0 <= tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128))") + tmp17 = tl.full([1, 1], 1, tl.int32) + tl.store(out_ptr2 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp7, r0_mask & xmask) + tl.store(out_ptr3 + (tl.broadcast_to(tmp15 + x0 + ks0*x0, [XBLOCK, R0_BLOCK])), tmp17, r0_mask & xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/pz/cpzw7g6yjflpctcqkzf5osq7m5acrctaysa6th3ox3deinxluypc.py b/SpecForge-ext/cache/compiled_kernels/pz/cpzw7g6yjflpctcqkzf5osq7m5acrctaysa6th3ox3deinxluypc.py new file mode 100644 index 0000000000000000000000000000000000000000..c4bbb6bc6403a9df114662385c1bb8a650e12795 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/pz/cpzw7g6yjflpctcqkzf5osq7m5acrctaysa6th3ox3deinxluypc.py @@ -0,0 +1,66 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4194304}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/r4/cr4vvf7u4h6ruefjt2n6mkf65qkivd25r7vnudgcyxoy5rribq72.py b/SpecForge-ext/cache/compiled_kernels/r4/cr4vvf7u4h6ruefjt2n6mkf65qkivd25r7vnudgcyxoy5rribq72.py new file mode 100644 index 0000000000000000000000000000000000000000..d62f83f433ceb1eea3e3ed8dcf0e5edca6517e4e --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/r4/cr4vvf7u4h6ruefjt2n6mkf65qkivd25r7vnudgcyxoy5rribq72.py @@ -0,0 +1,159 @@ +# AOT ID: ['1_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/2o/c2o3rt6yqwbdpzcj2uw7zb7tho7rvfo6bovdzzruianbb5qt4dsy.py +# Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax] +# Source node to ATen node mapping: +# target_head => convert_element_type +# target_p => div +# Graph fragment: +# %arg0_1 : Tensor "bf16[2, 2048, 32000][65536000, 32000, 1]cuda:0" = PlaceHolder[target=arg0_1] +# %getitem : Tensor "f32[2, 2048, 1][2048, 1, 4096]cuda:0" = PlaceHolder[target=getitem] +# %getitem_1 : Tensor "f32[2, 2048, 1][2048, 1, 4096]cuda:0" = PlaceHolder[target=getitem_1] +# %convert_element_type : Tensor "f32[2, 2048, 32000][65536000, 32000, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%arg0_1, torch.float32), kwargs = {}) +# %prepare_softmax_online_default : [num_users=2] = call_function[target=torch.ops.prims.prepare_softmax_online.default](args = (%convert_element_type, 2), kwargs = {}) +# %sub_tensor : Tensor "f32[2, 2048, 32000][65536000, 32000, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type, %getitem), kwargs = {}) +# %exp_default : Tensor "f32[2, 2048, 32000][65536000, 32000, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub_tensor,), kwargs = {}) +# %div : Tensor "f32[2, 2048, 32000][65536000, 32000, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%exp_default, %getitem_1), kwargs = {}) +# return %getitem,%getitem_1,%div +triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0 = async_compile.triton('triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'add_persistent_rblock': True, 'tiling_scores': {'x': 0, 'r0_': 1310720000}} +) +@triton.jit +def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 4096 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32) + _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + + _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine( + _tmp3_max, _tmp3_sum, tmp2, False + ) + + _tmp3_max = tl.where(r0_mask, _tmp3_max_next, _tmp3_max) + _tmp3_sum = tl.where(r0_mask, _tmp3_sum_next, _tmp3_sum) + + tmp3, tmp4 = triton_helpers.online_softmax_reduce( + _tmp3_max, _tmp3_sum, 1, False) + tmp3 = tmp3[:, None] + tmp4 = tmp4[:, None] + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp6 - tmp3 + tmp8 = libdevice.exp(tmp7) + tmp9 = (tmp8 / tmp4) + tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, = args + args.clear() + assert_size_stride(arg0_1, (2, 2048, 32000), (65536000, 32000, 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf2 = empty_strided_cuda((2, 2048, 32000), (65536000, 32000, 1), torch.float32) + # Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax] + stream0 = get_raw_stream(0) + triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0.run(arg0_1, buf2, 4096, 32000, stream=stream0) + del arg0_1 + return (buf2, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((2, 2048, 32000), (65536000, 32000, 1), device='cuda:0', dtype=torch.bfloat16) + fn = lambda: call([arg0_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/ra/crazrfho4yh2243lvlav5lo2lup4tmy2xzfys3yoapld26l4if6l.py b/SpecForge-ext/cache/compiled_kernels/ra/crazrfho4yh2243lvlav5lo2lup4tmy2xzfys3yoapld26l4if6l.py new file mode 100644 index 0000000000000000000000000000000000000000..fda7c3365474a8f7e648b6a369dd2728454281fe --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ra/crazrfho4yh2243lvlav5lo2lup4tmy2xzfys3yoapld26l4if6l.py @@ -0,0 +1,62 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32) + _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + + _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine( + _tmp3_max, _tmp3_sum, tmp2, False + ) + + _tmp3_max = tl.where(r0_mask & xmask, _tmp3_max_next, _tmp3_max) + _tmp3_sum = tl.where(r0_mask & xmask, _tmp3_sum_next, _tmp3_sum) + + tmp3, tmp4 = triton_helpers.online_softmax_reduce( + _tmp3_max, _tmp3_sum, 1, False) + tmp3 = tmp3[:, None] + tmp4 = tmp4[:, None] + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp6 - tmp3 + tmp8 = libdevice.exp(tmp7) + tmp9 = (tmp8 / tmp4) + tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask & xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/rl/crljmltx2ijltda364wepx5plo3nzy5cath6ikqwxj7wafhhskhq.py b/SpecForge-ext/cache/compiled_kernels/rl/crljmltx2ijltda364wepx5plo3nzy5cath6ikqwxj7wafhhskhq.py new file mode 100644 index 0000000000000000000000000000000000000000..54e67cd700a8aa96afcc6df0bac991888727a4ad --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/rl/crljmltx2ijltda364wepx5plo3nzy5cath6ikqwxj7wafhhskhq.py @@ -0,0 +1,89 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1(in_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % 16) + x2 = xindex // ks2 + _tmp36 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x5 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = (r0_index % 128) + r0_4 = r0_index // 128 + tmp0 = r0_3 + 128*x0 + tmp1 = ks1 + tmp2 = tmp0 < tmp1 + tmp3 = r0_4 + 128*x1 + tmp4 = r0_3 + 128*x0 + tmp5 = tmp3 >= tmp4 + tmp6 = tl.load(in_ptr0 + (tl.broadcast_to(x2, [XBLOCK, R0_BLOCK])), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp7 = tmp4 < tmp6 + tmp8 = tmp3 < tmp6 + tmp9 = tmp7 & tmp8 + tmp10 = tmp5 & tmp9 + tmp11 = tl.full([1, 1], False, tl.int1) + tmp12 = tmp11 | tmp10 + tmp13 = tl.full([1, 1], 2048, tl.int64) + tmp14 = tmp4 >= tmp13 + tmp15 = ((r0_3 + 128*x0) % 2048) + tmp16 = tmp15 < tmp6 + tmp17 = tmp14 & tmp16 + tmp18 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp19 = (tmp18 % tmp13) + tmp20 = tl.full([1, 1], 0, tl.int32) + tmp21 = tmp19 != tmp20 + tmp22 = (libdevice.signbit(tmp19) != 0) if (tmp19).dtype is tl.float32 else tmp19 < 0 + tmp23 = (libdevice.signbit(tmp13) != 0) if (tmp13).dtype is tl.float32 else tmp13 < 0 + tmp24 = tmp22 != tmp23 + tmp25 = tmp21 & tmp24 + tmp26 = tmp19 + tmp13 + tmp27 = tl.where(tmp25, tmp26, tmp19) + tmp28 = tl.full([1, 1], 0, tl.int64) + tmp29 = tmp27 == tmp28 + tmp30 = tmp17 & tmp29 + tmp31 = tmp12 | tmp30 + tmp32 = tl.full(tmp31.shape, False, tmp31.dtype) + tmp33 = tl.where(tmp2, tmp31, tmp32) + tmp34 = tmp33.to(tl.int64) + tmp35 = tl.broadcast_to(tmp34, [XBLOCK, R0_BLOCK]) + tmp37 = _tmp36 + tmp35 + _tmp36 = tl.where(r0_mask & xmask, tmp37, _tmp36) + tmp36 = tl.sum(_tmp36, 1)[:, None] + tmp38 = tl.full([1, 1], 0, tl.int64) + tmp39 = tmp36 > tmp38 + tmp40 = tl.full([1, 1], 16384, tl.int64) + tmp41 = tmp36 < tmp40 + tmp42 = tmp39 & tmp41 + tmp43 = tmp42.to(tl.int8) + tmp44 = tmp43.to(tl.int32) + tmp45 = tmp36 == tmp40 + tmp46 = tmp45.to(tl.int8) + tmp47 = tmp46.to(tl.int32) + tl.store(out_ptr1 + (x5), tmp44, xmask) + tl.store(out_ptr2 + (x5), tmp47, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/ry/cry6xskmafhujpmrarhzpcirnzkn5lbpatub2iu2sgelsuey6v2m.py b/SpecForge-ext/cache/compiled_kernels/ry/cry6xskmafhujpmrarhzpcirnzkn5lbpatub2iu2sgelsuey6v2m.py new file mode 100644 index 0000000000000000000000000000000000000000..14b7cf2c6f64f33975f2724338ad534b3c137497 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ry/cry6xskmafhujpmrarhzpcirnzkn5lbpatub2iu2sgelsuey6v2m.py @@ -0,0 +1,161 @@ +# AOT ID: ['11_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/v3/cv3ymxc5c4f2umbgpmwuvypoej63cmjx2dv22lsydctx55dk3ofc.py +# Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax] +# Source node to ATen node mapping: +# target_head => convert_element_type +# target_p => div +# Graph fragment: +# %arg1_1 : Tensor "bf16[8, s67, 32000][32000*s67, 32000, 1]cuda:7" = PlaceHolder[target=arg1_1] +# %getitem : Tensor "f32[8, s67, 1][s67, 1, 8*s67]cuda:7" = PlaceHolder[target=getitem] +# %getitem_1 : Tensor "f32[8, s67, 1][s67, 1, 8*s67]cuda:7" = PlaceHolder[target=getitem_1] +# %convert_element_type : Tensor "f32[8, s67, 32000][32000*s67, 32000, 1]cuda:7"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%arg1_1, torch.float32), kwargs = {}) +# %prepare_softmax_online_default : [num_users=2] = call_function[target=torch.ops.prims.prepare_softmax_online.default](args = (%convert_element_type, 2), kwargs = {}) +# %sub_tensor : Tensor "f32[8, s67, 32000][32000*s67, 32000, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type, %getitem), kwargs = {}) +# %exp_default : Tensor "f32[8, s67, 32000][32000*s67, 32000, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub_tensor,), kwargs = {}) +# %div : Tensor "f32[8, s67, 32000][32000*s67, 32000, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%exp_default, %getitem_1), kwargs = {}) +# return %getitem,%getitem_1,%div +triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0 = async_compile.triton('triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32) + _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + + _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine( + _tmp3_max, _tmp3_sum, tmp2, False + ) + + _tmp3_max = tl.where(r0_mask & xmask, _tmp3_max_next, _tmp3_max) + _tmp3_sum = tl.where(r0_mask & xmask, _tmp3_sum_next, _tmp3_sum) + + tmp3, tmp4 = triton_helpers.online_softmax_reduce( + _tmp3_max, _tmp3_sum, 1, False) + tmp3 = tmp3[:, None] + tmp4 = tmp4[:, None] + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp6 - tmp3 + tmp8 = libdevice.exp(tmp7) + tmp9 = (tmp8 / tmp4) + tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask & xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1 = args + args.clear() + s67 = arg0_1 + assert_size_stride(arg1_1, (8, s67, 32000), (32000*s67, 32000, 1)) + with torch.cuda._DeviceGuard(7): + torch.cuda.set_device(7) + buf2 = empty_strided_cuda((8, s67, 32000), (32000*s67, 32000, 1), torch.float32) + # Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax] + triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0_xnumel = 8*s67 + stream7 = get_raw_stream(7) + triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0.run(arg1_1, buf2, triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0_xnumel, 32000, stream=stream7) + del arg1_1 + return (buf2, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 2025 + arg1_1 = rand_strided((8, 2025, 32000), (64800000, 32000, 1), device='cuda:7', dtype=torch.bfloat16) + fn = lambda: call([arg0_1, arg1_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/ry/cryateojd5rygqtyncy4ns32mbyzvtsyctvjwifg7wtvvo63redp.py b/SpecForge-ext/cache/compiled_kernels/ry/cryateojd5rygqtyncy4ns32mbyzvtsyctvjwifg7wtvvo63redp.py new file mode 100644 index 0000000000000000000000000000000000000000..f184e4c93890c05147fc18166dbd86b6cf570577 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ry/cryateojd5rygqtyncy4ns32mbyzvtsyctvjwifg7wtvvo63redp.py @@ -0,0 +1,410 @@ +# AOT ID: ['7_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/vk/cvk2qr7hggrizog6osippdtnv4g54aa5mwpdaz7y5pik3awumasg.py +# Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax => argmax +# Graph fragment: +# %arg0_1 : Tensor "bf16[8, 2048, 32000][65536000, 32000, 1]cuda:5" = PlaceHolder[target=arg0_1] +# %argmax : Tensor "i64[8, 2048][2048, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg0_1, -1), kwargs = {}) +# return %argmax +triton_red_fused_argmax_0 = async_compile.triton('triton_red_fused_argmax_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 262144, 'r0_': 1048576000}} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/n2/cn2tg45k3x55ywffuxfwgqxig5mfgdvup54b54q7qi4axpjaewim.py +# Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] +# Source node to ATen node mapping: +# argmax_1 => argmax_1 +# Graph fragment: +# %arg1_1 : Tensor "f32[8, 2048, 32000][65760000, 32000, 1]cuda:5" = PlaceHolder[target=arg1_1] +# %argmax_1 : Tensor "i64[8, 2048][2048, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {}) +# return %argmax_1 +triton_red_fused_argmax_1 = async_compile.triton('triton_red_fused_argmax_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 262144, 'r0_': 2097152000}} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = xindex // 2048 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + 65760000*x1), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/r2/cr2iww7jjspoka2bx5iw3mn43pwv5zkxeytjhhcu24gkrw4iirtm.py +# Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum] +# Source node to ATen node mapping: +# eq => eq +# mul => mul +# squeeze => squeeze +# sum_1 => sum_1 +# Graph fragment: +# %argmax : Tensor "i64[8, 2048][2048, 1]cuda:5" = PlaceHolder[target=argmax] +# %argmax_1 : Tensor "i64[8, 2048][2048, 1]cuda:5" = PlaceHolder[target=argmax_1] +# %arg2_1 : Tensor "i64[8, 2048, 1][2048, 1, 1]cuda:5" = PlaceHolder[target=arg2_1] +# %eq : Tensor "b8[8, 2048][2048, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {}) +# %squeeze : Tensor "i64[8, 2048][2048, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg2_1, -1), kwargs = {}) +# %mul : Tensor "i64[8, 2048][2048, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq, %squeeze), kwargs = {}) +# %sum_1 : Tensor "i64[][]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul,), kwargs = {}) +# return %buf3 +triton_red_fused_eq_mul_squeeze_sum_2 = async_compile.triton('triton_red_fused_eq_mul_squeeze_sum_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2, 'r0_': 8192}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 8, 'r0_': 393216}} +) +@triton.jit +def triton_red_fused_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2 + r0_numel = 8192 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.load(in_ptr1 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp4 = tl.load(in_ptr2 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask & xmask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + tl.store(out_ptr0 + (x0), tmp7, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/tz/ctzvxt2zqre7xja5duo4xuecem3vcakaorj5xml5vin4wzkb3dir.py +# Topologically Sorted Source Nodes: [sum_2], Original ATen: [aten.sum] +# Source node to ATen node mapping: +# sum_2 => sum_2 +# Graph fragment: +# %arg3_1 : Tensor "i64[8, 2048, 1][2048, 1, 1]cuda:5" = PlaceHolder[target=arg3_1] +# %sum_2 : Tensor "i64[][]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg3_1,), kwargs = {}) +# return %buf5 +triton_red_fused_sum_3 = async_compile.triton('triton_red_fused_sum_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2, 'r0_': 8192}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 8, 'r0_': 131072}} +) +@triton.jit +def triton_red_fused_sum_3(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2 + r0_numel = 8192 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = _tmp2 + tmp1 + _tmp2 = tl.where(r0_mask & xmask, tmp3, _tmp2) + tmp2 = tl.sum(_tmp2, 1)[:, None] + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xi/cxii5gghycgmlug7rn243ynjs3m6ekb6f3nhv5rsxxx2dnbhs2vs.py +# Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1, sum_2, clamp_min, truediv], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum, aten.clamp_min, aten.div] +# Source node to ATen node mapping: +# clamp_min => clamp_min +# eq => eq +# mul => mul +# squeeze => squeeze +# sum_1 => sum_1 +# sum_2 => sum_2 +# truediv => div +# Graph fragment: +# %buf3 : Tensor "i64[2][1]cuda:5" = PlaceHolder[target=buf3] +# %buf5 : Tensor "i64[2][1]cuda:5" = PlaceHolder[target=buf5] +# %sum_1 : Tensor "i64[][]cuda:5" = PlaceHolder[target=sum_1] +# %sum_2 : Tensor "i64[][]cuda:5" = PlaceHolder[target=sum_2] +# %eq : Tensor "b8[8, 2048][2048, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.eq.Tensor](args = (%argmax, %argmax_1), kwargs = {}) +# %squeeze : Tensor "i64[8, 2048][2048, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%arg2_1, -1), kwargs = {}) +# %mul : Tensor "i64[8, 2048][2048, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%eq, %squeeze), kwargs = {}) +# %sum_1 : Tensor "i64[][]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul,), kwargs = {}) +# %sum_2 : Tensor "i64[][]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%arg3_1,), kwargs = {}) +# %clamp_min : Tensor "f32[][]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%sum_2, 1e-06), kwargs = {}) +# %div : Tensor "f32[][]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_1, %clamp_min), kwargs = {}) +# return %sum_1,%sum_2,%div +triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4 = async_compile.triton('triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 1, 'r0_': 2}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'out_ptr2': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'r0_': 8}} +) +@triton.jit +def triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4(in_ptr0, in_ptr1, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 1 + r0_numel = 2 + R0_BLOCK: tl.constexpr = 2 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), None) + tmp4 = tl.load(in_ptr1 + (r0_0), None) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.sum(tmp1, 1)[:, None].to(tl.int64) + tmp5 = tl.broadcast_to(tmp4, [XBLOCK, R0_BLOCK]) + tmp7 = tl.sum(tmp5, 1)[:, None].to(tl.int64) + tmp8 = tmp3.to(tl.float32) + tmp9 = tmp7.to(tl.float32) + tmp10 = 1e-06 + tmp11 = triton_helpers.maximum(tmp9, tmp10) + tmp12 = (tmp8 / tmp11) + tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp12, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1 = args + args.clear() + assert_size_stride(arg0_1, (8, 2048, 32000), (65536000, 32000, 1)) + assert_size_stride(arg1_1, (8, 2048, 32000), (65760000, 32000, 1)) + assert_size_stride(arg2_1, (8, 2048, 1), (2048, 1, 1)) + assert_size_stride(arg3_1, (8, 2048, 1), (2048, 1, 1)) + with torch.cuda._DeviceGuard(5): + torch.cuda.set_device(5) + buf0 = empty_strided_cuda((8, 2048), (2048, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax], Original ATen: [aten.argmax] + stream5 = get_raw_stream(5) + triton_red_fused_argmax_0.run(arg0_1, buf0, 16384, 32000, stream=stream5) + del arg0_1 + buf1 = empty_strided_cuda((8, 2048), (2048, 1), torch.int64) + # Topologically Sorted Source Nodes: [argmax_1], Original ATen: [aten.argmax] + stream5 = get_raw_stream(5) + triton_red_fused_argmax_1.run(arg1_1, buf1, 16384, 32000, stream=stream5) + del arg1_1 + buf3 = empty_strided_cuda((2, ), (1, ), torch.int64) + # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum] + stream5 = get_raw_stream(5) + triton_red_fused_eq_mul_squeeze_sum_2.run(buf0, buf1, arg2_1, buf3, 2, 8192, stream=stream5) + del arg2_1 + del buf0 + del buf1 + buf5 = empty_strided_cuda((2, ), (1, ), torch.int64) + # Topologically Sorted Source Nodes: [sum_2], Original ATen: [aten.sum] + stream5 = get_raw_stream(5) + triton_red_fused_sum_3.run(arg3_1, buf5, 2, 8192, stream=stream5) + del arg3_1 + buf7 = empty_strided_cuda((), (), torch.float32) + # Topologically Sorted Source Nodes: [eq, squeeze, mul, sum_1, sum_2, clamp_min, truediv], Original ATen: [aten.eq, aten.squeeze, aten.mul, aten.sum, aten.clamp_min, aten.div] + stream5 = get_raw_stream(5) + triton_per_fused_clamp_min_div_eq_mul_squeeze_sum_4.run(buf3, buf5, buf7, 1, 2, stream=stream5) + del buf3 + del buf5 + return (buf7, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((8, 2048, 32000), (65536000, 32000, 1), device='cuda:5', dtype=torch.bfloat16) + arg1_1 = rand_strided((8, 2048, 32000), (65760000, 32000, 1), device='cuda:5', dtype=torch.float32) + arg2_1 = rand_strided((8, 2048, 1), (2048, 1, 1), device='cuda:5', dtype=torch.int64) + arg3_1 = rand_strided((8, 2048, 1), (2048, 1, 1), device='cuda:5', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/ry/crybdkaospobmiqlxc56pxoib5eehua75mh3od3mgjz4754h54wu.py b/SpecForge-ext/cache/compiled_kernels/ry/crybdkaospobmiqlxc56pxoib5eehua75mh3od3mgjz4754h54wu.py new file mode 100644 index 0000000000000000000000000000000000000000..118d534236af011ceb577133faf2740a72fce725 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ry/crybdkaospobmiqlxc56pxoib5eehua75mh3od3mgjz4754h54wu.py @@ -0,0 +1,47 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(1,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + ks1*x1), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/sg/csg2r6gcuw5453tnkx7v65zysasesetlrx733ekbslnhgjntjrkm.py b/SpecForge-ext/cache/compiled_kernels/sg/csg2r6gcuw5453tnkx7v65zysasesetlrx733ekbslnhgjntjrkm.py new file mode 100644 index 0000000000000000000000000000000000000000..3d67d5b86bc10aa9f539d67efda9616a2fe6d3e5 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/sg/csg2r6gcuw5453tnkx7v65zysasesetlrx733ekbslnhgjntjrkm.py @@ -0,0 +1,72 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp8 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr2 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tmp2.to(tl.float32) + tmp5 = tmp4.to(tl.float32) + tmp6 = tmp3 * tmp5 + tmp7 = tl.broadcast_to(tmp6, [XBLOCK, R0_BLOCK]) + tmp9 = _tmp8 + tmp7 + _tmp8 = tl.where(r0_mask & xmask, tmp9, _tmp8) + tmp8 = tl.sum(_tmp8, 1)[:, None] + tmp14 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last') + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp10 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp11 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp24 = tl.load(in_ptr2 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp12 = tmp10 * tmp11 + tmp13 = tmp12.to(tl.float32) + tmp15 = tmp13 * tmp14 + tmp16 = -0.5 + tmp17 = tmp8 * tmp16 + tmp18 = tmp14 * tmp14 + tmp19 = tmp18 * tmp14 + tmp20 = tmp17 * tmp19 + tmp21 = ks0 + tmp22 = tmp21.to(tl.float32) + tmp23 = (tmp20 / tmp22) + tmp25 = tmp24.to(tl.float32) + tmp26 = 2.0 + tmp27 = tmp25 * tmp26 + tmp28 = tmp23 * tmp27 + tmp29 = tmp15 + tmp28 + tmp30 = tmp29.to(tl.float32) + tl.store(out_ptr1 + (r0_1 + ks0*x0), tmp30, r0_mask & xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/sw/cswtxydzgkh7pwlz42g4fmuptaeirf547qk2reulkxqvgklibkky.py b/SpecForge-ext/cache/compiled_kernels/sw/cswtxydzgkh7pwlz42g4fmuptaeirf547qk2reulkxqvgklibkky.py new file mode 100644 index 0000000000000000000000000000000000000000..4b4e5d002be9b066b73ef373cacc0d01ac47e2e1 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/sw/cswtxydzgkh7pwlz42g4fmuptaeirf547qk2reulkxqvgklibkky.py @@ -0,0 +1,49 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 256, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % ks0) + x1 = xindex // ks0 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (r0_2 + x0 + 16*x1 + ks0*r0_2 + 16*ks0*x1), xmask, eviction_policy='evict_last', other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x0 + 16*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp13, xmask) + tl.store(out_ptr3 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp14, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/sw/eb52fdbfd3cb014c353110a372bf31e8bcb295c12a02d7a92f55a1b12e20241f.best_config b/SpecForge-ext/cache/compiled_kernels/sw/eb52fdbfd3cb014c353110a372bf31e8bcb295c12a02d7a92f55a1b12e20241f.best_config new file mode 100644 index 0000000000000000000000000000000000000000..73d39cec03a4913ffd38deb7ad038bf56b5cd33f --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/sw/eb52fdbfd3cb014c353110a372bf31e8bcb295c12a02d7a92f55a1b12e20241f.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "b6ac5ef64fddcad8fc8d2c05fa12424871fd9baa5a4158ff38ecebbafb55a4b1", "found_by_coordesc": false, "time_taken_ms": 25, "triton_cache_hash": "G2LU7LIHIOEHQSWVLFBJATACJ76YHM672CUBUDGJGAJUEQVWVOFQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/sy/csynkjvbhibjhhcz34qs745sszarr6ouqgojjblcyc4xy63ctige.py b/SpecForge-ext/cache/compiled_kernels/sy/csynkjvbhibjhhcz34qs745sszarr6ouqgojjblcyc4xy63ctige.py new file mode 100644 index 0000000000000000000000000000000000000000..d081cf7312c20726d3d2980b25b14e6bc4a206a3 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/sy/csynkjvbhibjhhcz34qs745sszarr6ouqgojjblcyc4xy63ctige.py @@ -0,0 +1,89 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1024, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1(in_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % 16) + x2 = xindex // ks2 + _tmp36 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x5 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = (r0_index % 128) + r0_4 = r0_index // 128 + tmp0 = r0_3 + 128*x0 + tmp1 = ks1 + tmp2 = tmp0 < tmp1 + tmp3 = r0_4 + 128*x1 + tmp4 = r0_3 + 128*x0 + tmp5 = tmp3 >= tmp4 + tmp6 = tl.load(in_ptr0 + (tl.broadcast_to(x2, [XBLOCK, R0_BLOCK])), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp7 = tmp4 < tmp6 + tmp8 = tmp3 < tmp6 + tmp9 = tmp7 & tmp8 + tmp10 = tmp5 & tmp9 + tmp11 = tl.full([1, 1], False, tl.int1) + tmp12 = tmp11 | tmp10 + tmp13 = tl.full([1, 1], 2048, tl.int64) + tmp14 = tmp4 >= tmp13 + tmp15 = ((r0_3 + 128*x0) % 2048) + tmp16 = tmp15 < tmp6 + tmp17 = tmp14 & tmp16 + tmp18 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp19 = (tmp18 % tmp13) + tmp20 = tl.full([1, 1], 0, tl.int32) + tmp21 = tmp19 != tmp20 + tmp22 = (libdevice.signbit(tmp19) != 0) if (tmp19).dtype is tl.float32 else tmp19 < 0 + tmp23 = (libdevice.signbit(tmp13) != 0) if (tmp13).dtype is tl.float32 else tmp13 < 0 + tmp24 = tmp22 != tmp23 + tmp25 = tmp21 & tmp24 + tmp26 = tmp19 + tmp13 + tmp27 = tl.where(tmp25, tmp26, tmp19) + tmp28 = tl.full([1, 1], 0, tl.int64) + tmp29 = tmp27 == tmp28 + tmp30 = tmp17 & tmp29 + tmp31 = tmp12 | tmp30 + tmp32 = tl.full(tmp31.shape, False, tmp31.dtype) + tmp33 = tl.where(tmp2, tmp31, tmp32) + tmp34 = tmp33.to(tl.int64) + tmp35 = tl.broadcast_to(tmp34, [XBLOCK, R0_BLOCK]) + tmp37 = _tmp36 + tmp35 + _tmp36 = tl.where(r0_mask & xmask, tmp37, _tmp36) + tmp36 = tl.sum(_tmp36, 1)[:, None] + tmp38 = tl.full([1, 1], 0, tl.int64) + tmp39 = tmp36 > tmp38 + tmp40 = tl.full([1, 1], 16384, tl.int64) + tmp41 = tmp36 < tmp40 + tmp42 = tmp39 & tmp41 + tmp43 = tmp42.to(tl.int8) + tmp44 = tmp43.to(tl.int32) + tmp45 = tmp36 == tmp40 + tmp46 = tmp45.to(tl.int8) + tmp47 = tmp46.to(tl.int32) + tl.store(out_ptr1 + (x5), tmp44, xmask) + tl.store(out_ptr2 + (x5), tmp47, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/t3/ct3fq5xthzdec3x2zy2wlma26maa6ruf2v2tzswhoos342jsn7ev.py b/SpecForge-ext/cache/compiled_kernels/t3/ct3fq5xthzdec3x2zy2wlma26maa6ruf2v2tzswhoos342jsn7ev.py new file mode 100644 index 0000000000000000000000000000000000000000..658f89e01db5f2cf1d96178118b1b50697b08a6e --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/t3/ct3fq5xthzdec3x2zy2wlma26maa6ruf2v2tzswhoos342jsn7ev.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32', 'ks8': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128*ks1, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 2 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks8 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = ks8 + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/t6/ct65mho34hb6uiko5rqcirrhdoklxtx3wge4z7y6klsukdtxu23g.py b/SpecForge-ext/cache/compiled_kernels/t6/ct65mho34hb6uiko5rqcirrhdoklxtx3wge4z7y6klsukdtxu23g.py new file mode 100644 index 0000000000000000000000000000000000000000..cf7b49f3b406ffc78b5f7655030efcb7d5e274ef --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/t6/ct65mho34hb6uiko5rqcirrhdoklxtx3wge4z7y6klsukdtxu23g.py @@ -0,0 +1,44 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'out_ptr0': '*i64', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 1 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp4 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp7, None) diff --git a/SpecForge-ext/cache/compiled_kernels/tc/ctcllk4e7sex44xmd2e5ern3qyza6xidctepyhutntdvvidiui4e.py b/SpecForge-ext/cache/compiled_kernels/tc/ctcllk4e7sex44xmd2e5ern3qyza6xidctepyhutntdvvidiui4e.py new file mode 100644 index 0000000000000000000000000000000000000000..5e1b3b32f4d474a721ebd9b5f0780c21cebd40cc --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/tc/ctcllk4e7sex44xmd2e5ern3qyza6xidctepyhutntdvvidiui4e.py @@ -0,0 +1,1065 @@ +# AOT ID: ['9_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/og/cogol55cthk4zevsy3dlqiyzipefv735ge2wtaddvk436qht5nox.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:6" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 262144, 128, 1]cuda:6" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[2, 32, 2048][65536, 2048, 1]cuda:6" = PlaceHolder[target=buf0] +# %full_default : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:6, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_3, %primals_5, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, %primals_8, %primals_9, %primals_7, %primals_11, %primals_13, %primals_15, %primals_17, %primals_19, %primals_21, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_10,)), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_zeros_0 = async_compile.triton('triton_red_fused_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 131072, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 1048576, 'r0_': 67108864}} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 131072 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = ((xindex // 2048) % 32) + x2 = xindex // 65536 + x4 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/yv/cyvkuxxhls3flmclcp72a4iqymktf2npz6j2o3p7iy2usfvd3izr.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:6" = PlaceHolder[target=primals_1] +# %primals_3 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:6" = PlaceHolder[target=primals_3] +# %primals_5 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:6" = PlaceHolder[target=primals_5] +# %getitem_1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:6" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:6" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 262144, 128, 1]cuda:6" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:6" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:6" = PlaceHolder[target=getitem_5] +# %primals_9 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:6" = PlaceHolder[target=primals_9] +# %primals_7 : Tensor "i32[2, 1, 16, s72][16*s72, 16*s72, s72, 1]cuda:6" = PlaceHolder[target=primals_7] +# %primals_15 : Tensor "i32[2, 1, s56][s56, s56, 1]cuda:6" = PlaceHolder[target=primals_15] +# %primals_17 : Tensor "i32[2, 1, s84, 16][16*s84, 16*s84, 16, 1]cuda:6" = PlaceHolder[target=primals_17] +# %primals_11 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:6" = PlaceHolder[target=primals_11] +# %primals_13 : Tensor "i32[2, 1, 16, s4][16*s4, 16*s4, s4, 1]cuda:6" = PlaceHolder[target=primals_13] +# %primals_19 : Tensor "i32[2, 1, s99][s99, s99, 1]cuda:6" = PlaceHolder[target=primals_19] +# %primals_21 : Tensor "i32[2, 1, s6, 16][16*s6, 16*s6, 16, 1]cuda:6" = PlaceHolder[target=primals_21] +# %primals_10 : Tensor "i64[2][1]cuda:6" = PlaceHolder[target=primals_10] +# %full_default : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:6, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_3, %primals_5, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, %primals_8, %primals_9, %primals_7, %primals_11, %primals_13, %primals_15, %primals_17, %primals_19, %primals_21, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_10,)), kwargs = {}) +# return %getitem_4 +triton_tem_fused_zeros_1 = async_compile.triton('triton_tem_fused_zeros_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks2 + stride_q_idx_h = 16*ks3 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_8, primals_6, primals_12, primals_14, primals_16, primals_18, primals_20, primals_1, primals_3, primals_5, primals_7, primals_9, primals_10, primals_11, primals_13, primals_15, primals_17, primals_19, primals_21, getitem, getitem_1, tangents_1 = args + args.clear() + s0 = primals_8 + s72 = primals_6 + s4 = primals_12 + s56 = primals_14 + s84 = primals_16 + s99 = primals_18 + s6 = primals_20 + assert_size_stride(primals_1, (2, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_3, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_5, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_7, (2, 1, 16, s72), (16*s72, 16*s72, s72, 1)) + assert_size_stride(primals_9, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (2, ), (1, )) + assert_size_stride(primals_11, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_13, (2, 1, 16, s4), (16*s4, 16*s4, s4, 1)) + assert_size_stride(primals_15, (2, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_17, (2, 1, s84, 16), (16*s84, 16*s84, 16, 1)) + assert_size_stride(primals_19, (2, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_21, (2, 1, s6, 16), (16*s6, 16*s6, 16, 1)) + assert_size_stride(getitem, (2, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(getitem_1, (2, 32, 2048), (65536, 2048, 1)) + assert_size_stride(tangents_1, (2, 32, 2048, 128), (8388608, 262144, 128, 1)) + with torch.cuda._DeviceGuard(6): + torch.cuda.set_device(6) + buf1 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream6 = get_raw_stream(6) + triton_red_fused_zeros_0.run(getitem, tangents_1, buf1, 131072, 128, stream=stream6) + del getitem + buf3 = empty_strided_cuda((2, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((2, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16) + buf5 = empty_strided_cuda((2, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream6 = get_raw_stream(6) + triton_tem_fused_zeros_1.run(primals_1, primals_3, primals_5, getitem_1, buf1, tangents_1, buf3, buf4, primals_9, primals_7, primals_15, primals_17, primals_11, primals_13, primals_19, primals_21, primals_10, buf5, s0, s72, s56, s84, 64 + ((127 + s0) // 128), 2, 8, stream=stream6) + del buf1 + del getitem_1 + del primals_1 + del primals_10 + del primals_11 + del primals_13 + del primals_15 + del primals_17 + del primals_19 + del primals_21 + del primals_3 + del primals_5 + del primals_7 + del primals_9 + del tangents_1 + return (buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_8 = 4096 + primals_6 = 32 + primals_12 = 32 + primals_14 = 32 + primals_16 = 32 + primals_18 = 32 + primals_20 = 32 + primals_1 = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:6', dtype=torch.bfloat16) + primals_3 = rand_strided((2, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:6', dtype=torch.bfloat16) + primals_5 = rand_strided((2, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:6', dtype=torch.bfloat16) + primals_7 = rand_strided((2, 1, 16, 32), (512, 512, 32, 1), device='cuda:6', dtype=torch.int32) + primals_9 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:6', dtype=torch.int32) + primals_10 = rand_strided((2, ), (1, ), device='cuda:6', dtype=torch.int64) + primals_11 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:6', dtype=torch.int32) + primals_13 = rand_strided((2, 1, 16, 32), (512, 512, 32, 1), device='cuda:6', dtype=torch.int32) + primals_15 = rand_strided((2, 1, 32), (32, 32, 1), device='cuda:6', dtype=torch.int32) + primals_17 = rand_strided((2, 1, 32, 16), (512, 512, 16, 1), device='cuda:6', dtype=torch.int32) + primals_19 = rand_strided((2, 1, 32), (32, 32, 1), device='cuda:6', dtype=torch.int32) + primals_21 = rand_strided((2, 1, 32, 16), (512, 512, 16, 1), device='cuda:6', dtype=torch.int32) + getitem = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:6', dtype=torch.bfloat16) + getitem_1 = rand_strided((2, 32, 2048), (65536, 2048, 1), device='cuda:6', dtype=torch.float32) + tangents_1 = rand_strided((2, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:6', dtype=torch.bfloat16) + fn = lambda: call([primals_8, primals_6, primals_12, primals_14, primals_16, primals_18, primals_20, primals_1, primals_3, primals_5, primals_7, primals_9, primals_10, primals_11, primals_13, primals_15, primals_17, primals_19, primals_21, getitem, getitem_1, tangents_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/th/05467070bad193b08424e57f7570b81029d321ce18b2b13d7a296a1b4c7ad48e.best_config b/SpecForge-ext/cache/compiled_kernels/th/05467070bad193b08424e57f7570b81029d321ce18b2b13d7a296a1b4c7ad48e.best_config new file mode 100644 index 0000000000000000000000000000000000000000..990be040d913054ee650201b25cf2c95af882efd --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/th/05467070bad193b08424e57f7570b81029d321ce18b2b13d7a296a1b4c7ad48e.best_config @@ -0,0 +1 @@ +{"XBLOCK": 8, "R0_BLOCK": 128, "num_warps": 8, "num_stages": 1, "configs_hash": "b70837e3723f218c7368cc2b49566dcd2bec3baf4c88b5e174a3f0822a6c86c0", "found_by_coordesc": false, "time_taken_ms": 142, "triton_cache_hash": "BZ2FPB5QIE7EHR6P7EPVPHR4HKS3YX3QQPIWQIT2R3EOJOAVWCGA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/th/cthv5zc2es46ngo2febwflavdqzw5qdaig35rrejlvqiistqzhbc.py b/SpecForge-ext/cache/compiled_kernels/th/cthv5zc2es46ngo2febwflavdqzw5qdaig35rrejlvqiistqzhbc.py new file mode 100644 index 0000000000000000000000000000000000000000..01a6b0c18af0600fcb2808388e83922639436908 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/th/cthv5zc2es46ngo2febwflavdqzw5qdaig35rrejlvqiistqzhbc.py @@ -0,0 +1,49 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 524288, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 4194304, 'r0_': 268435456}} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 524288 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = ((xindex // 2048) % 32) + x2 = xindex // 65536 + x4 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, None) diff --git a/SpecForge-ext/cache/compiled_kernels/tk/ctkkhsglawnx4x5vzgiwopculwsuesf4qmzv2nk4aijxe5spueaw.py b/SpecForge-ext/cache/compiled_kernels/tk/ctkkhsglawnx4x5vzgiwopculwsuesf4qmzv2nk4aijxe5spueaw.py new file mode 100644 index 0000000000000000000000000000000000000000..7d93247fb5086496533acf0e17cecaa7b4c60413 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/tk/ctkkhsglawnx4x5vzgiwopculwsuesf4qmzv2nk4aijxe5spueaw.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks2 + stride_q_idx_h = 16*ks3 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/3TSRCAOXIABNSCY74AQ2GK4AIDHEKNRWVMFQVMQJIJ7C44WYAPYA/triton_red_fused_argmax_1.json b/SpecForge-ext/cache/compiled_kernels/triton/0/3TSRCAOXIABNSCY74AQ2GK4AIDHEKNRWVMFQVMQJIJ7C44WYAPYA/triton_red_fused_argmax_1.json new file mode 100644 index 0000000000000000000000000000000000000000..949d73949091e0aa1e18783929511149e2106e7b --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/3TSRCAOXIABNSCY74AQ2GK4AIDHEKNRWVMFQVMQJIJ7C44WYAPYA/triton_red_fused_argmax_1.json @@ -0,0 +1 @@ +{"hash": "dce51101d74002d90b1fe021a32b8040ce453636ab0b0ab209427e2e72d803f0", "target": {"backend": "cuda", "arch": 90, "warp_size": 32}, "num_warps": 16, "num_ctas": 1, "num_stages": 1, "warp_size": 32, "maxnreg": null, "cluster_dims": [1, 1, 1], "ptx_version": null, "ptx_options": null, "ir_override": null, "enable_fp_fusion": true, "launch_cooperative_grid": false, "launch_pdl": false, "supported_fp8_dtypes": ["fp8e4b15", "fp8e4nv", "fp8e5"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b15"], "default_dot_input_precision": "tf32", "allowed_dot_input_precisions": ["tf32", "tf32x3", "ieee"], "max_num_imprecise_acc_default": 1073741824, "extern_libs": [["libdevice", "/workspace/specforge/lib/python3.11/site-packages/triton/backends/nvidia/lib/libdevice.10.bc"]], "debug": true, "backend_name": "cuda", "sanitize_overflow": false, "arch": "sm90", "instrumentation_mode": "", "triton_version": "3.5.1", "tensordesc_meta": [], "shared": 256, "tmem_size": 0, "global_scratch_size": 0, "global_scratch_align": 1, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "triton_red_fused_argmax_1"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/4ODNRDUZZQ6DC76SMZ2XG44NLENCYGAAAUTCAR2KZZILJTJU67XA/triton_red_fused_zeros_0.ptx b/SpecForge-ext/cache/compiled_kernels/triton/0/4ODNRDUZZQ6DC76SMZ2XG44NLENCYGAAAUTCAR2KZZILJTJU67XA/triton_red_fused_zeros_0.ptx new file mode 100644 index 0000000000000000000000000000000000000000..215739e267481e06dcac4f21bf8246ef3ff89f17 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/4ODNRDUZZQ6DC76SMZ2XG44NLENCYGAAAUTCAR2KZZILJTJU67XA/triton_red_fused_zeros_0.ptx @@ -0,0 +1,412 @@ +// +// Generated by LLVM NVPTX Back-End +// + +.version 8.7 +.target sm_90a +.address_size 64 + + // .globl triton_red_fused_zeros_0 // -- Begin function triton_red_fused_zeros_0 +.extern .shared .align 16 .b8 global_smem[]; + // @triton_red_fused_zeros_0 +.visible .entry triton_red_fused_zeros_0( + .param .u64 .ptr .global .align 1 triton_red_fused_zeros_0_param_0, + .param .u64 .ptr .global .align 1 triton_red_fused_zeros_0_param_1, + .param .u64 .ptr .global .align 1 triton_red_fused_zeros_0_param_2, + .param .u32 triton_red_fused_zeros_0_param_3, + .param .u32 triton_red_fused_zeros_0_param_4, + .param .u64 .ptr .global .align 1 triton_red_fused_zeros_0_param_5, + .param .u64 .ptr .global .align 1 triton_red_fused_zeros_0_param_6 +) +.reqntid 128 +{ + .reg .pred %p<4>; + .reg .b16 %rs<9>; + .reg .b32 %r<72>; + .reg .b64 %rd<11>; + .loc 1 18 0 // cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:18:0 +$L__func_begin0: + .loc 1 18 0 // cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:18:0 + +// %bb.0: + ld.param.b64 %rd8, [triton_red_fused_zeros_0_param_0]; + ld.param.b64 %rd9, [triton_red_fused_zeros_0_param_1]; +$L__tmp0: + .loc 1 23 28 // cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:23:28 + mov.u32 %r10, %ctaid.x; + .loc 1 23 33 // cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:23:33 + shl.b32 %r11, %r10, 2; + ld.param.b64 %rd10, [triton_red_fused_zeros_0_param_2]; + .loc 1 24 44 // cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:24:44 + mov.u32 %r12, %tid.x; + and.b32 %r13, %r12, 96; + bfe.u32 %r14, %r12, 5, 2; + and.b32 %r15, %r12, 3; + .loc 1 24 23 // cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:24:23 + or.b32 %r16, %r14, %r11; + or.b32 %r17, %r11, %r15; + .loc 1 26 37 // cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:26:37 + shl.b32 %r18, %r12, 2; + and.b32 %r19, %r18, 124; + .loc 1 29 21 // cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:29:21 + bfe.s32 %r20, %r10, 29, 1; + shr.u32 %r21, %r20, 21; + add.s32 %r22, %r16, %r21; + shr.s32 %r23, %r22, 11; + .loc 1 28 19 // cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:28:19 + and.b32 %r24, %r22, 1046528; + sub.s32 %r25, %r16, %r24; + .loc 1 29 29 // cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:29:29 + shr.u32 %r26, %r23, 27; + add.s32 %r27, %r23, %r26; + and.b32 %r28, %r27, 33554400; + sub.s32 %r29, %r23, %r28; + .loc 1 30 19 // cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:30:19 + shr.u32 %r30, %r20, 16; + add.s32 %r31, %r16, %r30; + .loc 1 39 45 // cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:39:45 + shl.b32 %r32, %r29, 7; + .loc 1 39 55 // cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:39:55 + shl.b32 %r33, %r25, 12; + .loc 1 39 68 // cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:39:68 + shl.b32 %r34, %r31, 7; + and.b32 %r35, %r34, -8388608; + .loc 1 39 41 // cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:39:41 + or.b32 %r36, %r33, %r19; + .loc 1 39 50 // cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:39:50 + add.s32 %r37, %r36, %r35; + .loc 1 39 60 // cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:39:60 + add.s32 %r38, %r37, %r32; + .loc 1 39 34 // cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:39:34 + mad.wide.s32 %rd2, %r38, 2, %rd8; + .loc 1 39 73 // cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:39:73 + // begin inline asm + mov.u64 %rd3, 0x0; + createpolicy.fractional.L2::evict_first.b64 %rd3, 1.0; + // end inline asm + mov.b32 %r3, 0; + mov.pred %p1, -1; + // begin inline asm + mov.u32 %r1, %r3; + mov.u32 %r2, %r3; + @%p1 ld.global.L1::evict_first.L2::cache_hint.v2.b32 { %r1, %r2 }, [ %rd2 + 0 ], %rd3; + // end inline asm + .loc 1 40 45 // cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:40:45 + shl.b32 %r39, %r16, 7; + .loc 1 40 41 // cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:40:41 + or.b32 %r40, %r39, %r19; + .loc 1 40 34 // cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:40:34 + mad.wide.s32 %rd5, %r40, 2, %rd9; + .loc 1 40 50 // cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:40:50 + // begin inline asm + mov.u64 %rd6, 0x0; + createpolicy.fractional.L2::evict_first.b64 %rd6, 1.0; + // end inline asm + // begin inline asm + mov.u32 %r5, %r3; + mov.u32 %r6, %r3; + @%p1 ld.global.L1::evict_first.L2::cache_hint.v2.b32 { %r5, %r6 }, [ %rd5 + 0 ], %rd6; + // end inline asm + .loc 1 39 127 // cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:39:127 + mov.b32 {%rs1, %rs2}, %r1; + cvt.f32.bf16 %r41, %rs1; + cvt.f32.bf16 %r42, %rs2; + .loc 1 40 104 // cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:40:104 + mov.b32 {%rs3, %rs4}, %r5; + cvt.f32.bf16 %r43, %rs3; + cvt.f32.bf16 %r44, %rs4; + .loc 1 43 23 // cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:43:23 + fma.rn.f32 %r45, %r42, %r44, 0f00000000; + fma.rn.f32 %r46, %r41, %r43, 0f00000000; + .loc 1 39 127 // cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:39:127 + mov.b32 {%rs5, %rs6}, %r2; + cvt.f32.bf16 %r47, %rs5; + cvt.f32.bf16 %r48, %rs6; + .loc 1 40 104 // cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:40:104 + mov.b32 {%rs7, %rs8}, %r6; + cvt.f32.bf16 %r49, %rs7; + cvt.f32.bf16 %r50, %rs8; + .loc 1 43 23 // cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:43:23 + fma.rn.f32 %r51, %r48, %r50, 0f00000000; + fma.rn.f32 %r52, %r47, %r49, 0f00000000; +$L__tmp1: + .loc 2 261 15 // standard.py:261:15 @[ cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:45:25 ] + add.f32 %r53, %r46, %r45; + add.f32 %r54, %r52, %r53; + add.f32 %r55, %r51, %r54; + .loc 2 291 36 // standard.py:291:36 @[ cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:45:25 ] + shfl.sync.bfly.b32 %r56, %r55, 16, 31, -1; + .loc 2 261 15 // standard.py:261:15 @[ cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:45:25 ] + add.f32 %r57, %r55, %r56; + .loc 2 291 36 // standard.py:291:36 @[ cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:45:25 ] + shfl.sync.bfly.b32 %r58, %r57, 8, 31, -1; + .loc 2 261 15 // standard.py:261:15 @[ cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:45:25 ] + add.f32 %r59, %r57, %r58; + .loc 2 291 36 // standard.py:291:36 @[ cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:45:25 ] + shfl.sync.bfly.b32 %r60, %r59, 4, 31, -1; + .loc 2 261 15 // standard.py:261:15 @[ cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:45:25 ] + add.f32 %r61, %r59, %r60; + .loc 2 291 36 // standard.py:291:36 @[ cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:45:25 ] + shfl.sync.bfly.b32 %r62, %r61, 2, 31, -1; + .loc 2 261 15 // standard.py:261:15 @[ cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:45:25 ] + add.f32 %r63, %r61, %r62; + .loc 2 291 36 // standard.py:291:36 @[ cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:45:25 ] + shfl.sync.bfly.b32 %r64, %r63, 1, 31, -1; + .loc 2 261 15 // standard.py:261:15 @[ cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:45:25 ] + add.f32 %r65, %r63, %r64; +$L__tmp2: + .loc 1 45 28 // cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:45:28 + shr.u32 %r66, %r13, 3; + mov.b32 %r67, global_smem; + add.s32 %r68, %r67, %r66; + st.shared.b32 [%r68], %r65; + bar.sync 0; + shl.b32 %r69, %r15, 2; + add.s32 %r70, %r67, %r69; + ld.shared.b32 %r9, [%r70]; + .loc 1 49 25 // cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:49:25 + mad.wide.s32 %rd7, %r17, 4, %rd10; + .loc 1 49 36 // cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:49:36 + and.b32 %r71, %r12, 124; + setp.eq.b32 %p3, %r71, 0; + // begin inline asm + @%p3 st.global.b32 [ %rd7 + 0 ], { %r9 }; + // end inline asm + .loc 1 49 4 // cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py:49:4 + ret; +$L__tmp3: +$L__func_end0: + // -- End function +} + .file 1 "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py" + .file 2 "/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py" + .section .debug_abbrev + { +.b8 1 // Abbreviation Code +.b8 17 // DW_TAG_compile_unit +.b8 1 // DW_CHILDREN_yes +.b8 37 // DW_AT_producer +.b8 8 // DW_FORM_string +.b8 19 // DW_AT_language +.b8 5 // DW_FORM_data2 +.b8 3 // DW_AT_name +.b8 8 // DW_FORM_string +.b8 16 // DW_AT_stmt_list +.b8 6 // DW_FORM_data4 +.b8 27 // DW_AT_comp_dir +.b8 8 // DW_FORM_string +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 2 // Abbreviation Code +.b8 46 // DW_TAG_subprogram +.b8 0 // DW_CHILDREN_no +.b8 3 // DW_AT_name +.b8 8 // DW_FORM_string +.b8 32 // DW_AT_inline +.b8 11 // DW_FORM_data1 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 3 // Abbreviation Code +.b8 46 // DW_TAG_subprogram +.b8 1 // DW_CHILDREN_yes +.b8 17 // DW_AT_low_pc +.b8 1 // DW_FORM_addr +.b8 18 // DW_AT_high_pc +.b8 1 // DW_FORM_addr +.b8 49 // DW_AT_abstract_origin +.b8 19 // DW_FORM_ref4 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 4 // Abbreviation Code +.b8 29 // DW_TAG_inlined_subroutine +.b8 0 // DW_CHILDREN_no +.b8 49 // DW_AT_abstract_origin +.b8 19 // DW_FORM_ref4 +.b8 17 // DW_AT_low_pc +.b8 1 // DW_FORM_addr +.b8 18 // DW_AT_high_pc +.b8 1 // DW_FORM_addr +.b8 88 // DW_AT_call_file +.b8 11 // DW_FORM_data1 +.b8 89 // DW_AT_call_line +.b8 11 // DW_FORM_data1 +.b8 87 // DW_AT_call_column +.b8 11 // DW_FORM_data1 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 0 // EOM(3) + } + .section .debug_info + { +.b32 209 // Length of Unit +.b8 2 // DWARF version number +.b8 0 +.b32 .debug_abbrev // Offset Into Abbrev. Section +.b8 8 // Address Size (in bytes) +.b8 1 // Abbrev [1] 0xb:0xca DW_TAG_compile_unit +.b8 116 // DW_AT_producer +.b8 114 +.b8 105 +.b8 116 +.b8 111 +.b8 110 +.b8 0 +.b8 2 // DW_AT_language +.b8 0 +.b8 99 // DW_AT_name +.b8 117 +.b8 122 +.b8 99 +.b8 103 +.b8 116 +.b8 101 +.b8 98 +.b8 114 +.b8 121 +.b8 51 +.b8 120 +.b8 104 +.b8 120 +.b8 102 +.b8 119 +.b8 119 +.b8 110 +.b8 106 +.b8 112 +.b8 108 +.b8 102 +.b8 119 +.b8 115 +.b8 105 +.b8 111 +.b8 53 +.b8 112 +.b8 114 +.b8 120 +.b8 114 +.b8 116 +.b8 105 +.b8 55 +.b8 50 +.b8 110 +.b8 118 +.b8 108 +.b8 109 +.b8 103 +.b8 116 +.b8 104 +.b8 100 +.b8 51 +.b8 117 +.b8 109 +.b8 111 +.b8 105 +.b8 108 +.b8 116 +.b8 109 +.b8 98 +.b8 46 +.b8 112 +.b8 121 +.b8 0 +.b32 .debug_line // DW_AT_stmt_list +.b8 47 // DW_AT_comp_dir +.b8 119 +.b8 111 +.b8 114 +.b8 107 +.b8 115 +.b8 112 +.b8 97 +.b8 99 +.b8 101 +.b8 47 +.b8 104 +.b8 97 +.b8 110 +.b8 114 +.b8 117 +.b8 105 +.b8 47 +.b8 83 +.b8 112 +.b8 101 +.b8 99 +.b8 70 +.b8 111 +.b8 114 +.b8 103 +.b8 101 +.b8 45 +.b8 101 +.b8 120 +.b8 116 +.b8 47 +.b8 99 +.b8 97 +.b8 99 +.b8 104 +.b8 101 +.b8 47 +.b8 99 +.b8 111 +.b8 109 +.b8 112 +.b8 105 +.b8 108 +.b8 101 +.b8 100 +.b8 95 +.b8 107 +.b8 101 +.b8 114 +.b8 110 +.b8 101 +.b8 108 +.b8 115 +.b8 47 +.b8 117 +.b8 122 +.b8 0 +.b8 2 // Abbrev [2] 0x8b:0x1b DW_TAG_subprogram +.b8 116 // DW_AT_name +.b8 114 +.b8 105 +.b8 116 +.b8 111 +.b8 110 +.b8 95 +.b8 114 +.b8 101 +.b8 100 +.b8 95 +.b8 102 +.b8 117 +.b8 115 +.b8 101 +.b8 100 +.b8 95 +.b8 122 +.b8 101 +.b8 114 +.b8 111 +.b8 115 +.b8 95 +.b8 48 +.b8 0 +.b8 1 // DW_AT_inline +.b8 3 // Abbrev [3] 0xa6:0x2e DW_TAG_subprogram +.b64 $L__func_begin0 // DW_AT_low_pc +.b64 $L__func_end0 // DW_AT_high_pc +.b32 139 // DW_AT_abstract_origin +.b8 4 // Abbrev [4] 0xbb:0x18 DW_TAG_inlined_subroutine +.b32 139 // DW_AT_abstract_origin +.b64 $L__tmp1 // DW_AT_low_pc +.b64 $L__tmp2 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 45 // DW_AT_call_line +.b8 25 // DW_AT_call_column +.b8 0 // End Of Children Mark +.b8 0 // End Of Children Mark + } + .section .debug_macinfo { } diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/4ODNRDUZZQ6DC76SMZ2XG44NLENCYGAAAUTCAR2KZZILJTJU67XA/triton_red_fused_zeros_0.source b/SpecForge-ext/cache/compiled_kernels/triton/0/4ODNRDUZZQ6DC76SMZ2XG44NLENCYGAAAUTCAR2KZZILJTJU67XA/triton_red_fused_zeros_0.source new file mode 100644 index 0000000000000000000000000000000000000000..4838c0a109cc34de64baa7a0297954281410ec27 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/4ODNRDUZZQ6DC76SMZ2XG44NLENCYGAAAUTCAR2KZZILJTJU67XA/triton_red_fused_zeros_0.source @@ -0,0 +1,222 @@ +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":18:0) +#loc44 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":285:0) +#loc46 = loc(unknown) +#loc49 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":260:0) +#loc53 = loc("in_ptr0"(#loc)) +#loc54 = loc("in_ptr1"(#loc)) +#loc55 = loc("out_ptr1"(#loc)) +#loc56 = loc("xnumel"(#loc)) +#loc57 = loc("r0_numel"(#loc)) +#loc97 = loc("input"(#loc44)) +#loc98 = loc("a"(#loc49)) +#loc99 = loc("b"(#loc49)) +module { + tt.func public @triton_red_fused_zeros_0(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %in_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr1"(#loc)), %out_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr1"(#loc)), %xnumel: i32 {tt.divisibility = 16 : i32} loc("xnumel"(#loc)), %r0_numel: i32 {tt.divisibility = 16 : i32} loc("r0_numel"(#loc))) attributes {noinline = false} { + %xnumel_0 = arith.constant 131072 : i32 loc(#loc58) + %r0_numel_1 = arith.constant 128 : i32 loc(#loc59) + %xoffset = tt.get_program_id x : i32 loc(#loc60) + %xoffset_2 = arith.constant 4 : i32 loc(#loc61) + %xoffset_3 = arith.constant 4 : i32 loc(#loc61) + %xoffset_4 = arith.muli %xoffset, %xoffset_3 : i32 loc(#loc61) + %xindex = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> loc(#loc62) + %xindex_5 = tt.expand_dims %xindex {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> loc(#loc63) + %xindex_6 = tt.splat %xoffset_4 : i32 -> tensor<4x1xi32> loc(#loc64) + %xindex_7 = arith.addi %xindex_6, %xindex_5 : tensor<4x1xi32> loc(#loc64) + %xmask = arith.constant true loc(#loc65) + %xmask_8 = arith.constant dense : tensor<4x128xi1> loc(#loc65) + %r0_base = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc66) + %r0_base_9 = tt.expand_dims %r0_base {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc67) + %x0 = arith.constant 2048 : i32 loc(#loc68) + %x0_10 = arith.constant 2048 : i32 loc(#loc68) + %x0_11 = arith.constant dense<2048> : tensor<4x1xi32> loc(#loc68) + %x0_12 = arith.remsi %xindex_7, %x0_11 : tensor<4x1xi32> loc(#loc68) + %x1 = arith.constant 2048 : i32 loc(#loc69) + %x1_13 = arith.constant 2048 : i32 loc(#loc69) + %x1_14 = arith.constant dense<2048> : tensor<4x1xi32> loc(#loc69) + %x1_15 = arith.divsi %xindex_7, %x1_14 : tensor<4x1xi32> loc(#loc69) + %x1_16 = arith.constant 32 : i32 loc(#loc70) + %x1_17 = arith.constant 32 : i32 loc(#loc70) + %x1_18 = arith.constant dense<32> : tensor<4x1xi32> loc(#loc70) + %x1_19 = arith.remsi %x1_15, %x1_18 : tensor<4x1xi32> loc(#loc70) + %x2 = arith.constant 65536 : i32 loc(#loc71) + %x2_20 = arith.constant 65536 : i32 loc(#loc71) + %x2_21 = arith.constant dense<65536> : tensor<4x1xi32> loc(#loc71) + %x2_22 = arith.divsi %xindex_7, %x2_21 : tensor<4x1xi32> loc(#loc71) + %_tmp4 = arith.constant 0.000000e+00 : f32 loc(#loc72) + %_tmp4_23 = arith.constant dense<0.000000e+00> : tensor<4x128xf32> loc(#loc72) + %c0_i32 = arith.constant 0 : i32 loc(#loc16) + %c128_i32 = arith.constant 128 : i32 loc(#loc16) + %0 = arith.bitcast %c0_i32 : i32 to i32 loc(#loc16) + %1 = arith.bitcast %r0_numel_1 : i32 to i32 loc(#loc16) + %2 = arith.bitcast %c128_i32 : i32 to i32 loc(#loc16) + %3 = ub.poison : i32 loc(#loc16) + %_tmp4_24 = scf.for %r0_offset = %0 to %1 step %2 iter_args(%_tmp4_27 = %_tmp4_23) -> (tensor<4x128xf32>) : i32 { + %r0_index = tt.splat %r0_offset : i32 -> tensor<1x128xi32> loc(#loc74) + %r0_index_28 = arith.addi %r0_index, %r0_base_9 : tensor<1x128xi32> loc(#loc74) + %r0_mask = arith.constant dense<128> : tensor<1x128xi32> loc(#loc75) + %r0_mask_29 = arith.cmpi slt, %r0_index_28, %r0_mask : tensor<1x128xi32> loc(#loc75) + %tmp0 = arith.constant 128 : i32 loc(#loc76) + %tmp0_30 = arith.constant 128 : i32 loc(#loc76) + %tmp0_31 = arith.constant dense<128> : tensor<4x1xi32> loc(#loc76) + %tmp0_32 = arith.muli %tmp0_31, %x1_19 : tensor<4x1xi32> loc(#loc76) + %tmp0_33 = tt.broadcast %r0_index_28 : tensor<1x128xi32> -> tensor<4x128xi32> loc(#loc77) + %tmp0_34 = tt.broadcast %tmp0_32 : tensor<4x1xi32> -> tensor<4x128xi32> loc(#loc77) + %tmp0_35 = arith.addi %tmp0_33, %tmp0_34 : tensor<4x128xi32> loc(#loc77) + %tmp0_36 = arith.constant 4096 : i32 loc(#loc78) + %tmp0_37 = arith.constant 4096 : i32 loc(#loc78) + %tmp0_38 = arith.constant dense<4096> : tensor<4x1xi32> loc(#loc78) + %tmp0_39 = arith.muli %tmp0_38, %x0_12 : tensor<4x1xi32> loc(#loc78) + %tmp0_40 = tt.broadcast %tmp0_39 : tensor<4x1xi32> -> tensor<4x128xi32> loc(#loc79) + %tmp0_41 = arith.addi %tmp0_35, %tmp0_40 : tensor<4x128xi32> loc(#loc79) + %tmp0_42 = arith.constant 8388608 : i32 loc(#loc80) + %tmp0_43 = arith.constant 8388608 : i32 loc(#loc80) + %tmp0_44 = arith.constant dense<8388608> : tensor<4x1xi32> loc(#loc80) + %tmp0_45 = arith.muli %tmp0_44, %x2_22 : tensor<4x1xi32> loc(#loc80) + %tmp0_46 = tt.broadcast %tmp0_45 : tensor<4x1xi32> -> tensor<4x128xi32> loc(#loc81) + %tmp0_47 = arith.addi %tmp0_41, %tmp0_46 : tensor<4x128xi32> loc(#loc81) + %tmp0_48 = tt.splat %in_ptr0 : !tt.ptr -> tensor<4x128x!tt.ptr> loc(#loc82) + %tmp0_49 = tt.addptr %tmp0_48, %tmp0_47 : tensor<4x128x!tt.ptr>, tensor<4x128xi32> loc(#loc82) + %tmp0_50 = arith.constant 0.000000e+00 : f32 loc(#loc83) + %tmp0_51 = tt.broadcast %r0_mask_29 : tensor<1x128xi1> -> tensor<4x128xi1> loc(#loc83) + %tmp0_52 = arith.constant dense<0.000000e+00> : tensor<4x128xf32> loc(#loc83) + %tmp0_53 = arith.truncf %tmp0_52 : tensor<4x128xf32> to tensor<4x128xbf16> loc(#loc83) + %tmp0_54 = tt.load %tmp0_49, %tmp0_51, %tmp0_53 evictionPolicy = evict_first : tensor<4x128x!tt.ptr> loc(#loc83) + %tmp0_55 = arith.extf %tmp0_54 : tensor<4x128xbf16> to tensor<4x128xf32> loc(#loc84) + %tmp1 = arith.constant 128 : i32 loc(#loc85) + %tmp1_56 = arith.constant 128 : i32 loc(#loc85) + %tmp1_57 = arith.constant dense<128> : tensor<4x1xi32> loc(#loc85) + %tmp1_58 = arith.muli %tmp1_57, %xindex_7 : tensor<4x1xi32> loc(#loc85) + %tmp1_59 = tt.broadcast %r0_index_28 : tensor<1x128xi32> -> tensor<4x128xi32> loc(#loc86) + %tmp1_60 = tt.broadcast %tmp1_58 : tensor<4x1xi32> -> tensor<4x128xi32> loc(#loc86) + %tmp1_61 = arith.addi %tmp1_59, %tmp1_60 : tensor<4x128xi32> loc(#loc86) + %tmp1_62 = tt.splat %in_ptr1 : !tt.ptr -> tensor<4x128x!tt.ptr> loc(#loc87) + %tmp1_63 = tt.addptr %tmp1_62, %tmp1_61 : tensor<4x128x!tt.ptr>, tensor<4x128xi32> loc(#loc87) + %tmp1_64 = arith.constant 0.000000e+00 : f32 loc(#loc88) + %tmp1_65 = tt.broadcast %r0_mask_29 : tensor<1x128xi1> -> tensor<4x128xi1> loc(#loc88) + %tmp1_66 = arith.constant dense<0.000000e+00> : tensor<4x128xf32> loc(#loc88) + %tmp1_67 = arith.truncf %tmp1_66 : tensor<4x128xf32> to tensor<4x128xbf16> loc(#loc88) + %tmp1_68 = tt.load %tmp1_63, %tmp1_65, %tmp1_67 evictionPolicy = evict_first : tensor<4x128x!tt.ptr> loc(#loc88) + %tmp1_69 = arith.extf %tmp1_68 : tensor<4x128xbf16> to tensor<4x128xf32> loc(#loc89) + %tmp2 = arith.mulf %tmp0_55, %tmp1_69 : tensor<4x128xf32> loc(#loc90) + %tmp5 = arith.addf %_tmp4_27, %tmp2 : tensor<4x128xf32> loc(#loc91) + %_tmp4_70 = tt.broadcast %r0_mask_29 : tensor<1x128xi1> -> tensor<4x128xi1> loc(#loc92) + %_tmp4_71 = arith.select %_tmp4_70, %tmp5, %_tmp4_27 : tensor<4x128xi1>, tensor<4x128xf32> loc(#loc92) + scf.yield %_tmp4_71 : tensor<4x128xf32> loc(#loc36) + } loc(#loc73) + %tmp4 = tt.call @"triton.language.standard.sum__fp32S4_128S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%_tmp4_24) : (tensor<4x128xf32>) -> tensor<4xf32> loc(#loc93) + %tmp4_25 = tt.expand_dims %tmp4 {axis = 1 : i32} : tensor<4xf32> -> tensor<4x1xf32> loc(#loc94) + %tmp7 = arith.constant 0.000000e+00 : f32 loc(#loc95) + %tmp8 = arith.constant dense<0.000000e+00> : tensor<4x1xf32> loc(#loc96) + %tmp8_26 = arith.subf %tmp4_25, %tmp8 : tensor<4x1xf32> loc(#loc96) + %4 = tt.splat %out_ptr1 : !tt.ptr -> tensor<4x1x!tt.ptr> loc(#loc41) + %5 = tt.addptr %4, %xindex_7 : tensor<4x1x!tt.ptr>, tensor<4x1xi32> loc(#loc41) + tt.store %5, %tmp8_26 : tensor<4x1x!tt.ptr> loc(#loc42) + tt.return loc(#loc43) + } loc(#loc) + tt.func private @"triton.language.standard.sum__fp32S4_128S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%input: tensor<4x128xf32> loc("input"(#loc44))) -> tensor<4xf32> attributes {noinline = false} { + %0 = "tt.reduce"(%input) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32 loc(unknown), %arg2: f32 loc(unknown)): + %2 = tt.call @triton.language.standard._sum_combine__fp32_fp32__(%arg1, %arg2) : (f32, f32) -> f32 loc(#loc45) + tt.reduce.return %2 : f32 loc(#loc45) + }) : (tensor<4x128xf32>) -> tensor<4xf32> loc(#loc45) + tt.return %0 : tensor<4xf32> loc(#loc47) + ^bb1: // no predecessors + %1 = ub.poison : tensor<4xf32> loc(#loc48) + tt.return %1 : tensor<4xf32> loc(#loc48) + } loc(#loc44) + tt.func private @triton.language.standard._sum_combine__fp32_fp32__(%a: f32 loc("a"(#loc49)), %b: f32 loc("b"(#loc49))) -> f32 attributes {noinline = false} { + %0 = arith.addf %a, %b : f32 loc(#loc50) + tt.return %0 : f32 loc(#loc51) + ^bb1: // no predecessors + %1 = ub.poison : f32 loc(#loc52) + tt.return %1 : f32 loc(#loc52) + } loc(#loc49) +} loc(#loc) +#loc1 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":19:13) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":20:15) +#loc3 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":23:28) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":23:33) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":24:36) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":24:44) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":24:23) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":25:46) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":26:27) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":26:37) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":28:19) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":29:21) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":29:29) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":30:19) +#loc15 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":32:43) +#loc16 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":33:40) +#loc17 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":34:31) +#loc18 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":35:29) +#loc19 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":39:45) +#loc20 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":39:41) +#loc21 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":39:55) +#loc22 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":39:50) +#loc23 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":39:68) +#loc24 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":39:60) +#loc25 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":39:34) +#loc26 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":39:73) +#loc27 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":39:127) +#loc28 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":40:45) +#loc29 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":40:41) +#loc30 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":40:34) +#loc31 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":40:50) +#loc32 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":40:104) +#loc33 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":41:22) +#loc34 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":43:23) +#loc35 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":44:40) +#loc36 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":44:8) +#loc37 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":45:25) +#loc38 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":45:28) +#loc39 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":47:11) +#loc40 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":48:18) +#loc41 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":49:25) +#loc42 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":49:36) +#loc43 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":49:4) +#loc45 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:36) +#loc47 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:11) +#loc48 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:4) +#loc50 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:15) +#loc51 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:11) +#loc52 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:4) +#loc58 = loc("xnumel"(#loc1)) +#loc59 = loc("r0_numel"(#loc2)) +#loc60 = loc("xoffset"(#loc3)) +#loc61 = loc("xoffset"(#loc4)) +#loc62 = loc("xindex"(#loc5)) +#loc63 = loc("xindex"(#loc6)) +#loc64 = loc("xindex"(#loc7)) +#loc65 = loc("xmask"(#loc8)) +#loc66 = loc("r0_base"(#loc9)) +#loc67 = loc("r0_base"(#loc10)) +#loc68 = loc("x0"(#loc11)) +#loc69 = loc("x1"(#loc12)) +#loc70 = loc("x1"(#loc13)) +#loc71 = loc("x2"(#loc14)) +#loc72 = loc("_tmp4"(#loc15)) +#loc73 = loc("_tmp4"(#loc16)) +#loc74 = loc("r0_index"(#loc17)) +#loc75 = loc("r0_mask"(#loc18)) +#loc76 = loc("tmp0"(#loc19)) +#loc77 = loc("tmp0"(#loc20)) +#loc78 = loc("tmp0"(#loc21)) +#loc79 = loc("tmp0"(#loc22)) +#loc80 = loc("tmp0"(#loc23)) +#loc81 = loc("tmp0"(#loc24)) +#loc82 = loc("tmp0"(#loc25)) +#loc83 = loc("tmp0"(#loc26)) +#loc84 = loc("tmp0"(#loc27)) +#loc85 = loc("tmp1"(#loc28)) +#loc86 = loc("tmp1"(#loc29)) +#loc87 = loc("tmp1"(#loc30)) +#loc88 = loc("tmp1"(#loc31)) +#loc89 = loc("tmp1"(#loc32)) +#loc90 = loc("tmp2"(#loc33)) +#loc91 = loc("tmp5"(#loc34)) +#loc92 = loc("_tmp4"(#loc35)) +#loc93 = loc("tmp4"(#loc37)) +#loc94 = loc("tmp4"(#loc38)) +#loc95 = loc("tmp7"(#loc39)) +#loc96 = loc("tmp8"(#loc40)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/4ODNRDUZZQ6DC76SMZ2XG44NLENCYGAAAUTCAR2KZZILJTJU67XA/triton_red_fused_zeros_0.ttgir b/SpecForge-ext/cache/compiled_kernels/triton/0/4ODNRDUZZQ6DC76SMZ2XG44NLENCYGAAAUTCAR2KZZILJTJU67XA/triton_red_fused_zeros_0.ttgir new file mode 100644 index 0000000000000000000000000000000000000000..aa299359acc6f4aee5411619d3bfc279fee956a5 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/4ODNRDUZZQ6DC76SMZ2XG44NLENCYGAAAUTCAR2KZZILJTJU67XA/triton_red_fused_zeros_0.ttgir @@ -0,0 +1,142 @@ +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1]}> +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":18:0) +#loc1 = loc(unknown) +#loc30 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":45:25) +#loc36 = loc("in_ptr0"(#loc)) +#loc37 = loc("in_ptr1"(#loc)) +#loc38 = loc("out_ptr1"(#loc)) +#loc39 = loc("xnumel"(#loc)) +#loc40 = loc("r0_numel"(#loc)) +#loc68 = loc("tmp4"(#loc30)) +#loc71 = loc(callsite(#loc1 at #loc68)) +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @triton_red_fused_zeros_0(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %in_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr1"(#loc)), %out_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr1"(#loc)), %xnumel: i32 {tt.divisibility = 16 : i32} loc("xnumel"(#loc)), %r0_numel: i32 {tt.divisibility = 16 : i32} loc("r0_numel"(#loc))) attributes {noinline = false} { + %cst = arith.constant dense<128> : tensor<1x128xi32, #blocked> loc(#loc1) + %cst_0 = arith.constant dense<128> : tensor<4x1xi32, #blocked> loc(#loc1) + %cst_1 = arith.constant dense<4096> : tensor<4x1xi32, #blocked> loc(#loc1) + %cst_2 = arith.constant dense<8388608> : tensor<4x1xi32, #blocked> loc(#loc1) + %cst_3 = arith.constant dense<65536> : tensor<4x1xi32, #blocked> loc(#loc1) + %cst_4 = arith.constant dense<32> : tensor<4x1xi32, #blocked> loc(#loc1) + %cst_5 = arith.constant dense<2048> : tensor<4x1xi32, #blocked> loc(#loc1) + %c4_i32 = arith.constant 4 : i32 loc(#loc1) + %cst_6 = arith.constant dense<0.000000e+00> : tensor<4x128xbf16, #blocked> loc(#loc1) + %cst_7 = arith.constant dense<0.000000e+00> : tensor<4x128xf32, #blocked> loc(#loc1) + %xoffset = tt.get_program_id x : i32 loc(#loc41) + %xoffset_8 = arith.muli %xoffset, %c4_i32 : i32 loc(#loc42) + %xindex = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc43) + %xindex_9 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc43) + %xindex_10 = tt.expand_dims %xindex {axis = 1 : i32} : tensor<4xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<4x1xi32, #blocked> loc(#loc43) + %xindex_11 = tt.expand_dims %xindex_9 {axis = 1 : i32} : tensor<4xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<4x1xi32, #blocked1> loc(#loc43) + %xindex_12 = tt.splat %xoffset_8 : i32 -> tensor<4x1xi32, #blocked> loc(#loc44) + %xindex_13 = tt.splat %xoffset_8 : i32 -> tensor<4x1xi32, #blocked1> loc(#loc44) + %xindex_14 = arith.addi %xindex_12, %xindex_10 : tensor<4x1xi32, #blocked> loc(#loc44) + %xindex_15 = arith.addi %xindex_13, %xindex_11 : tensor<4x1xi32, #blocked1> loc(#loc44) + %r0_base = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc45) + %r0_base_16 = tt.expand_dims %r0_base {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> loc(#loc45) + %x0 = arith.remsi %xindex_14, %cst_5 : tensor<4x1xi32, #blocked> loc(#loc46) + %x1 = arith.divsi %xindex_14, %cst_5 : tensor<4x1xi32, #blocked> loc(#loc47) + %x1_17 = arith.remsi %x1, %cst_4 : tensor<4x1xi32, #blocked> loc(#loc48) + %x2 = arith.divsi %xindex_14, %cst_3 : tensor<4x1xi32, #blocked> loc(#loc49) + %r0_mask = arith.cmpi slt, %r0_base_16, %cst : tensor<1x128xi32, #blocked> loc(#loc50) + %tmp0 = arith.muli %x1_17, %cst_0 : tensor<4x1xi32, #blocked> loc(#loc51) + %tmp0_18 = tt.broadcast %r0_base_16 : tensor<1x128xi32, #blocked> -> tensor<4x128xi32, #blocked> loc(#loc52) + %tmp0_19 = tt.broadcast %tmp0 : tensor<4x1xi32, #blocked> -> tensor<4x128xi32, #blocked> loc(#loc52) + %tmp0_20 = arith.addi %tmp0_18, %tmp0_19 : tensor<4x128xi32, #blocked> loc(#loc52) + %tmp0_21 = arith.muli %x0, %cst_1 : tensor<4x1xi32, #blocked> loc(#loc53) + %tmp0_22 = tt.broadcast %tmp0_21 : tensor<4x1xi32, #blocked> -> tensor<4x128xi32, #blocked> loc(#loc54) + %tmp0_23 = arith.addi %tmp0_20, %tmp0_22 : tensor<4x128xi32, #blocked> loc(#loc54) + %tmp0_24 = arith.muli %x2, %cst_2 : tensor<4x1xi32, #blocked> loc(#loc55) + %tmp0_25 = tt.broadcast %tmp0_24 : tensor<4x1xi32, #blocked> -> tensor<4x128xi32, #blocked> loc(#loc56) + %tmp0_26 = arith.addi %tmp0_23, %tmp0_25 : tensor<4x128xi32, #blocked> loc(#loc56) + %tmp0_27 = tt.splat %in_ptr0 : !tt.ptr -> tensor<4x128x!tt.ptr, #blocked> loc(#loc57) + %tmp0_28 = tt.addptr %tmp0_27, %tmp0_26 : tensor<4x128x!tt.ptr, #blocked>, tensor<4x128xi32, #blocked> loc(#loc57) + %tmp0_29 = tt.broadcast %r0_mask : tensor<1x128xi1, #blocked> -> tensor<4x128xi1, #blocked> loc(#loc58) + %tmp0_30 = tt.load %tmp0_28, %tmp0_29, %cst_6 evictionPolicy = evict_first : tensor<4x128x!tt.ptr, #blocked> loc(#loc58) + %tmp0_31 = arith.extf %tmp0_30 : tensor<4x128xbf16, #blocked> to tensor<4x128xf32, #blocked> loc(#loc59) + %tmp1 = arith.muli %xindex_14, %cst_0 : tensor<4x1xi32, #blocked> loc(#loc60) + %tmp1_32 = tt.broadcast %tmp1 : tensor<4x1xi32, #blocked> -> tensor<4x128xi32, #blocked> loc(#loc61) + %tmp1_33 = arith.addi %tmp0_18, %tmp1_32 : tensor<4x128xi32, #blocked> loc(#loc61) + %tmp1_34 = tt.splat %in_ptr1 : !tt.ptr -> tensor<4x128x!tt.ptr, #blocked> loc(#loc62) + %tmp1_35 = tt.addptr %tmp1_34, %tmp1_33 : tensor<4x128x!tt.ptr, #blocked>, tensor<4x128xi32, #blocked> loc(#loc62) + %tmp1_36 = tt.load %tmp1_35, %tmp0_29, %cst_6 evictionPolicy = evict_first : tensor<4x128x!tt.ptr, #blocked> loc(#loc63) + %tmp1_37 = arith.extf %tmp1_36 : tensor<4x128xbf16, #blocked> to tensor<4x128xf32, #blocked> loc(#loc64) + %tmp2 = arith.mulf %tmp0_31, %tmp1_37 : tensor<4x128xf32, #blocked> loc(#loc65) + %tmp5 = arith.addf %tmp2, %cst_7 : tensor<4x128xf32, #blocked> loc(#loc66) + %_tmp4 = arith.select %tmp0_29, %tmp5, %cst_7 : tensor<4x128xi1, #blocked>, tensor<4x128xf32, #blocked> loc(#loc67) + %tmp4 = "tt.reduce"(%_tmp4) <{axis = 1 : i32}> ({ + ^bb0(%tmp4_40: f32 loc(callsite(#loc1 at #loc68)), %tmp4_41: f32 loc(callsite(#loc1 at #loc68))): + %tmp4_42 = arith.addf %tmp4_40, %tmp4_41 : f32 loc(#loc72) + tt.reduce.return %tmp4_42 : f32 loc(#loc70) + }) : (tensor<4x128xf32, #blocked>) -> tensor<4xf32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc70) + %tmp4_38 = ttg.convert_layout %tmp4 : tensor<4xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<4xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc69) + %tmp4_39 = tt.expand_dims %tmp4_38 {axis = 1 : i32} : tensor<4xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<4x1xf32, #blocked1> loc(#loc69) + %0 = tt.splat %out_ptr1 : !tt.ptr -> tensor<4x1x!tt.ptr, #blocked1> loc(#loc33) + %1 = tt.addptr %0, %xindex_15 : tensor<4x1x!tt.ptr, #blocked1>, tensor<4x1xi32, #blocked1> loc(#loc33) + tt.store %1, %tmp4_39 : tensor<4x1x!tt.ptr, #blocked1> loc(#loc34) + tt.return loc(#loc35) + } loc(#loc) +} loc(#loc) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":23:28) +#loc3 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":23:33) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":24:44) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":24:23) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":26:37) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":28:19) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":29:21) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":29:29) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":30:19) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":35:29) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":39:45) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":39:41) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":39:55) +#loc15 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":39:50) +#loc16 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":39:68) +#loc17 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":39:60) +#loc18 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":39:34) +#loc19 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":39:73) +#loc20 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":39:127) +#loc21 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":40:45) +#loc22 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":40:41) +#loc23 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":40:34) +#loc24 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":40:50) +#loc25 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":40:104) +#loc26 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":41:22) +#loc27 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":43:23) +#loc28 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":44:40) +#loc29 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:36) +#loc31 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:15) +#loc32 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":45:28) +#loc33 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":49:25) +#loc34 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":49:36) +#loc35 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":49:4) +#loc41 = loc("xoffset"(#loc2)) +#loc42 = loc("xoffset"(#loc3)) +#loc43 = loc("xindex"(#loc4)) +#loc44 = loc("xindex"(#loc5)) +#loc45 = loc("r0_base"(#loc6)) +#loc46 = loc("x0"(#loc7)) +#loc47 = loc("x1"(#loc8)) +#loc48 = loc("x1"(#loc9)) +#loc49 = loc("x2"(#loc10)) +#loc50 = loc("r0_mask"(#loc11)) +#loc51 = loc("tmp0"(#loc12)) +#loc52 = loc("tmp0"(#loc13)) +#loc53 = loc("tmp0"(#loc14)) +#loc54 = loc("tmp0"(#loc15)) +#loc55 = loc("tmp0"(#loc16)) +#loc56 = loc("tmp0"(#loc17)) +#loc57 = loc("tmp0"(#loc18)) +#loc58 = loc("tmp0"(#loc19)) +#loc59 = loc("tmp0"(#loc20)) +#loc60 = loc("tmp1"(#loc21)) +#loc61 = loc("tmp1"(#loc22)) +#loc62 = loc("tmp1"(#loc23)) +#loc63 = loc("tmp1"(#loc24)) +#loc64 = loc("tmp1"(#loc25)) +#loc65 = loc("tmp2"(#loc26)) +#loc66 = loc("tmp5"(#loc27)) +#loc67 = loc("_tmp4"(#loc28)) +#loc69 = loc("tmp4"(#loc32)) +#loc70 = loc(callsite(#loc29 at #loc68)) +#loc72 = loc(callsite(#loc31 at #loc70)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/4ODNRDUZZQ6DC76SMZ2XG44NLENCYGAAAUTCAR2KZZILJTJU67XA/triton_red_fused_zeros_0.ttir b/SpecForge-ext/cache/compiled_kernels/triton/0/4ODNRDUZZQ6DC76SMZ2XG44NLENCYGAAAUTCAR2KZZILJTJU67XA/triton_red_fused_zeros_0.ttir new file mode 100644 index 0000000000000000000000000000000000000000..1920bc2ad7c733f2b632ffa97ad3b4634a16bf21 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/4ODNRDUZZQ6DC76SMZ2XG44NLENCYGAAAUTCAR2KZZILJTJU67XA/triton_red_fused_zeros_0.ttir @@ -0,0 +1,139 @@ +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":18:0) +#loc1 = loc(unknown) +#loc32 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":45:25) +#loc38 = loc("in_ptr0"(#loc)) +#loc39 = loc("in_ptr1"(#loc)) +#loc40 = loc("out_ptr1"(#loc)) +#loc41 = loc("xnumel"(#loc)) +#loc42 = loc("r0_numel"(#loc)) +#loc72 = loc("tmp4"(#loc32)) +#loc75 = loc(callsite(#loc1 at #loc72)) +module { + tt.func public @triton_red_fused_zeros_0(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %in_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr1"(#loc)), %out_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr1"(#loc)), %xnumel: i32 {tt.divisibility = 16 : i32} loc("xnumel"(#loc)), %r0_numel: i32 {tt.divisibility = 16 : i32} loc("r0_numel"(#loc))) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<4x128xbf16> loc(#loc1) + %cst_0 = arith.constant dense<8388608> : tensor<4x1xi32> loc(#loc1) + %cst_1 = arith.constant dense<4096> : tensor<4x1xi32> loc(#loc1) + %cst_2 = arith.constant dense<128> : tensor<4x1xi32> loc(#loc1) + %cst_3 = arith.constant dense<128> : tensor<1x128xi32> loc(#loc1) + %cst_4 = arith.constant dense<0.000000e+00> : tensor<4x128xf32> loc(#loc1) + %x2 = arith.constant dense<65536> : tensor<4x1xi32> loc(#loc43) + %x1 = arith.constant dense<32> : tensor<4x1xi32> loc(#loc44) + %cst_5 = arith.constant dense<2048> : tensor<4x1xi32> loc(#loc1) + %c4_i32 = arith.constant 4 : i32 loc(#loc1) + %xoffset = tt.get_program_id x : i32 loc(#loc45) + %xoffset_6 = arith.muli %xoffset, %c4_i32 : i32 loc(#loc46) + %xindex = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> loc(#loc47) + %xindex_7 = tt.expand_dims %xindex {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> loc(#loc48) + %xindex_8 = tt.splat %xoffset_6 : i32 -> tensor<4x1xi32> loc(#loc49) + %xindex_9 = arith.addi %xindex_8, %xindex_7 : tensor<4x1xi32> loc(#loc49) + %r0_base = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc50) + %r0_base_10 = tt.expand_dims %r0_base {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc51) + %x0 = arith.remsi %xindex_9, %cst_5 : tensor<4x1xi32> loc(#loc52) + %x1_11 = arith.divsi %xindex_9, %cst_5 : tensor<4x1xi32> loc(#loc53) + %x1_12 = arith.remsi %x1_11, %x1 : tensor<4x1xi32> loc(#loc44) + %x2_13 = arith.divsi %xindex_9, %x2 : tensor<4x1xi32> loc(#loc43) + %r0_mask = arith.cmpi slt, %r0_base_10, %cst_3 : tensor<1x128xi32> loc(#loc54) + %tmp0 = arith.muli %x1_12, %cst_2 : tensor<4x1xi32> loc(#loc55) + %tmp0_14 = tt.broadcast %r0_base_10 : tensor<1x128xi32> -> tensor<4x128xi32> loc(#loc56) + %tmp0_15 = tt.broadcast %tmp0 : tensor<4x1xi32> -> tensor<4x128xi32> loc(#loc56) + %tmp0_16 = arith.addi %tmp0_14, %tmp0_15 : tensor<4x128xi32> loc(#loc56) + %tmp0_17 = arith.muli %x0, %cst_1 : tensor<4x1xi32> loc(#loc57) + %tmp0_18 = tt.broadcast %tmp0_17 : tensor<4x1xi32> -> tensor<4x128xi32> loc(#loc58) + %tmp0_19 = arith.addi %tmp0_16, %tmp0_18 : tensor<4x128xi32> loc(#loc58) + %tmp0_20 = arith.muli %x2_13, %cst_0 : tensor<4x1xi32> loc(#loc59) + %tmp0_21 = tt.broadcast %tmp0_20 : tensor<4x1xi32> -> tensor<4x128xi32> loc(#loc60) + %tmp0_22 = arith.addi %tmp0_19, %tmp0_21 : tensor<4x128xi32> loc(#loc60) + %tmp0_23 = tt.splat %in_ptr0 : !tt.ptr -> tensor<4x128x!tt.ptr> loc(#loc61) + %tmp0_24 = tt.addptr %tmp0_23, %tmp0_22 : tensor<4x128x!tt.ptr>, tensor<4x128xi32> loc(#loc61) + %tmp0_25 = tt.broadcast %r0_mask : tensor<1x128xi1> -> tensor<4x128xi1> loc(#loc62) + %tmp0_26 = tt.load %tmp0_24, %tmp0_25, %cst evictionPolicy = evict_first : tensor<4x128x!tt.ptr> loc(#loc62) + %tmp0_27 = arith.extf %tmp0_26 : tensor<4x128xbf16> to tensor<4x128xf32> loc(#loc63) + %tmp1 = arith.muli %xindex_9, %cst_2 : tensor<4x1xi32> loc(#loc64) + %tmp1_28 = tt.broadcast %tmp1 : tensor<4x1xi32> -> tensor<4x128xi32> loc(#loc65) + %tmp1_29 = arith.addi %tmp0_14, %tmp1_28 : tensor<4x128xi32> loc(#loc65) + %tmp1_30 = tt.splat %in_ptr1 : !tt.ptr -> tensor<4x128x!tt.ptr> loc(#loc66) + %tmp1_31 = tt.addptr %tmp1_30, %tmp1_29 : tensor<4x128x!tt.ptr>, tensor<4x128xi32> loc(#loc66) + %tmp1_32 = tt.load %tmp1_31, %tmp0_25, %cst evictionPolicy = evict_first : tensor<4x128x!tt.ptr> loc(#loc67) + %tmp1_33 = arith.extf %tmp1_32 : tensor<4x128xbf16> to tensor<4x128xf32> loc(#loc68) + %tmp2 = arith.mulf %tmp0_27, %tmp1_33 : tensor<4x128xf32> loc(#loc69) + %tmp5 = arith.addf %tmp2, %cst_4 : tensor<4x128xf32> loc(#loc70) + %_tmp4 = arith.select %tmp0_25, %tmp5, %cst_4 : tensor<4x128xi1>, tensor<4x128xf32> loc(#loc71) + %tmp4 = "tt.reduce"(%_tmp4) <{axis = 1 : i32}> ({ + ^bb0(%tmp4_35: f32 loc(callsite(#loc1 at #loc72)), %tmp4_36: f32 loc(callsite(#loc1 at #loc72))): + %tmp4_37 = arith.addf %tmp4_35, %tmp4_36 : f32 loc(#loc76) + tt.reduce.return %tmp4_37 : f32 loc(#loc74) + }) : (tensor<4x128xf32>) -> tensor<4xf32> loc(#loc74) + %tmp4_34 = tt.expand_dims %tmp4 {axis = 1 : i32} : tensor<4xf32> -> tensor<4x1xf32> loc(#loc73) + %0 = tt.splat %out_ptr1 : !tt.ptr -> tensor<4x1x!tt.ptr> loc(#loc35) + %1 = tt.addptr %0, %xindex_9 : tensor<4x1x!tt.ptr>, tensor<4x1xi32> loc(#loc35) + tt.store %1, %tmp4_34 : tensor<4x1x!tt.ptr> loc(#loc36) + tt.return loc(#loc37) + } loc(#loc) +} loc(#loc) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":30:19) +#loc3 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":29:29) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":23:28) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":23:33) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":24:36) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":24:44) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":24:23) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":26:27) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":26:37) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":28:19) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":29:21) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":35:29) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":39:45) +#loc15 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":39:41) +#loc16 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":39:55) +#loc17 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":39:50) +#loc18 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":39:68) +#loc19 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":39:60) +#loc20 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":39:34) +#loc21 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":39:73) +#loc22 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":39:127) +#loc23 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":40:45) +#loc24 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":40:41) +#loc25 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":40:34) +#loc26 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":40:50) +#loc27 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":40:104) +#loc28 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":41:22) +#loc29 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":43:23) +#loc30 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":44:40) +#loc31 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:36) +#loc33 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:15) +#loc34 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":45:28) +#loc35 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":49:25) +#loc36 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":49:36) +#loc37 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/uz/cuzcgtebry3xhxfwwnjplfwsio5prxrti72nvlmgthd3umoiltmb.py":49:4) +#loc43 = loc("x2"(#loc2)) +#loc44 = loc("x1"(#loc3)) +#loc45 = loc("xoffset"(#loc4)) +#loc46 = loc("xoffset"(#loc5)) +#loc47 = loc("xindex"(#loc6)) +#loc48 = loc("xindex"(#loc7)) +#loc49 = loc("xindex"(#loc8)) +#loc50 = loc("r0_base"(#loc9)) +#loc51 = loc("r0_base"(#loc10)) +#loc52 = loc("x0"(#loc11)) +#loc53 = loc("x1"(#loc12)) +#loc54 = loc("r0_mask"(#loc13)) +#loc55 = loc("tmp0"(#loc14)) +#loc56 = loc("tmp0"(#loc15)) +#loc57 = loc("tmp0"(#loc16)) +#loc58 = loc("tmp0"(#loc17)) +#loc59 = loc("tmp0"(#loc18)) +#loc60 = loc("tmp0"(#loc19)) +#loc61 = loc("tmp0"(#loc20)) +#loc62 = loc("tmp0"(#loc21)) +#loc63 = loc("tmp0"(#loc22)) +#loc64 = loc("tmp1"(#loc23)) +#loc65 = loc("tmp1"(#loc24)) +#loc66 = loc("tmp1"(#loc25)) +#loc67 = loc("tmp1"(#loc26)) +#loc68 = loc("tmp1"(#loc27)) +#loc69 = loc("tmp2"(#loc28)) +#loc70 = loc("tmp5"(#loc29)) +#loc71 = loc("_tmp4"(#loc30)) +#loc73 = loc("tmp4"(#loc34)) +#loc74 = loc(callsite(#loc31 at #loc72)) +#loc76 = loc(callsite(#loc33 at #loc74)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/5P66Y2R4BAVAKI2AZ4OOKOSCRGKH7SKT5SYLTP4RXGBUCBAJIDZQ/__grp__triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.json b/SpecForge-ext/cache/compiled_kernels/triton/0/5P66Y2R4BAVAKI2AZ4OOKOSCRGKH7SKT5SYLTP4RXGBUCBAJIDZQ/__grp__triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.json new file mode 100644 index 0000000000000000000000000000000000000000..67c4154d3b5fd4e5ed14cce2aca7a9906a3e97a1 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/5P66Y2R4BAVAKI2AZ4OOKOSCRGKH7SKT5SYLTP4RXGBUCBAJIDZQ/__grp__triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.json @@ -0,0 +1 @@ +{"child_paths": {"triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.source": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/5P66Y2R4BAVAKI2AZ4OOKOSCRGKH7SKT5SYLTP4RXGBUCBAJIDZQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.source", "triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ttir": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/5P66Y2R4BAVAKI2AZ4OOKOSCRGKH7SKT5SYLTP4RXGBUCBAJIDZQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ttir", "triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ttgir": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/5P66Y2R4BAVAKI2AZ4OOKOSCRGKH7SKT5SYLTP4RXGBUCBAJIDZQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ttgir", "triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.llir": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/5P66Y2R4BAVAKI2AZ4OOKOSCRGKH7SKT5SYLTP4RXGBUCBAJIDZQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.llir", "triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ptx": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/5P66Y2R4BAVAKI2AZ4OOKOSCRGKH7SKT5SYLTP4RXGBUCBAJIDZQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ptx", "triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.cubin": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/5P66Y2R4BAVAKI2AZ4OOKOSCRGKH7SKT5SYLTP4RXGBUCBAJIDZQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.cubin", "triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.json": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/5P66Y2R4BAVAKI2AZ4OOKOSCRGKH7SKT5SYLTP4RXGBUCBAJIDZQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.json"}} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/5P66Y2R4BAVAKI2AZ4OOKOSCRGKH7SKT5SYLTP4RXGBUCBAJIDZQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.cubin b/SpecForge-ext/cache/compiled_kernels/triton/0/5P66Y2R4BAVAKI2AZ4OOKOSCRGKH7SKT5SYLTP4RXGBUCBAJIDZQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.cubin new file mode 100644 index 0000000000000000000000000000000000000000..fcf4ba7483734066b3850e3798ce9ca3dbfab72d Binary files /dev/null and b/SpecForge-ext/cache/compiled_kernels/triton/0/5P66Y2R4BAVAKI2AZ4OOKOSCRGKH7SKT5SYLTP4RXGBUCBAJIDZQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.cubin differ diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/5P66Y2R4BAVAKI2AZ4OOKOSCRGKH7SKT5SYLTP4RXGBUCBAJIDZQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.llir b/SpecForge-ext/cache/compiled_kernels/triton/0/5P66Y2R4BAVAKI2AZ4OOKOSCRGKH7SKT5SYLTP4RXGBUCBAJIDZQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.llir new file mode 100644 index 0000000000000000000000000000000000000000..f2b4674f4817e34138224d08469ca06ae2c43496 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/5P66Y2R4BAVAKI2AZ4OOKOSCRGKH7SKT5SYLTP4RXGBUCBAJIDZQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.llir @@ -0,0 +1,1110 @@ +; ModuleID = 'LLVMDialectModule' +source_filename = "LLVMDialectModule" +target datalayout = "e-p3:32:32-p4:32:32-p5:32:32-p6:32:32-p7:32:32-i64:64-i128:128-v16:16-v32:32-n16:32:64" + +@assertFunc_1 = internal constant [8 x i8] c"unknown\00" +@assertFile_1 = internal constant [114 x i8] c"/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py\00" +@assertMessage_1 = internal constant [37 x i8] c"index out of bounds: 0 <= tmp49 < 17\00" +@assertFunc_0 = internal constant [8 x i8] c"unknown\00" +@assertFile_0 = internal constant [114 x i8] c"/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py\00" +@assertMessage_0 = internal constant [37 x i8] c"index out of bounds: 0 <= tmp40 < 17\00" +@global_smem = external local_unnamed_addr addrspace(3) global [0 x i8], align 16 + +; Function Attrs: noreturn +declare !dbg !5 void @__assertfail(ptr, ptr, i32, ptr, i64) local_unnamed_addr #0 + +define ptx_kernel void @triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, ptr addrspace(1) %3, ptr addrspace(1) %4, ptr addrspace(1) %5, ptr addrspace(1) %6, i32 %7, i32 %8, ptr addrspace(1) readnone captures(none) %9, ptr addrspace(1) readnone captures(none) %10) local_unnamed_addr #1 !dbg !9 { + %12 = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !dbg !10 + %13 = shl i32 %12, 3, !dbg !11 + %14 = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !dbg !12 + %15 = and i32 %14, 56, !dbg !12 + %16 = lshr exact i32 %15, 3, !dbg !12 + %17 = and i32 %14, 7, !dbg !12 + %18 = or disjoint i32 %16, %13, !dbg !13 + %19 = icmp slt i32 %18, 32, !dbg !14 + %20 = shl nuw nsw i32 %17, 1, !dbg !15 + %21 = or disjoint i32 %20, 1, !dbg !15 + %22 = shl i32 %18, 4, !dbg !16 + %23 = or disjoint i32 %22, %20, !dbg !17 + %24 = sext i32 %23 to i64, !dbg !18 + %25 = getelementptr i64, ptr addrspace(1) %0, i64 %24, !dbg !18 + %26 = tail call { i64, i64 } asm sideeffect "mov.u64 $0, 0x0;\0A\09mov.u64 $1, 0x0;\0A\09@$3 ld.global.v2.b64 { $0, $1 }, [ $2 + 0 ];", "=l,=l,l,b"(ptr addrspace(1) %25, i1 %19) #5, !dbg !19 + %27 = extractvalue { i64, i64 } %26, 0, !dbg !19 + %28 = extractvalue { i64, i64 } %26, 1, !dbg !19 + %29 = add i64 %27, -1, !dbg !20 + %30 = icmp ult i64 %29, 16383, !dbg !20 + %31 = add i64 %28, -1, !dbg !20 + %32 = icmp ult i64 %31, 16383, !dbg !20 + %33 = zext i1 %30 to i32, !dbg !21 + %34 = zext i1 %32 to i32, !dbg !21 + %35 = and i32 %14, 1, !dbg !22 + %36 = lshr i32 %14, 1, !dbg !22 + %.lobit = and i32 %36, 1, !dbg !22 + %37 = lshr i32 %14, 2, !dbg !22 + %.lobit1 = and i32 %37, 1, !dbg !22 + %38 = xor i32 %35, 1, !dbg !26 + %39 = xor i32 %.lobit, 1, !dbg !26 + %40 = xor i32 %.lobit1, 1, !dbg !26 + %41 = xor i1 %30, true, !dbg !27 + %42 = and i1 %32, %41, !dbg !27 + %43 = trunc i32 %14 to i1, !dbg !28 + %44 = xor i1 %42, %43, !dbg !28 + %45 = xor i32 %33, %34, !dbg !29 + %46 = select i1 %44, i32 %45, i32 0, !dbg !30 + %47 = xor i32 %46, %33, !dbg !31 + %48 = xor i32 %46, %34, !dbg !31 + %49 = xor i32 %21, %20, !dbg !32 + %50 = select i1 %44, i32 %49, i32 0, !dbg !33 + %51 = xor i32 %50, %20, !dbg !34 + %52 = xor i32 %50, %21, !dbg !34 + %53 = mul nuw nsw i32 %47, %38, !dbg !35 + %54 = mul nuw nsw i32 %48, %38, !dbg !35 + %55 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %53, i32 1, i32 31), !dbg !36 + %56 = add i32 %53, %55, !dbg !39 + %57 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %54, i32 1, i32 31), !dbg !36 + %58 = add i32 %54, %57, !dbg !39 + %59 = mul nuw nsw i32 %47, %35, !dbg !40 + %60 = mul nuw nsw i32 %48, %35, !dbg !40 + %61 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %59, i32 1, i32 31), !dbg !36 + %62 = add i32 %59, %61, !dbg !39 + %63 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %60, i32 1, i32 31), !dbg !36 + %64 = add i32 %60, %63, !dbg !39 + %65 = mul nuw nsw i32 %51, %38, !dbg !41 + %66 = mul nuw nsw i32 %52, %38, !dbg !41 + %67 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %65, i32 1, i32 31), !dbg !36 + %68 = add i32 %65, %67, !dbg !39 + %69 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %66, i32 1, i32 31), !dbg !36 + %70 = add i32 %66, %69, !dbg !39 + %71 = mul nuw nsw i32 %51, %35, !dbg !42 + %72 = mul nuw nsw i32 %52, %35, !dbg !42 + %73 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %71, i32 1, i32 31), !dbg !36 + %74 = add i32 %71, %73, !dbg !39 + %75 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %72, i32 1, i32 31), !dbg !36 + %76 = add i32 %72, %75, !dbg !39 + %77 = trunc i32 %36 to i1, !dbg !28 + %78 = insertelement <2 x i32> poison, i32 %56, i64 0, !dbg !28 + %79 = insertelement <2 x i32> %78, i32 %58, i64 1, !dbg !28 + %80 = insertelement <2 x i32> poison, i32 %62, i64 0, !dbg !28 + %81 = insertelement <2 x i32> %80, i32 %64, i64 1, !dbg !28 + %82 = icmp sge <2 x i32> %79, %81, !dbg !28 + %83 = icmp ne <2 x i32> %79, %81, !dbg !28 + %84 = insertelement <2 x i32> poison, i32 %68, i64 0, !dbg !28 + %85 = insertelement <2 x i32> %84, i32 %70, i64 1, !dbg !28 + %86 = insertelement <2 x i32> poison, i32 %74, i64 0, !dbg !28 + %87 = insertelement <2 x i32> %86, i32 %76, i64 1, !dbg !28 + %88 = icmp sle <2 x i32> %85, %87, !dbg !28 + %89 = or <2 x i1> %83, %88, !dbg !28 + %90 = and <2 x i1> %82, %89, !dbg !28 + %91 = insertelement <2 x i1> poison, i1 %77, i64 0, !dbg !28 + %92 = shufflevector <2 x i1> %91, <2 x i1> poison, <2 x i32> zeroinitializer, !dbg !28 + %93 = xor <2 x i1> %90, %92, !dbg !28 + %94 = xor <2 x i32> %79, %81, !dbg !29 + %95 = select <2 x i1> %93, <2 x i32> zeroinitializer, <2 x i32> %94, !dbg !30 + %96 = insertelement <2 x i32> poison, i32 %47, i64 0, !dbg !31 + %97 = insertelement <2 x i32> %96, i32 %48, i64 1, !dbg !31 + %98 = xor <2 x i32> %95, %97, !dbg !31 + %99 = xor <2 x i32> %85, %87, !dbg !32 + %100 = select <2 x i1> %93, <2 x i32> zeroinitializer, <2 x i32> %99, !dbg !33 + %101 = insertelement <2 x i32> poison, i32 %51, i64 0, !dbg !34 + %102 = insertelement <2 x i32> %101, i32 %52, i64 1, !dbg !34 + %103 = xor <2 x i32> %100, %102, !dbg !34 + %104 = extractelement <2 x i32> %98, i64 0, !dbg !28 + %105 = extractelement <2 x i32> %98, i64 1, !dbg !28 + %106 = icmp sge i32 %104, %105, !dbg !28 + %107 = icmp ne i32 %104, %105, !dbg !28 + %108 = extractelement <2 x i32> %103, i64 0, !dbg !28 + %109 = extractelement <2 x i32> %103, i64 1, !dbg !28 + %110 = icmp sle i32 %108, %109, !dbg !28 + %111 = or i1 %107, %110, !dbg !28 + %112 = and i1 %106, %111, !dbg !28 + %.not3 = xor i1 %112, %77, !dbg !28 + %113 = xor i32 %104, %105, !dbg !29 + %114 = select i1 %.not3, i32 0, i32 %113, !dbg !30 + %115 = xor i32 %108, %109, !dbg !32 + %116 = select i1 %.not3, i32 0, i32 %115, !dbg !33 + %117 = xor i32 %116, %108, !dbg !34 + %118 = xor i32 %116, %109, !dbg !34 + %119 = mul nuw nsw i32 %117, %39, !dbg !41 + %120 = mul nuw nsw i32 %118, %39, !dbg !41 + %121 = mul nuw nsw i32 %117, %.lobit, !dbg !42 + %122 = mul nuw nsw i32 %118, %.lobit, !dbg !42 + %123 = trunc i32 %37 to i1, !dbg !28 + %124 = insertelement <2 x i32> poison, i32 %114, i64 0, !dbg !31 + %125 = shufflevector <2 x i32> %124, <2 x i32> poison, <2 x i32> zeroinitializer, !dbg !31 + %126 = xor <2 x i32> %125, %98, !dbg !31 + %127 = extractelement <2 x i32> %126, i64 0, !dbg !40 + %128 = mul nuw nsw i32 %127, %39, !dbg !35 + %129 = extractelement <2 x i32> %126, i64 1, !dbg !40 + %130 = mul nuw nsw i32 %129, %39, !dbg !35 + %131 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %128, i32 2, i32 31), !dbg !36 + %132 = add i32 %128, %131, !dbg !39 + %133 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %130, i32 2, i32 31), !dbg !36 + %134 = add i32 %130, %133, !dbg !39 + %135 = mul nuw nsw i32 %127, %.lobit, !dbg !40 + %136 = mul nuw nsw i32 %129, %.lobit, !dbg !40 + %137 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %135, i32 2, i32 31), !dbg !36 + %138 = add i32 %135, %137, !dbg !39 + %139 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %136, i32 2, i32 31), !dbg !36 + %140 = add i32 %136, %139, !dbg !39 + %141 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %119, i32 2, i32 31), !dbg !36 + %142 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %120, i32 2, i32 31), !dbg !36 + %143 = insertelement <2 x i32> poison, i32 %119, i64 0, !dbg !39 + %144 = insertelement <2 x i32> %143, i32 %120, i64 1, !dbg !39 + %145 = insertelement <2 x i32> poison, i32 %141, i64 0, !dbg !39 + %146 = insertelement <2 x i32> %145, i32 %142, i64 1, !dbg !39 + %147 = add <2 x i32> %144, %146, !dbg !39 + %148 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %121, i32 2, i32 31), !dbg !36 + %149 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %122, i32 2, i32 31), !dbg !36 + %150 = insertelement <2 x i32> poison, i32 %121, i64 0, !dbg !39 + %151 = insertelement <2 x i32> %150, i32 %122, i64 1, !dbg !39 + %152 = insertelement <2 x i32> poison, i32 %148, i64 0, !dbg !39 + %153 = insertelement <2 x i32> %152, i32 %149, i64 1, !dbg !39 + %154 = add <2 x i32> %151, %153, !dbg !39 + %155 = insertelement <2 x i32> poison, i32 %132, i64 0, !dbg !28 + %156 = insertelement <2 x i32> %155, i32 %134, i64 1, !dbg !28 + %157 = insertelement <2 x i32> poison, i32 %138, i64 0, !dbg !28 + %158 = insertelement <2 x i32> %157, i32 %140, i64 1, !dbg !28 + %159 = icmp sge <2 x i32> %156, %158, !dbg !28 + %160 = icmp ne <2 x i32> %156, %158, !dbg !28 + %161 = xor <2 x i32> %156, %158, !dbg !29 + %162 = icmp sle <2 x i32> %147, %154, !dbg !28 + %163 = or <2 x i1> %160, %162, !dbg !28 + %164 = and <2 x i1> %159, %163, !dbg !28 + %165 = insertelement <2 x i1> poison, i1 %123, i64 0, !dbg !28 + %166 = shufflevector <2 x i1> %165, <2 x i1> poison, <2 x i32> zeroinitializer, !dbg !28 + %167 = xor <2 x i1> %164, %166, !dbg !28 + %168 = select <2 x i1> %167, <2 x i32> zeroinitializer, <2 x i32> %161, !dbg !30 + %169 = xor <2 x i32> %168, %126, !dbg !31 + %170 = xor <2 x i32> %147, %154, !dbg !32 + %171 = select <2 x i1> %167, <2 x i32> zeroinitializer, <2 x i32> %170, !dbg !33 + %172 = insertelement <2 x i32> poison, i32 %117, i64 0, !dbg !34 + %173 = insertelement <2 x i32> %172, i32 %118, i64 1, !dbg !34 + %174 = xor <2 x i32> %171, %173, !dbg !34 + %175 = extractelement <2 x i32> %169, i64 0, !dbg !40 + %176 = mul nuw nsw i32 %175, %38, !dbg !35 + %177 = extractelement <2 x i32> %169, i64 1, !dbg !40 + %178 = mul nuw nsw i32 %177, %38, !dbg !35 + %179 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %176, i32 1, i32 31), !dbg !36 + %180 = add i32 %176, %179, !dbg !39 + %181 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %178, i32 1, i32 31), !dbg !36 + %182 = add i32 %178, %181, !dbg !39 + %183 = mul nuw nsw i32 %175, %35, !dbg !40 + %184 = mul nuw nsw i32 %177, %35, !dbg !40 + %185 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %183, i32 1, i32 31), !dbg !36 + %186 = add i32 %183, %185, !dbg !39 + %187 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %184, i32 1, i32 31), !dbg !36 + %188 = add i32 %184, %187, !dbg !39 + %189 = extractelement <2 x i32> %174, i64 0, !dbg !42 + %190 = mul nuw nsw i32 %189, %38, !dbg !41 + %191 = extractelement <2 x i32> %174, i64 1, !dbg !42 + %192 = mul nuw nsw i32 %191, %38, !dbg !41 + %193 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %190, i32 1, i32 31), !dbg !36 + %194 = add i32 %190, %193, !dbg !39 + %195 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %192, i32 1, i32 31), !dbg !36 + %196 = add i32 %192, %195, !dbg !39 + %197 = mul nuw nsw i32 %189, %35, !dbg !42 + %198 = mul nuw nsw i32 %191, %35, !dbg !42 + %199 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %197, i32 1, i32 31), !dbg !36 + %200 = add i32 %197, %199, !dbg !39 + %201 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %198, i32 1, i32 31), !dbg !36 + %202 = add i32 %198, %201, !dbg !39 + %203 = insertelement <2 x i32> poison, i32 %180, i64 0, !dbg !28 + %204 = insertelement <2 x i32> %203, i32 %182, i64 1, !dbg !28 + %205 = insertelement <2 x i32> poison, i32 %186, i64 0, !dbg !28 + %206 = insertelement <2 x i32> %205, i32 %188, i64 1, !dbg !28 + %207 = icmp sge <2 x i32> %204, %206, !dbg !28 + %208 = icmp ne <2 x i32> %204, %206, !dbg !28 + %209 = insertelement <2 x i32> poison, i32 %194, i64 0, !dbg !28 + %210 = insertelement <2 x i32> %209, i32 %196, i64 1, !dbg !28 + %211 = insertelement <2 x i32> poison, i32 %200, i64 0, !dbg !28 + %212 = insertelement <2 x i32> %211, i32 %202, i64 1, !dbg !28 + %213 = icmp sle <2 x i32> %210, %212, !dbg !28 + %214 = or <2 x i1> %208, %213, !dbg !28 + %215 = and <2 x i1> %207, %214, !dbg !28 + %216 = xor <2 x i1> %215, %166, !dbg !28 + %217 = xor <2 x i32> %204, %206, !dbg !29 + %218 = select <2 x i1> %216, <2 x i32> zeroinitializer, <2 x i32> %217, !dbg !30 + %219 = xor <2 x i32> %218, %169, !dbg !31 + %220 = xor <2 x i32> %210, %212, !dbg !32 + %221 = select <2 x i1> %216, <2 x i32> zeroinitializer, <2 x i32> %220, !dbg !33 + %222 = xor <2 x i32> %221, %174, !dbg !34 + %223 = extractelement <2 x i32> %219, i64 0, !dbg !28 + %224 = extractelement <2 x i32> %219, i64 1, !dbg !28 + %225 = icmp sge i32 %223, %224, !dbg !28 + %226 = icmp ne i32 %223, %224, !dbg !28 + %227 = extractelement <2 x i32> %222, i64 0, !dbg !28 + %228 = extractelement <2 x i32> %222, i64 1, !dbg !28 + %229 = icmp sle i32 %227, %228, !dbg !28 + %230 = or i1 %226, %229, !dbg !28 + %231 = and i1 %225, %230, !dbg !28 + %.not8 = xor i1 %231, %123, !dbg !28 + %232 = xor i32 %223, %224, !dbg !29 + %233 = select i1 %.not8, i32 0, i32 %232, !dbg !30 + %234 = insertelement <2 x i32> poison, i32 %233, i64 0, !dbg !31 + %235 = shufflevector <2 x i32> %234, <2 x i32> poison, <2 x i32> zeroinitializer, !dbg !31 + %236 = xor <2 x i32> %235, %219, !dbg !31 + %237 = xor i32 %227, %228, !dbg !32 + %238 = select i1 %.not8, i32 0, i32 %237, !dbg !33 + %239 = insertelement <2 x i32> poison, i32 %238, i64 0, !dbg !34 + %240 = shufflevector <2 x i32> %239, <2 x i32> poison, <2 x i32> zeroinitializer, !dbg !34 + %241 = xor <2 x i32> %240, %222, !dbg !34 + %242 = extractelement <2 x i32> %236, i64 0, !dbg !40 + %243 = mul nuw nsw i32 %242, %40, !dbg !35 + %244 = extractelement <2 x i32> %236, i64 1, !dbg !40 + %245 = mul nuw nsw i32 %244, %40, !dbg !35 + %246 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %243, i32 4, i32 31), !dbg !36 + %247 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %245, i32 4, i32 31), !dbg !36 + %248 = insertelement <2 x i32> poison, i32 %243, i64 0, !dbg !39 + %249 = insertelement <2 x i32> %248, i32 %245, i64 1, !dbg !39 + %250 = insertelement <2 x i32> poison, i32 %246, i64 0, !dbg !39 + %251 = insertelement <2 x i32> %250, i32 %247, i64 1, !dbg !39 + %252 = add <2 x i32> %249, %251, !dbg !39 + %253 = mul nuw nsw i32 %242, %.lobit1, !dbg !40 + %254 = mul nuw nsw i32 %244, %.lobit1, !dbg !40 + %255 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %253, i32 4, i32 31), !dbg !36 + %256 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %254, i32 4, i32 31), !dbg !36 + %257 = insertelement <2 x i32> poison, i32 %253, i64 0, !dbg !39 + %258 = insertelement <2 x i32> %257, i32 %254, i64 1, !dbg !39 + %259 = insertelement <2 x i32> poison, i32 %255, i64 0, !dbg !39 + %260 = insertelement <2 x i32> %259, i32 %256, i64 1, !dbg !39 + %261 = add <2 x i32> %258, %260, !dbg !39 + %262 = extractelement <2 x i32> %241, i64 0, !dbg !42 + %263 = mul nuw nsw i32 %262, %40, !dbg !41 + %264 = extractelement <2 x i32> %241, i64 1, !dbg !42 + %265 = mul nuw nsw i32 %264, %40, !dbg !41 + %266 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %263, i32 4, i32 31), !dbg !36 + %267 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %265, i32 4, i32 31), !dbg !36 + %268 = insertelement <2 x i32> poison, i32 %263, i64 0, !dbg !39 + %269 = insertelement <2 x i32> %268, i32 %265, i64 1, !dbg !39 + %270 = insertelement <2 x i32> poison, i32 %266, i64 0, !dbg !39 + %271 = insertelement <2 x i32> %270, i32 %267, i64 1, !dbg !39 + %272 = add <2 x i32> %269, %271, !dbg !39 + %273 = mul nuw nsw i32 %262, %.lobit1, !dbg !42 + %274 = mul nuw nsw i32 %264, %.lobit1, !dbg !42 + %275 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %273, i32 4, i32 31), !dbg !36 + %276 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %274, i32 4, i32 31), !dbg !36 + %277 = insertelement <2 x i32> poison, i32 %273, i64 0, !dbg !39 + %278 = insertelement <2 x i32> %277, i32 %274, i64 1, !dbg !39 + %279 = insertelement <2 x i32> poison, i32 %275, i64 0, !dbg !39 + %280 = insertelement <2 x i32> %279, i32 %276, i64 1, !dbg !39 + %281 = add <2 x i32> %278, %280, !dbg !39 + %282 = icmp slt <2 x i32> %252, %261, !dbg !27 + %283 = icmp eq <2 x i32> %252, %261, !dbg !43 + %284 = icmp sgt <2 x i32> %272, %281, !dbg !44 + %285 = and <2 x i1> %283, %284, !dbg !45 + %286 = or <2 x i1> %282, %285, !dbg !46 + %287 = xor <2 x i32> %252, %261, !dbg !29 + %288 = select <2 x i1> %286, <2 x i32> %287, <2 x i32> zeroinitializer, !dbg !30 + %289 = xor <2 x i32> %288, %236, !dbg !31 + %290 = xor <2 x i32> %272, %281, !dbg !32 + %291 = select <2 x i1> %286, <2 x i32> %290, <2 x i32> zeroinitializer, !dbg !33 + %292 = xor <2 x i32> %291, %241, !dbg !34 + %293 = insertelement <2 x i32> poison, i32 %39, i64 0, !dbg !35 + %294 = shufflevector <2 x i32> %293, <2 x i32> poison, <2 x i32> zeroinitializer, !dbg !35 + %295 = mul nuw nsw <2 x i32> %289, %294, !dbg !35 + %296 = extractelement <2 x i32> %295, i64 0, !dbg !36 + %297 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %296, i32 2, i32 31), !dbg !36 + %298 = extractelement <2 x i32> %295, i64 1, !dbg !36 + %299 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %298, i32 2, i32 31), !dbg !36 + %300 = insertelement <2 x i32> poison, i32 %297, i64 0, !dbg !39 + %301 = insertelement <2 x i32> %300, i32 %299, i64 1, !dbg !39 + %302 = add <2 x i32> %295, %301, !dbg !39 + %303 = insertelement <2 x i32> poison, i32 %.lobit, i64 0, !dbg !40 + %304 = shufflevector <2 x i32> %303, <2 x i32> poison, <2 x i32> zeroinitializer, !dbg !40 + %305 = mul nuw nsw <2 x i32> %289, %304, !dbg !40 + %306 = extractelement <2 x i32> %305, i64 0, !dbg !36 + %307 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %306, i32 2, i32 31), !dbg !36 + %308 = extractelement <2 x i32> %305, i64 1, !dbg !36 + %309 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %308, i32 2, i32 31), !dbg !36 + %310 = insertelement <2 x i32> poison, i32 %307, i64 0, !dbg !39 + %311 = insertelement <2 x i32> %310, i32 %309, i64 1, !dbg !39 + %312 = add <2 x i32> %305, %311, !dbg !39 + %313 = extractelement <2 x i32> %292, i64 0, !dbg !42 + %314 = mul nuw nsw i32 %313, %39, !dbg !41 + %315 = extractelement <2 x i32> %292, i64 1, !dbg !42 + %316 = mul nuw nsw i32 %315, %39, !dbg !41 + %317 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %314, i32 2, i32 31), !dbg !36 + %318 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %316, i32 2, i32 31), !dbg !36 + %319 = insertelement <2 x i32> poison, i32 %314, i64 0, !dbg !39 + %320 = insertelement <2 x i32> %319, i32 %316, i64 1, !dbg !39 + %321 = insertelement <2 x i32> poison, i32 %317, i64 0, !dbg !39 + %322 = insertelement <2 x i32> %321, i32 %318, i64 1, !dbg !39 + %323 = add <2 x i32> %320, %322, !dbg !39 + %324 = mul nuw nsw i32 %313, %.lobit, !dbg !42 + %325 = mul nuw nsw i32 %315, %.lobit, !dbg !42 + %326 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %324, i32 2, i32 31), !dbg !36 + %327 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %325, i32 2, i32 31), !dbg !36 + %328 = insertelement <2 x i32> poison, i32 %324, i64 0, !dbg !39 + %329 = insertelement <2 x i32> %328, i32 %325, i64 1, !dbg !39 + %330 = insertelement <2 x i32> poison, i32 %326, i64 0, !dbg !39 + %331 = insertelement <2 x i32> %330, i32 %327, i64 1, !dbg !39 + %332 = add <2 x i32> %329, %331, !dbg !39 + %333 = icmp slt <2 x i32> %302, %312, !dbg !27 + %334 = icmp eq <2 x i32> %302, %312, !dbg !43 + %335 = icmp sgt <2 x i32> %323, %332, !dbg !44 + %336 = and <2 x i1> %334, %335, !dbg !45 + %337 = or <2 x i1> %333, %336, !dbg !46 + %338 = xor <2 x i32> %302, %312, !dbg !29 + %339 = select <2 x i1> %337, <2 x i32> %338, <2 x i32> zeroinitializer, !dbg !30 + %340 = xor <2 x i32> %339, %289, !dbg !31 + %341 = xor <2 x i32> %323, %332, !dbg !32 + %342 = select <2 x i1> %337, <2 x i32> %341, <2 x i32> zeroinitializer, !dbg !33 + %343 = xor <2 x i32> %342, %292, !dbg !34 + %344 = insertelement <2 x i32> poison, i32 %38, i64 0, !dbg !35 + %345 = shufflevector <2 x i32> %344, <2 x i32> poison, <2 x i32> zeroinitializer, !dbg !35 + %346 = mul nuw nsw <2 x i32> %340, %345, !dbg !35 + %347 = extractelement <2 x i32> %346, i64 0, !dbg !36 + %348 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %347, i32 1, i32 31), !dbg !36 + %349 = extractelement <2 x i32> %346, i64 1, !dbg !36 + %350 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %349, i32 1, i32 31), !dbg !36 + %351 = insertelement <2 x i32> poison, i32 %348, i64 0, !dbg !39 + %352 = insertelement <2 x i32> %351, i32 %350, i64 1, !dbg !39 + %353 = add <2 x i32> %346, %352, !dbg !39 + %354 = insertelement <2 x i32> poison, i32 %35, i64 0, !dbg !40 + %355 = shufflevector <2 x i32> %354, <2 x i32> poison, <2 x i32> zeroinitializer, !dbg !40 + %356 = mul nuw nsw <2 x i32> %340, %355, !dbg !40 + %357 = extractelement <2 x i32> %356, i64 0, !dbg !36 + %358 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %357, i32 1, i32 31), !dbg !36 + %359 = extractelement <2 x i32> %356, i64 1, !dbg !36 + %360 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %359, i32 1, i32 31), !dbg !36 + %361 = insertelement <2 x i32> poison, i32 %358, i64 0, !dbg !39 + %362 = insertelement <2 x i32> %361, i32 %360, i64 1, !dbg !39 + %363 = add <2 x i32> %356, %362, !dbg !39 + %364 = mul nuw nsw <2 x i32> %343, %345, !dbg !41 + %365 = extractelement <2 x i32> %364, i64 0, !dbg !36 + %366 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %365, i32 1, i32 31), !dbg !36 + %367 = extractelement <2 x i32> %364, i64 1, !dbg !36 + %368 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %367, i32 1, i32 31), !dbg !36 + %369 = insertelement <2 x i32> poison, i32 %366, i64 0, !dbg !39 + %370 = insertelement <2 x i32> %369, i32 %368, i64 1, !dbg !39 + %371 = add <2 x i32> %364, %370, !dbg !39 + %372 = mul nuw nsw <2 x i32> %343, %355, !dbg !42 + %373 = extractelement <2 x i32> %372, i64 0, !dbg !36 + %374 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %373, i32 1, i32 31), !dbg !36 + %375 = extractelement <2 x i32> %372, i64 1, !dbg !36 + %376 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %375, i32 1, i32 31), !dbg !36 + %377 = insertelement <2 x i32> poison, i32 %374, i64 0, !dbg !39 + %378 = insertelement <2 x i32> %377, i32 %376, i64 1, !dbg !39 + %379 = add <2 x i32> %372, %378, !dbg !39 + %380 = icmp slt <2 x i32> %353, %363, !dbg !27 + %381 = icmp eq <2 x i32> %353, %363, !dbg !43 + %382 = icmp sgt <2 x i32> %371, %379, !dbg !44 + %383 = and <2 x i1> %381, %382, !dbg !45 + %384 = or <2 x i1> %380, %383, !dbg !46 + %385 = xor <2 x i32> %353, %363, !dbg !29 + %386 = select <2 x i1> %384, <2 x i32> %385, <2 x i32> zeroinitializer, !dbg !30 + %387 = xor <2 x i32> %386, %340, !dbg !31 + %388 = xor <2 x i32> %371, %379, !dbg !32 + %389 = select <2 x i1> %384, <2 x i32> %388, <2 x i32> zeroinitializer, !dbg !33 + %390 = xor <2 x i32> %389, %343, !dbg !34 + %391 = extractelement <2 x i32> %387, i64 0, !dbg !27 + %392 = extractelement <2 x i32> %387, i64 1, !dbg !27 + %393 = icmp slt i32 %391, %392, !dbg !27 + %394 = icmp eq i32 %391, %392, !dbg !43 + %395 = extractelement <2 x i32> %390, i64 0, !dbg !44 + %396 = extractelement <2 x i32> %390, i64 1, !dbg !44 + %397 = icmp sgt i32 %395, %396, !dbg !44 + %398 = and i1 %394, %397, !dbg !45 + %399 = or i1 %393, %398, !dbg !46 + %400 = xor i32 %395, %396, !dbg !32 + %401 = select i1 %399, i32 %400, i32 0, !dbg !33 + %402 = xor i32 %401, %395, !dbg !34 + %403 = xor i32 %401, %396, !dbg !34 + %404 = icmp eq i64 %27, 16384, !dbg !47 + %405 = icmp eq i64 %28, 16384, !dbg !47 + %406 = zext i1 %404 to i32, !dbg !21 + %407 = zext i1 %405 to i32, !dbg !21 + %408 = xor i1 %404, true, !dbg !48 + %409 = and i1 %405, %408, !dbg !48 + %410 = xor i1 %409, %43, !dbg !50 + %411 = xor i32 %406, %407, !dbg !51 + %412 = select i1 %410, i32 %411, i32 0, !dbg !52 + %413 = xor i32 %412, %406, !dbg !53 + %414 = xor i32 %412, %407, !dbg !53 + %415 = select i1 %410, i32 %49, i32 0, !dbg !54 + %416 = xor i32 %415, %20, !dbg !55 + %417 = xor i32 %415, %21, !dbg !55 + %418 = mul nuw nsw i32 %413, %38, !dbg !56 + %419 = mul nuw nsw i32 %414, %38, !dbg !56 + %420 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %418, i32 1, i32 31), !dbg !57 + %421 = add i32 %420, %418, !dbg !58 + %422 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %419, i32 1, i32 31), !dbg !57 + %423 = add i32 %422, %419, !dbg !58 + %424 = mul nuw nsw i32 %413, %35, !dbg !59 + %425 = mul nuw nsw i32 %414, %35, !dbg !59 + %426 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %424, i32 1, i32 31), !dbg !57 + %427 = add i32 %426, %424, !dbg !58 + %428 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %425, i32 1, i32 31), !dbg !57 + %429 = add i32 %428, %425, !dbg !58 + %430 = mul nuw nsw i32 %416, %38, !dbg !60 + %431 = mul nuw nsw i32 %417, %38, !dbg !60 + %432 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %430, i32 1, i32 31), !dbg !57 + %433 = add i32 %432, %430, !dbg !58 + %434 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %431, i32 1, i32 31), !dbg !57 + %435 = add i32 %434, %431, !dbg !58 + %436 = mul nuw nsw i32 %416, %35, !dbg !61 + %437 = mul nuw nsw i32 %417, %35, !dbg !61 + %438 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %436, i32 1, i32 31), !dbg !57 + %439 = add i32 %438, %436, !dbg !58 + %440 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %437, i32 1, i32 31), !dbg !57 + %441 = add i32 %440, %437, !dbg !58 + %442 = insertelement <2 x i32> poison, i32 %421, i64 0, !dbg !50 + %443 = insertelement <2 x i32> %442, i32 %423, i64 1, !dbg !50 + %444 = insertelement <2 x i32> poison, i32 %427, i64 0, !dbg !50 + %445 = insertelement <2 x i32> %444, i32 %429, i64 1, !dbg !50 + %446 = icmp sge <2 x i32> %443, %445, !dbg !50 + %447 = icmp ne <2 x i32> %443, %445, !dbg !50 + %448 = insertelement <2 x i32> poison, i32 %433, i64 0, !dbg !50 + %449 = insertelement <2 x i32> %448, i32 %435, i64 1, !dbg !50 + %450 = insertelement <2 x i32> poison, i32 %439, i64 0, !dbg !50 + %451 = insertelement <2 x i32> %450, i32 %441, i64 1, !dbg !50 + %452 = icmp sle <2 x i32> %449, %451, !dbg !50 + %453 = or <2 x i1> %447, %452, !dbg !50 + %454 = and <2 x i1> %446, %453, !dbg !50 + %455 = xor <2 x i1> %454, %92, !dbg !50 + %456 = insertelement <2 x i32> %451, i32 %429, i64 1, !dbg !51 + %457 = insertelement <2 x i32> %449, i32 %423, i64 1, !dbg !51 + %458 = xor <2 x i32> %456, %457, !dbg !51 + %459 = insertelement <2 x i32> poison, i32 %441, i64 0, !dbg !51 + %460 = insertelement <2 x i32> %459, i32 %427, i64 1, !dbg !51 + %461 = insertelement <2 x i32> poison, i32 %435, i64 0, !dbg !51 + %462 = insertelement <2 x i32> %461, i32 %421, i64 1, !dbg !51 + %463 = xor <2 x i32> %460, %462, !dbg !51 + %464 = select <2 x i1> %455, <2 x i32> zeroinitializer, <2 x i32> %458, !dbg !52 + %465 = shufflevector <2 x i1> %455, <2 x i1> poison, <2 x i32> , !dbg !52 + %466 = select <2 x i1> %465, <2 x i32> zeroinitializer, <2 x i32> %463, !dbg !52 + %467 = insertelement <2 x i32> poison, i32 %416, i64 0, !dbg !53 + %468 = insertelement <2 x i32> %467, i32 %414, i64 1, !dbg !53 + %469 = xor <2 x i32> %464, %468, !dbg !53 + %470 = insertelement <2 x i32> poison, i32 %417, i64 0, !dbg !53 + %471 = insertelement <2 x i32> %470, i32 %413, i64 1, !dbg !53 + %472 = xor <2 x i32> %466, %471, !dbg !53 + %473 = extractelement <2 x i32> %469, i64 1, !dbg !50 + %474 = extractelement <2 x i32> %472, i64 1, !dbg !50 + %475 = icmp sge i32 %474, %473, !dbg !50 + %476 = icmp ne <2 x i32> %469, %472, !dbg !50 + %477 = icmp sle <2 x i32> %469, %472, !dbg !50 + %shift = shufflevector <2 x i1> %476, <2 x i1> poison, <2 x i32> , !dbg !50 + %foldExtExtBinop = or <2 x i1> %shift, %477, !dbg !50 + %478 = extractelement <2 x i1> %foldExtExtBinop, i64 0, !dbg !50 + %479 = and i1 %475, %478, !dbg !50 + %.not12 = xor i1 %479, %77, !dbg !50 + %480 = xor i32 %473, %474, !dbg !51 + %481 = select i1 %.not12, i32 0, i32 %480, !dbg !52 + %482 = xor i32 %481, %474, !dbg !53 + %483 = xor i32 %481, %473, !dbg !53 + %484 = extractelement <2 x i32> %469, i64 0, !dbg !62 + %485 = extractelement <2 x i32> %472, i64 0, !dbg !62 + %486 = xor i32 %485, %484, !dbg !62 + %487 = select i1 %.not12, i32 0, i32 %486, !dbg !54 + %488 = xor i32 %487, %484, !dbg !55 + %489 = xor i32 %487, %485, !dbg !55 + %490 = mul nuw nsw i32 %482, %39, !dbg !56 + %491 = mul nuw nsw i32 %483, %39, !dbg !56 + %492 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %490, i32 2, i32 31), !dbg !57 + %493 = add i32 %490, %492, !dbg !58 + %494 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %491, i32 2, i32 31), !dbg !57 + %495 = add i32 %491, %494, !dbg !58 + %496 = mul nuw nsw i32 %482, %.lobit, !dbg !59 + %497 = mul nuw nsw i32 %483, %.lobit, !dbg !59 + %498 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %496, i32 2, i32 31), !dbg !57 + %499 = add i32 %496, %498, !dbg !58 + %500 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %497, i32 2, i32 31), !dbg !57 + %501 = add i32 %497, %500, !dbg !58 + %502 = mul nuw nsw i32 %488, %39, !dbg !60 + %503 = mul nuw nsw i32 %489, %39, !dbg !60 + %504 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %502, i32 2, i32 31), !dbg !57 + %505 = add i32 %502, %504, !dbg !58 + %506 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %503, i32 2, i32 31), !dbg !57 + %507 = add i32 %503, %506, !dbg !58 + %508 = mul nuw nsw i32 %488, %.lobit, !dbg !61 + %509 = mul nuw nsw i32 %489, %.lobit, !dbg !61 + %510 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %508, i32 2, i32 31), !dbg !57 + %511 = add i32 %508, %510, !dbg !58 + %512 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %509, i32 2, i32 31), !dbg !57 + %513 = add i32 %509, %512, !dbg !58 + %514 = insertelement <2 x i32> poison, i32 %493, i64 0, !dbg !50 + %515 = insertelement <2 x i32> %514, i32 %495, i64 1, !dbg !50 + %516 = insertelement <2 x i32> poison, i32 %499, i64 0, !dbg !50 + %517 = insertelement <2 x i32> %516, i32 %501, i64 1, !dbg !50 + %518 = icmp sge <2 x i32> %515, %517, !dbg !50 + %519 = icmp ne <2 x i32> %515, %517, !dbg !50 + %520 = xor i32 %493, %499, !dbg !51 + %521 = xor i32 %495, %501, !dbg !51 + %522 = insertelement <2 x i32> poison, i32 %505, i64 0, !dbg !50 + %523 = insertelement <2 x i32> %522, i32 %507, i64 1, !dbg !50 + %524 = insertelement <2 x i32> poison, i32 %511, i64 0, !dbg !50 + %525 = insertelement <2 x i32> %524, i32 %513, i64 1, !dbg !50 + %526 = icmp sle <2 x i32> %523, %525, !dbg !50 + %527 = or <2 x i1> %519, %526, !dbg !50 + %528 = and <2 x i1> %518, %527, !dbg !50 + %529 = xor <2 x i1> %528, %166, !dbg !50 + %530 = extractelement <2 x i1> %529, i64 0, !dbg !52 + %531 = select i1 %530, i32 0, i32 %520, !dbg !52 + %532 = extractelement <2 x i1> %529, i64 1, !dbg !52 + %533 = select i1 %532, i32 0, i32 %521, !dbg !52 + %534 = xor i32 %531, %482, !dbg !53 + %535 = xor i32 %533, %483, !dbg !53 + %536 = xor <2 x i32> %523, %525, !dbg !62 + %537 = select <2 x i1> %529, <2 x i32> zeroinitializer, <2 x i32> %536, !dbg !54 + %538 = insertelement <2 x i32> poison, i32 %488, i64 0, !dbg !55 + %539 = insertelement <2 x i32> %538, i32 %489, i64 1, !dbg !55 + %540 = xor <2 x i32> %537, %539, !dbg !55 + %541 = mul nuw nsw i32 %534, %38, !dbg !56 + %542 = mul nuw nsw i32 %535, %38, !dbg !56 + %543 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %541, i32 1, i32 31), !dbg !57 + %544 = add i32 %541, %543, !dbg !58 + %545 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %542, i32 1, i32 31), !dbg !57 + %546 = add i32 %542, %545, !dbg !58 + %547 = mul nuw nsw i32 %534, %35, !dbg !59 + %548 = mul nuw nsw i32 %535, %35, !dbg !59 + %549 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %547, i32 1, i32 31), !dbg !57 + %550 = add i32 %547, %549, !dbg !58 + %551 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %548, i32 1, i32 31), !dbg !57 + %552 = add i32 %548, %551, !dbg !58 + %553 = extractelement <2 x i32> %540, i64 0, !dbg !61 + %554 = mul nuw nsw i32 %553, %38, !dbg !60 + %555 = extractelement <2 x i32> %540, i64 1, !dbg !61 + %556 = mul nuw nsw i32 %555, %38, !dbg !60 + %557 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %554, i32 1, i32 31), !dbg !57 + %558 = add i32 %554, %557, !dbg !58 + %559 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %556, i32 1, i32 31), !dbg !57 + %560 = add i32 %556, %559, !dbg !58 + %561 = mul nuw nsw i32 %553, %35, !dbg !61 + %562 = mul nuw nsw i32 %555, %35, !dbg !61 + %563 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %561, i32 1, i32 31), !dbg !57 + %564 = add i32 %561, %563, !dbg !58 + %565 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %562, i32 1, i32 31), !dbg !57 + %566 = add i32 %562, %565, !dbg !58 + %567 = insertelement <2 x i32> poison, i32 %544, i64 0, !dbg !50 + %568 = insertelement <2 x i32> %567, i32 %546, i64 1, !dbg !50 + %569 = insertelement <2 x i32> poison, i32 %550, i64 0, !dbg !50 + %570 = insertelement <2 x i32> %569, i32 %552, i64 1, !dbg !50 + %571 = icmp sge <2 x i32> %568, %570, !dbg !50 + %572 = icmp ne <2 x i32> %568, %570, !dbg !50 + %573 = insertelement <2 x i32> poison, i32 %558, i64 0, !dbg !50 + %574 = insertelement <2 x i32> %573, i32 %560, i64 1, !dbg !50 + %575 = insertelement <2 x i32> poison, i32 %564, i64 0, !dbg !50 + %576 = insertelement <2 x i32> %575, i32 %566, i64 1, !dbg !50 + %577 = icmp sle <2 x i32> %574, %576, !dbg !50 + %578 = or <2 x i1> %572, %577, !dbg !50 + %579 = and <2 x i1> %571, %578, !dbg !50 + %580 = xor <2 x i1> %579, %166, !dbg !50 + %581 = xor <2 x i32> %568, %570, !dbg !51 + %582 = select <2 x i1> %580, <2 x i32> zeroinitializer, <2 x i32> %581, !dbg !52 + %583 = insertelement <2 x i32> poison, i32 %534, i64 0, !dbg !53 + %584 = insertelement <2 x i32> %583, i32 %535, i64 1, !dbg !53 + %585 = xor <2 x i32> %582, %584, !dbg !53 + %586 = xor <2 x i32> %574, %576, !dbg !62 + %587 = select <2 x i1> %580, <2 x i32> zeroinitializer, <2 x i32> %586, !dbg !54 + %588 = xor <2 x i32> %587, %540, !dbg !55 + %589 = extractelement <2 x i32> %585, i64 0, !dbg !50 + %590 = extractelement <2 x i32> %585, i64 1, !dbg !50 + %591 = icmp sge i32 %589, %590, !dbg !50 + %592 = icmp ne i32 %589, %590, !dbg !50 + %593 = extractelement <2 x i32> %588, i64 0, !dbg !50 + %594 = extractelement <2 x i32> %588, i64 1, !dbg !50 + %595 = icmp sle i32 %593, %594, !dbg !50 + %596 = or i1 %592, %595, !dbg !50 + %597 = and i1 %591, %596, !dbg !50 + %.not17 = xor i1 %597, %123, !dbg !50 + %598 = xor i32 %589, %590, !dbg !51 + %599 = select i1 %.not17, i32 0, i32 %598, !dbg !52 + %600 = xor i32 %599, %589, !dbg !53 + %601 = xor i32 %599, %590, !dbg !53 + %602 = xor i32 %593, %594, !dbg !62 + %603 = select i1 %.not17, i32 0, i32 %602, !dbg !54 + %604 = xor i32 %603, %593, !dbg !55 + %605 = xor i32 %603, %594, !dbg !55 + %606 = mul nuw nsw i32 %600, %40, !dbg !56 + %607 = mul nuw nsw i32 %601, %40, !dbg !56 + %608 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %606, i32 4, i32 31), !dbg !57 + %609 = add i32 %606, %608, !dbg !58 + %610 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %607, i32 4, i32 31), !dbg !57 + %611 = add i32 %607, %610, !dbg !58 + %612 = mul nuw nsw i32 %600, %.lobit1, !dbg !59 + %613 = mul nuw nsw i32 %601, %.lobit1, !dbg !59 + %614 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %612, i32 4, i32 31), !dbg !57 + %615 = add i32 %612, %614, !dbg !58 + %616 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %613, i32 4, i32 31), !dbg !57 + %617 = add i32 %613, %616, !dbg !58 + %618 = mul nuw nsw i32 %604, %40, !dbg !60 + %619 = mul nuw nsw i32 %605, %40, !dbg !60 + %620 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %618, i32 4, i32 31), !dbg !57 + %621 = add i32 %618, %620, !dbg !58 + %622 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %619, i32 4, i32 31), !dbg !57 + %623 = add i32 %619, %622, !dbg !58 + %624 = mul nuw nsw i32 %604, %.lobit1, !dbg !61 + %625 = mul nuw nsw i32 %605, %.lobit1, !dbg !61 + %626 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %624, i32 4, i32 31), !dbg !57 + %627 = add i32 %624, %626, !dbg !58 + %628 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %625, i32 4, i32 31), !dbg !57 + %629 = add i32 %625, %628, !dbg !58 + %630 = insertelement <2 x i32> poison, i32 %609, i64 0, !dbg !48 + %631 = insertelement <2 x i32> %630, i32 %611, i64 1, !dbg !48 + %632 = insertelement <2 x i32> poison, i32 %615, i64 0, !dbg !48 + %633 = insertelement <2 x i32> %632, i32 %617, i64 1, !dbg !48 + %634 = icmp slt <2 x i32> %631, %633, !dbg !48 + %635 = icmp eq <2 x i32> %631, %633, !dbg !63 + %636 = insertelement <2 x i32> poison, i32 %621, i64 0, !dbg !64 + %637 = insertelement <2 x i32> %636, i32 %623, i64 1, !dbg !64 + %638 = insertelement <2 x i32> poison, i32 %627, i64 0, !dbg !64 + %639 = insertelement <2 x i32> %638, i32 %629, i64 1, !dbg !64 + %640 = icmp sgt <2 x i32> %637, %639, !dbg !64 + %641 = and <2 x i1> %635, %640, !dbg !65 + %642 = or <2 x i1> %634, %641, !dbg !66 + %643 = xor <2 x i32> %631, %633, !dbg !51 + %644 = select <2 x i1> %642, <2 x i32> %643, <2 x i32> zeroinitializer, !dbg !52 + %645 = insertelement <2 x i32> poison, i32 %600, i64 0, !dbg !53 + %646 = insertelement <2 x i32> %645, i32 %601, i64 1, !dbg !53 + %647 = xor <2 x i32> %644, %646, !dbg !53 + %648 = xor i32 %621, %627, !dbg !62 + %649 = xor i32 %623, %629, !dbg !62 + %650 = extractelement <2 x i1> %642, i64 0, !dbg !54 + %651 = select i1 %650, i32 %648, i32 0, !dbg !54 + %652 = extractelement <2 x i1> %642, i64 1, !dbg !54 + %653 = select i1 %652, i32 %649, i32 0, !dbg !54 + %654 = xor i32 %651, %604, !dbg !55 + %655 = xor i32 %653, %605, !dbg !55 + %656 = extractelement <2 x i32> %647, i64 0, !dbg !59 + %657 = mul nuw nsw i32 %656, %39, !dbg !56 + %658 = extractelement <2 x i32> %647, i64 1, !dbg !59 + %659 = mul nuw nsw i32 %658, %39, !dbg !56 + %660 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %657, i32 2, i32 31), !dbg !57 + %661 = add i32 %657, %660, !dbg !58 + %662 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %659, i32 2, i32 31), !dbg !57 + %663 = add i32 %659, %662, !dbg !58 + %664 = mul nuw nsw i32 %656, %.lobit, !dbg !59 + %665 = mul nuw nsw i32 %658, %.lobit, !dbg !59 + %666 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %664, i32 2, i32 31), !dbg !57 + %667 = add i32 %664, %666, !dbg !58 + %668 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %665, i32 2, i32 31), !dbg !57 + %669 = add i32 %665, %668, !dbg !58 + %670 = mul nuw nsw i32 %654, %39, !dbg !60 + %671 = mul nuw nsw i32 %655, %39, !dbg !60 + %672 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %670, i32 2, i32 31), !dbg !57 + %673 = add i32 %670, %672, !dbg !58 + %674 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %671, i32 2, i32 31), !dbg !57 + %675 = add i32 %671, %674, !dbg !58 + %676 = mul nuw nsw i32 %654, %.lobit, !dbg !61 + %677 = mul nuw nsw i32 %655, %.lobit, !dbg !61 + %678 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %676, i32 2, i32 31), !dbg !57 + %679 = add i32 %676, %678, !dbg !58 + %680 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %677, i32 2, i32 31), !dbg !57 + %681 = add i32 %677, %680, !dbg !58 + %682 = insertelement <2 x i32> poison, i32 %661, i64 0, !dbg !48 + %683 = insertelement <2 x i32> %682, i32 %663, i64 1, !dbg !48 + %684 = insertelement <2 x i32> poison, i32 %667, i64 0, !dbg !48 + %685 = insertelement <2 x i32> %684, i32 %669, i64 1, !dbg !48 + %686 = icmp slt <2 x i32> %683, %685, !dbg !48 + %687 = icmp eq <2 x i32> %683, %685, !dbg !63 + %688 = insertelement <2 x i32> poison, i32 %673, i64 0, !dbg !64 + %689 = insertelement <2 x i32> %688, i32 %675, i64 1, !dbg !64 + %690 = insertelement <2 x i32> poison, i32 %679, i64 0, !dbg !64 + %691 = insertelement <2 x i32> %690, i32 %681, i64 1, !dbg !64 + %692 = icmp sgt <2 x i32> %689, %691, !dbg !64 + %693 = and <2 x i1> %687, %692, !dbg !65 + %694 = or <2 x i1> %686, %693, !dbg !66 + %695 = xor <2 x i32> %683, %685, !dbg !51 + %696 = select <2 x i1> %694, <2 x i32> %695, <2 x i32> zeroinitializer, !dbg !52 + %697 = xor <2 x i32> %696, %647, !dbg !53 + %698 = xor i32 %673, %679, !dbg !62 + %699 = xor i32 %675, %681, !dbg !62 + %700 = extractelement <2 x i1> %694, i64 0, !dbg !54 + %701 = select i1 %700, i32 %698, i32 0, !dbg !54 + %702 = extractelement <2 x i1> %694, i64 1, !dbg !54 + %703 = select i1 %702, i32 %699, i32 0, !dbg !54 + %704 = xor i32 %701, %654, !dbg !55 + %705 = xor i32 %703, %655, !dbg !55 + %706 = extractelement <2 x i32> %697, i64 0, !dbg !59 + %707 = mul nuw nsw i32 %706, %38, !dbg !56 + %708 = extractelement <2 x i32> %697, i64 1, !dbg !59 + %709 = mul nuw nsw i32 %708, %38, !dbg !56 + %710 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %707, i32 1, i32 31), !dbg !57 + %711 = add i32 %707, %710, !dbg !58 + %712 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %709, i32 1, i32 31), !dbg !57 + %713 = add i32 %709, %712, !dbg !58 + %714 = mul nuw nsw i32 %706, %35, !dbg !59 + %715 = mul nuw nsw i32 %708, %35, !dbg !59 + %716 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %714, i32 1, i32 31), !dbg !57 + %717 = add i32 %714, %716, !dbg !58 + %718 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %715, i32 1, i32 31), !dbg !57 + %719 = add i32 %715, %718, !dbg !58 + %720 = mul nuw nsw i32 %704, %38, !dbg !60 + %721 = mul nuw nsw i32 %705, %38, !dbg !60 + %722 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %720, i32 1, i32 31), !dbg !57 + %723 = add i32 %720, %722, !dbg !58 + %724 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %721, i32 1, i32 31), !dbg !57 + %725 = add i32 %721, %724, !dbg !58 + %726 = mul nuw nsw i32 %704, %35, !dbg !61 + %727 = mul nuw nsw i32 %705, %35, !dbg !61 + %728 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %726, i32 1, i32 31), !dbg !57 + %729 = add i32 %726, %728, !dbg !58 + %730 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %727, i32 1, i32 31), !dbg !57 + %731 = add i32 %727, %730, !dbg !58 + %732 = insertelement <2 x i32> poison, i32 %711, i64 0, !dbg !48 + %733 = insertelement <2 x i32> %732, i32 %713, i64 1, !dbg !48 + %734 = insertelement <2 x i32> poison, i32 %717, i64 0, !dbg !48 + %735 = insertelement <2 x i32> %734, i32 %719, i64 1, !dbg !48 + %736 = icmp slt <2 x i32> %733, %735, !dbg !48 + %737 = icmp eq <2 x i32> %733, %735, !dbg !63 + %738 = insertelement <2 x i32> poison, i32 %723, i64 0, !dbg !64 + %739 = insertelement <2 x i32> %738, i32 %725, i64 1, !dbg !64 + %740 = insertelement <2 x i32> poison, i32 %729, i64 0, !dbg !64 + %741 = insertelement <2 x i32> %740, i32 %731, i64 1, !dbg !64 + %742 = icmp sgt <2 x i32> %739, %741, !dbg !64 + %743 = and <2 x i1> %737, %742, !dbg !65 + %744 = or <2 x i1> %736, %743, !dbg !66 + %745 = xor <2 x i32> %733, %735, !dbg !51 + %746 = select <2 x i1> %744, <2 x i32> %745, <2 x i32> zeroinitializer, !dbg !52 + %747 = xor <2 x i32> %746, %697, !dbg !53 + %748 = xor i32 %723, %729, !dbg !62 + %749 = xor i32 %725, %731, !dbg !62 + %750 = extractelement <2 x i1> %744, i64 0, !dbg !54 + %751 = select i1 %750, i32 %748, i32 0, !dbg !54 + %752 = extractelement <2 x i1> %744, i64 1, !dbg !54 + %753 = select i1 %752, i32 %749, i32 0, !dbg !54 + %754 = xor i32 %751, %704, !dbg !55 + %755 = xor i32 %753, %705, !dbg !55 + %756 = extractelement <2 x i32> %747, i64 0, !dbg !48 + %757 = extractelement <2 x i32> %747, i64 1, !dbg !48 + %758 = icmp slt i32 %756, %757, !dbg !48 + %759 = icmp eq i32 %756, %757, !dbg !63 + %760 = icmp sgt i32 %754, %755, !dbg !64 + %761 = and i1 %759, %760, !dbg !65 + %762 = or i1 %758, %761, !dbg !66 + %763 = xor i32 %754, %755, !dbg !62 + %764 = select i1 %762, i32 %763, i32 0, !dbg !54 + %765 = xor i32 %764, %754, !dbg !55 + %766 = xor i32 %764, %755, !dbg !55 + %narrow = select i1 %19, i1 %30, i1 false, !dbg !67 + %767 = zext i1 %narrow to i64, !dbg !67 + %narrow18 = select i1 %19, i1 %32, i1 false, !dbg !67 + %768 = zext i1 %narrow18 to i64, !dbg !67 + %769 = add nuw nsw i64 %767, %768, !dbg !68 + %770 = trunc nuw nsw i64 %769 to i32, !dbg !70 + %771 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %770, i32 4, i32 31), !dbg !70 + %772 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 0, i32 4, i32 31), !dbg !70 + %773 = insertelement <2 x i32> poison, i32 %771, i64 0, !dbg !70 + %774 = insertelement <2 x i32> %773, i32 %772, i64 1, !dbg !70 + %775 = bitcast <2 x i32> %774 to i64, !dbg !70 + %776 = add i64 %769, %775, !dbg !68 + %extelt.offset = lshr i64 %776, 32, !dbg !70 + %777 = trunc nuw i64 %extelt.offset to i32, !dbg !70 + %778 = trunc i64 %776 to i32, !dbg !70 + %779 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %778, i32 2, i32 31), !dbg !70 + %780 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %777, i32 2, i32 31), !dbg !70 + %781 = insertelement <2 x i32> poison, i32 %779, i64 0, !dbg !70 + %782 = insertelement <2 x i32> %781, i32 %780, i64 1, !dbg !70 + %783 = bitcast <2 x i32> %782 to i64, !dbg !70 + %784 = add i64 %776, %783, !dbg !68 + %extelt.offset19 = lshr i64 %784, 32, !dbg !70 + %785 = trunc nuw i64 %extelt.offset19 to i32, !dbg !70 + %786 = trunc i64 %784 to i32, !dbg !70 + %787 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %786, i32 1, i32 31), !dbg !70 + %788 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %785, i32 1, i32 31), !dbg !70 + %789 = insertelement <2 x i32> poison, i32 %787, i64 0, !dbg !70 + %790 = insertelement <2 x i32> %789, i32 %788, i64 1, !dbg !70 + %791 = bitcast <2 x i32> %790 to i64, !dbg !70 + %792 = add i64 %784, %791, !dbg !68 + %793 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %15, !dbg !71 + %794 = insertelement <1 x i64> poison, i64 %792, i64 0, !dbg !71 + store <1 x i64> %794, ptr addrspace(3) %793, align 8, !dbg !71 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !71 + %795 = shl nuw nsw i32 %17, 3, !dbg !71 + %796 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %795, !dbg !71 + %797 = load i64, ptr addrspace(3) %796, align 8, !dbg !71 + %narrow20 = select i1 %19, i1 %404, i1 false, !dbg !72 + %798 = zext i1 %narrow20 to i64, !dbg !72 + %narrow21 = select i1 %19, i1 %405, i1 false, !dbg !72 + %799 = zext i1 %narrow21 to i64, !dbg !72 + %800 = add nuw nsw i64 %798, %799, !dbg !73 + %801 = trunc nuw nsw i64 %800 to i32, !dbg !75 + %802 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %801, i32 4, i32 31), !dbg !75 + %803 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 0, i32 4, i32 31), !dbg !75 + %804 = insertelement <2 x i32> poison, i32 %802, i64 0, !dbg !75 + %805 = insertelement <2 x i32> %804, i32 %803, i64 1, !dbg !75 + %806 = bitcast <2 x i32> %805 to i64, !dbg !75 + %807 = add i64 %800, %806, !dbg !73 + %extelt.offset23 = lshr i64 %807, 32, !dbg !75 + %808 = trunc nuw i64 %extelt.offset23 to i32, !dbg !75 + %809 = trunc i64 %807 to i32, !dbg !75 + %810 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %809, i32 2, i32 31), !dbg !75 + %811 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %808, i32 2, i32 31), !dbg !75 + %812 = insertelement <2 x i32> poison, i32 %810, i64 0, !dbg !75 + %813 = insertelement <2 x i32> %812, i32 %811, i64 1, !dbg !75 + %814 = bitcast <2 x i32> %813 to i64, !dbg !75 + %815 = add i64 %807, %814, !dbg !73 + %extelt.offset24 = lshr i64 %815, 32, !dbg !75 + %816 = trunc nuw i64 %extelt.offset24 to i32, !dbg !75 + %817 = trunc i64 %815 to i32, !dbg !75 + %818 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %817, i32 1, i32 31), !dbg !75 + %819 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %816, i32 1, i32 31), !dbg !75 + %820 = insertelement <2 x i32> poison, i32 %818, i64 0, !dbg !75 + %821 = insertelement <2 x i32> %820, i32 %819, i64 1, !dbg !75 + %822 = bitcast <2 x i32> %821 to i64, !dbg !75 + %823 = add i64 %815, %822, !dbg !73 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !76 + %824 = insertelement <1 x i64> poison, i64 %823, i64 0, !dbg !76 + store <1 x i64> %824, ptr addrspace(3) %793, align 8, !dbg !76 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !76 + %825 = load i64, ptr addrspace(3) %796, align 8, !dbg !76 + %826 = trunc i64 %792 to i32, !dbg !71 + %827 = icmp slt i32 %20, %826, !dbg !77 + %828 = icmp slt i32 %21, %826, !dbg !77 + %829 = select i1 %827, i32 %402, i32 16, !dbg !78 + %830 = select i1 %828, i32 %403, i32 16, !dbg !78 + %831 = add i32 %829, 17, !dbg !79 + %832 = add i32 %830, 17, !dbg !79 + %833 = icmp slt i32 %829, 0, !dbg !80 + %834 = icmp slt i32 %830, 0, !dbg !80 + %835 = select i1 %833, i32 %831, i32 %829, !dbg !81 + %836 = select i1 %834, i32 %832, i32 %830, !dbg !81 + %837 = icmp ugt i32 %835, 16, !dbg !82 + %838 = icmp ugt i32 %836, 16, !dbg !82 + %.not2629 = or i1 %837, %838, !dbg !83 + %839 = and i1 %19, %.not2629, !dbg !83 + br i1 %839, label %840, label %841, !dbg !83 + +840: ; preds = %11 + tail call void @__assertfail(ptr nonnull @assertMessage_0, ptr nonnull @assertFile_0, i32 71, ptr nonnull @assertFunc_0, i64 1), !dbg !83 + unreachable, !dbg !83 + +841: ; preds = %11 + %842 = trunc i64 %823 to i32, !dbg !76 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !83 + %843 = icmp slt i32 %20, %842, !dbg !84 + %844 = icmp slt i32 %21, %842, !dbg !84 + %845 = select i1 %843, i32 %765, i32 16, !dbg !85 + %846 = select i1 %844, i32 %766, i32 16, !dbg !85 + %847 = add i32 %845, 17, !dbg !86 + %848 = add i32 %846, 17, !dbg !86 + %849 = icmp slt i32 %845, 0, !dbg !87 + %850 = icmp slt i32 %846, 0, !dbg !87 + %851 = select i1 %849, i32 %847, i32 %845, !dbg !88 + %852 = select i1 %850, i32 %848, i32 %846, !dbg !88 + %853 = icmp ugt i32 %851, 16, !dbg !89 + %854 = icmp ugt i32 %852, 16, !dbg !89 + %.not3134 = or i1 %853, %854, !dbg !90 + %855 = and i1 %19, %.not3134, !dbg !90 + br i1 %855, label %856, label %857, !dbg !90 + +856: ; preds = %841 + tail call void @__assertfail(ptr nonnull @assertMessage_1, ptr nonnull @assertFile_1, i32 80, ptr nonnull @assertFunc_1, i64 1), !dbg !90 + unreachable, !dbg !90 + +857: ; preds = %841 + %858 = trunc i64 %825 to i32, !dbg !76 + %859 = trunc i64 %797 to i32, !dbg !71 + %860 = or disjoint i32 %13, %17, !dbg !13 + %861 = icmp slt i32 %860, 32, !dbg !14 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !90 + %862 = sext i32 %860 to i64, !dbg !91 + %863 = getelementptr i32, ptr addrspace(1) %1, i64 %862, !dbg !91 + %864 = icmp eq i32 %15, 0, !dbg !92 + %865 = and i1 %864, %861, !dbg !92 + tail call void asm sideeffect "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b"(i32 %859, ptr addrspace(1) %863, i1 %865) #5, !dbg !92 + %866 = getelementptr i32, ptr addrspace(1) %2, i64 %862, !dbg !93 + tail call void asm sideeffect "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b"(i32 %858, ptr addrspace(1) %866, i1 %865) #5, !dbg !94 + %867 = getelementptr i32, ptr addrspace(1) %3, i64 %24, !dbg !95 + tail call void asm sideeffect "@$3 st.global.v2.b32 [ $2 + 0 ], { $0, $1 };", "r,r,l,b"(i32 %402, i32 %403, ptr addrspace(1) %867, i1 %19) #5, !dbg !96 + %868 = mul i32 %18, 17, !dbg !97 + %869 = add i32 %835, %868, !dbg !98 + %870 = add i32 %836, %868, !dbg !98 + %871 = sext i32 %869 to i64, !dbg !99 + %872 = getelementptr i32, ptr addrspace(1) %4, i64 %871, !dbg !99 + %873 = sext i32 %870 to i64, !dbg !99 + %874 = getelementptr i32, ptr addrspace(1) %4, i64 %873, !dbg !99 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !100 + %875 = ptrtoint ptr addrspace(1) %872 to i64, !dbg !100 + %876 = ptrtoint ptr addrspace(1) %874 to i64, !dbg !100 + %877 = shl nuw nsw i32 %14, 5, !dbg !100 + %878 = and i32 %877, 96, !dbg !100 + %879 = and i32 %14, 48, !dbg !100 + %880 = shl nuw nsw i32 %879, 3, !dbg !100 + %881 = shl nuw nsw i32 %14, 1, !dbg !100 + %882 = and i32 %881, 120, !dbg !100 + %883 = or disjoint i32 %878, %880, !dbg !100 + %884 = xor i32 %883, %882, !dbg !100 + %885 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %884, !dbg !100 + %886 = insertelement <1 x i64> poison, i64 %875, i64 0, !dbg !100 + store <1 x i64> %886, ptr addrspace(3) %885, align 8, !dbg !100 + %887 = xor i32 %884, 520, !dbg !100 + %888 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %887, !dbg !100 + %889 = insertelement <1 x i64> poison, i64 %876, i64 0, !dbg !100 + store <1 x i64> %889, ptr addrspace(3) %888, align 8, !dbg !100 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !100 + %890 = shl nuw nsw i32 %14, 6, !dbg !100 + %891 = and i32 %890, 384, !dbg !100 + %892 = shl nuw nsw i32 %17, 4, !dbg !100 + %893 = shl nuw nsw i32 %879, 1, !dbg !100 + %894 = and i32 %14, 8, !dbg !100 + %895 = icmp eq i32 %894, 0, !dbg !100 + %896 = select i1 %895, i32 0, i32 520, !dbg !100 + %897 = or disjoint i32 %891, %892, !dbg !100 + %898 = xor i32 %897, %893, !dbg !100 + %899 = or disjoint i32 %898, %896, !dbg !100 + %900 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %899, !dbg !100 + %901 = load i64, ptr addrspace(3) %900, align 8, !dbg !100 + %902 = xor i32 %899, 8, !dbg !100 + %903 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %902, !dbg !100 + %904 = load i64, ptr addrspace(3) %903, align 8, !dbg !100 + %905 = inttoptr i64 %901 to ptr addrspace(1), !dbg !100 + %906 = inttoptr i64 %904 to ptr addrspace(1), !dbg !100 + tail call void asm sideeffect "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b"(i32 1, ptr addrspace(1) %905, i1 %861) #5, !dbg !100 + tail call void asm sideeffect "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b"(i32 1, ptr addrspace(1) %906, i1 %861) #5, !dbg !100 + %907 = getelementptr i32, ptr addrspace(1) %5, i64 %24, !dbg !101 + tail call void asm sideeffect "@$3 st.global.v2.b32 [ $2 + 0 ], { $0, $1 };", "r,r,l,b"(i32 %765, i32 %766, ptr addrspace(1) %907, i1 %19) #5, !dbg !102 + %908 = add i32 %851, %868, !dbg !103 + %909 = add i32 %852, %868, !dbg !103 + %910 = sext i32 %908 to i64, !dbg !104 + %911 = getelementptr i32, ptr addrspace(1) %6, i64 %910, !dbg !104 + %912 = sext i32 %909 to i64, !dbg !104 + %913 = getelementptr i32, ptr addrspace(1) %6, i64 %912, !dbg !104 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !105 + %914 = ptrtoint ptr addrspace(1) %911 to i64, !dbg !105 + %915 = ptrtoint ptr addrspace(1) %913 to i64, !dbg !105 + %916 = insertelement <1 x i64> poison, i64 %914, i64 0, !dbg !105 + store <1 x i64> %916, ptr addrspace(3) %885, align 8, !dbg !105 + %917 = insertelement <1 x i64> poison, i64 %915, i64 0, !dbg !105 + store <1 x i64> %917, ptr addrspace(3) %888, align 8, !dbg !105 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !105 + %918 = load i64, ptr addrspace(3) %900, align 8, !dbg !105 + %919 = load i64, ptr addrspace(3) %903, align 8, !dbg !105 + %920 = inttoptr i64 %918 to ptr addrspace(1), !dbg !105 + %921 = inttoptr i64 %919 to ptr addrspace(1), !dbg !105 + tail call void asm sideeffect "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b"(i32 1, ptr addrspace(1) %920, i1 %861) #5, !dbg !105 + tail call void asm sideeffect "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b"(i32 1, ptr addrspace(1) %921, i1 %861) #5, !dbg !105 + ret void, !dbg !106 +} + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare noundef range(i32 0, 2147483647) i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #2 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare noundef range(i32 0, 1024) i32 @llvm.nvvm.read.ptx.sreg.tid.x() #2 + +; Function Attrs: convergent nocallback nounwind memory(inaccessiblemem: readwrite) +declare i32 @llvm.nvvm.shfl.sync.bfly.i32(i32, i32, i32, i32) #3 + +; Function Attrs: convergent nocallback nounwind +declare void @llvm.nvvm.barrier.cta.sync.aligned.all(i32) #4 + +attributes #0 = { noreturn } +attributes #1 = { "nvvm.reqntid"="64" } +attributes #2 = { mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) } +attributes #3 = { convergent nocallback nounwind memory(inaccessiblemem: readwrite) } +attributes #4 = { convergent nocallback nounwind } +attributes #5 = { nounwind } + +!llvm.dbg.cu = !{!0} +!llvm.module.flags = !{!2, !3} +!llvm.ident = !{!4} + +!0 = distinct !DICompileUnit(language: DW_LANG_C, file: !1, producer: "triton", isOptimized: true, runtimeVersion: 0, emissionKind: LineTablesOnly) +!1 = !DIFile(filename: "cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py", directory: "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs") +!2 = !{i32 2, !"Debug Info Version", i32 3} +!3 = !{i32 4, !"nvvm-reflect-ftz", i32 1} +!4 = !{!"clang version 3.8.0 (tags/RELEASE_380/final)"} +!5 = !DISubprogram(name: "__assertfail", linkageName: "__assertfail", scope: !6, file: !6, type: !7, spFlags: DISPFlagOptimized) +!6 = !DIFile(filename: "", directory: "") +!7 = !DISubroutineType(cc: DW_CC_normal, types: !8) +!8 = !{} +!9 = distinct !DISubprogram(name: "triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2", linkageName: "triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2", scope: !1, file: !1, line: 18, type: !7, scopeLine: 18, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !0) +!10 = !DILocation(line: 24, column: 28, scope: !9) +!11 = !DILocation(line: 24, column: 33, scope: !9) +!12 = !DILocation(line: 25, column: 44, scope: !9) +!13 = !DILocation(line: 25, column: 23, scope: !9) +!14 = !DILocation(line: 26, column: 21, scope: !9) +!15 = !DILocation(line: 27, column: 38, scope: !9) +!16 = !DILocation(line: 34, column: 40, scope: !9) +!17 = !DILocation(line: 34, column: 37, scope: !9) +!18 = !DILocation(line: 34, column: 30, scope: !9) +!19 = !DILocation(line: 34, column: 45, scope: !9) +!20 = !DILocation(line: 39, column: 18, scope: !9) +!21 = !DILocation(line: 0, scope: !9) +!22 = !DILocation(line: 627, column: 44, scope: !23, inlinedAt: !25) +!23 = distinct !DILexicalBlockFile(scope: !9, file: !24, discriminator: 0) +!24 = !DIFile(filename: "triton_helpers.py", directory: "/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime") +!25 = !DILocation(line: 46, column: 71, scope: !9) +!26 = !DILocation(line: 537, column: 21, scope: !23, inlinedAt: !25) +!27 = !DILocation(line: 574, column: 22, scope: !23, inlinedAt: !25) +!28 = !DILocation(line: 599, column: 28, scope: !23, inlinedAt: !25) +!29 = !DILocation(line: 600, column: 38, scope: !23, inlinedAt: !25) +!30 = !DILocation(line: 600, column: 46, scope: !23, inlinedAt: !25) +!31 = !DILocation(line: 600, column: 15, scope: !23, inlinedAt: !25) +!32 = !DILocation(line: 601, column: 48, scope: !23, inlinedAt: !25) +!33 = !DILocation(line: 601, column: 59, scope: !23, inlinedAt: !25) +!34 = !DILocation(line: 601, column: 22, scope: !23, inlinedAt: !25) +!35 = !DILocation(line: 538, column: 40, scope: !23, inlinedAt: !25) +!36 = !DILocation(line: 291, column: 36, scope: !37, inlinedAt: !25) +!37 = distinct !DILexicalBlockFile(scope: !9, file: !38, discriminator: 0) +!38 = !DIFile(filename: "standard.py", directory: "/workspace/specforge/lib/python3.11/site-packages/triton/language") +!39 = !DILocation(line: 261, column: 15, scope: !37, inlinedAt: !25) +!40 = !DILocation(line: 539, column: 41, scope: !23, inlinedAt: !25) +!41 = !DILocation(line: 548, column: 23, scope: !23, inlinedAt: !25) +!42 = !DILocation(line: 551, column: 23, scope: !23, inlinedAt: !25) +!43 = !DILocation(line: 591, column: 21, scope: !23, inlinedAt: !25) +!44 = !DILocation(line: 594, column: 40, scope: !23, inlinedAt: !25) +!45 = !DILocation(line: 594, column: 29, scope: !23, inlinedAt: !25) +!46 = !DILocation(line: 594, column: 23, scope: !23, inlinedAt: !25) +!47 = !DILocation(line: 47, column: 20, scope: !9) +!48 = !DILocation(line: 574, column: 22, scope: !23, inlinedAt: !49) +!49 = !DILocation(line: 51, column: 71, scope: !9) +!50 = !DILocation(line: 599, column: 28, scope: !23, inlinedAt: !49) +!51 = !DILocation(line: 600, column: 38, scope: !23, inlinedAt: !49) +!52 = !DILocation(line: 600, column: 46, scope: !23, inlinedAt: !49) +!53 = !DILocation(line: 600, column: 15, scope: !23, inlinedAt: !49) +!54 = !DILocation(line: 601, column: 59, scope: !23, inlinedAt: !49) +!55 = !DILocation(line: 601, column: 22, scope: !23, inlinedAt: !49) +!56 = !DILocation(line: 538, column: 40, scope: !23, inlinedAt: !49) +!57 = !DILocation(line: 291, column: 36, scope: !37, inlinedAt: !49) +!58 = !DILocation(line: 261, column: 15, scope: !37, inlinedAt: !49) +!59 = !DILocation(line: 539, column: 41, scope: !23, inlinedAt: !49) +!60 = !DILocation(line: 548, column: 23, scope: !23, inlinedAt: !49) +!61 = !DILocation(line: 551, column: 23, scope: !23, inlinedAt: !49) +!62 = !DILocation(line: 601, column: 48, scope: !23, inlinedAt: !49) +!63 = !DILocation(line: 591, column: 21, scope: !23, inlinedAt: !49) +!64 = !DILocation(line: 594, column: 40, scope: !23, inlinedAt: !49) +!65 = !DILocation(line: 594, column: 29, scope: !23, inlinedAt: !49) +!66 = !DILocation(line: 594, column: 23, scope: !23, inlinedAt: !49) +!67 = !DILocation(line: 54, column: 35, scope: !9) +!68 = !DILocation(line: 261, column: 15, scope: !37, inlinedAt: !69) +!69 = !DILocation(line: 55, column: 26, scope: !9) +!70 = !DILocation(line: 291, column: 36, scope: !37, inlinedAt: !69) +!71 = !DILocation(line: 60, column: 21, scope: !9) +!72 = !DILocation(line: 58, column: 35, scope: !9) +!73 = !DILocation(line: 261, column: 15, scope: !37, inlinedAt: !74) +!74 = !DILocation(line: 59, column: 26, scope: !9) +!75 = !DILocation(line: 291, column: 36, scope: !37, inlinedAt: !74) +!76 = !DILocation(line: 61, column: 21, scope: !9) +!77 = !DILocation(line: 64, column: 19, scope: !9) +!78 = !DILocation(line: 66, column: 35, scope: !9) +!79 = !DILocation(line: 68, column: 20, scope: !9) +!80 = !DILocation(line: 69, column: 20, scope: !9) +!81 = !DILocation(line: 70, column: 35, scope: !9) +!82 = !DILocation(line: 71, column: 38, scope: !9) +!83 = !DILocation(line: 71, column: 63, scope: !9) +!84 = !DILocation(line: 75, column: 19, scope: !9) +!85 = !DILocation(line: 76, column: 35, scope: !9) +!86 = !DILocation(line: 77, column: 20, scope: !9) +!87 = !DILocation(line: 78, column: 20, scope: !9) +!88 = !DILocation(line: 79, column: 35, scope: !9) +!89 = !DILocation(line: 80, column: 38, scope: !9) +!90 = !DILocation(line: 80, column: 63, scope: !9) +!91 = !DILocation(line: 81, column: 25, scope: !9) +!92 = !DILocation(line: 81, column: 37, scope: !9) +!93 = !DILocation(line: 82, column: 25, scope: !9) +!94 = !DILocation(line: 82, column: 37, scope: !9) +!95 = !DILocation(line: 83, column: 25, scope: !9) +!96 = !DILocation(line: 83, column: 47, scope: !9) +!97 = !DILocation(line: 84, column: 52, scope: !9) +!98 = !DILocation(line: 84, column: 49, scope: !9) +!99 = !DILocation(line: 84, column: 25, scope: !9) +!100 = !DILocation(line: 84, column: 85, scope: !9) +!101 = !DILocation(line: 85, column: 25, scope: !9) +!102 = !DILocation(line: 85, column: 47, scope: !9) +!103 = !DILocation(line: 86, column: 49, scope: !9) +!104 = !DILocation(line: 86, column: 25, scope: !9) +!105 = !DILocation(line: 86, column: 85, scope: !9) +!106 = !DILocation(line: 86, column: 4, scope: !9) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/5P66Y2R4BAVAKI2AZ4OOKOSCRGKH7SKT5SYLTP4RXGBUCBAJIDZQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ptx b/SpecForge-ext/cache/compiled_kernels/triton/0/5P66Y2R4BAVAKI2AZ4OOKOSCRGKH7SKT5SYLTP4RXGBUCBAJIDZQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ptx new file mode 100644 index 0000000000000000000000000000000000000000..df80bb3c82744e470476a6c9d9685420e69a1652 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/5P66Y2R4BAVAKI2AZ4OOKOSCRGKH7SKT5SYLTP4RXGBUCBAJIDZQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ptx @@ -0,0 +1,1774 @@ +// +// Generated by LLVM NVPTX Back-End +// + +.version 8.7 +.target sm_90a +.address_size 64 + + // .globl triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2 // -- Begin function triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2 +.extern .func __assertfail +( + .param .b64 __assertfail_param_0, + .param .b64 __assertfail_param_1, + .param .b32 __assertfail_param_2, + .param .b64 __assertfail_param_3, + .param .b64 __assertfail_param_4 +) +.noreturn; +.global .align 1 .b8 assertFunc_1[8] = {117, 110, 107, 110, 111, 119, 110}; +.global .align 1 .b8 assertFile_1[114] = {47, 119, 111, 114, 107, 115, 112, 97, 99, 101, 47, 104, 97, 110, 114, 117, 105, 47, 83, 112, 101, 99, 70, 111, 114, 103, 101, 45, 101, 120, 116, 47, 99, 97, 99, 104, 101, 47, 99, 111, 109, 112, 105, 108, 101, 100, 95, 107, 101, 114, 110, 101, 108, 115, 47, 98, 115, 47, 99, 98, 115, 54, 53, 50, 105, 55, 99, 116, 53, 55, 117, 103, 120, 54, 118, 110, 99, 51, 53, 110, 54, 51, 116, 105, 110, 112, 122, 117, 54, 98, 97, 51, 122, 121, 109, 117, 102, 111, 100, 105, 103, 52, 104, 112, 122, 97, 114, 119, 99, 108, 46, 112, 121}; +.global .align 1 .b8 assertMessage_1[37] = {105, 110, 100, 101, 120, 32, 111, 117, 116, 32, 111, 102, 32, 98, 111, 117, 110, 100, 115, 58, 32, 48, 32, 60, 61, 32, 116, 109, 112, 52, 57, 32, 60, 32, 49, 55}; +.global .align 1 .b8 assertFunc_0[8] = {117, 110, 107, 110, 111, 119, 110}; +.global .align 1 .b8 assertFile_0[114] = {47, 119, 111, 114, 107, 115, 112, 97, 99, 101, 47, 104, 97, 110, 114, 117, 105, 47, 83, 112, 101, 99, 70, 111, 114, 103, 101, 45, 101, 120, 116, 47, 99, 97, 99, 104, 101, 47, 99, 111, 109, 112, 105, 108, 101, 100, 95, 107, 101, 114, 110, 101, 108, 115, 47, 98, 115, 47, 99, 98, 115, 54, 53, 50, 105, 55, 99, 116, 53, 55, 117, 103, 120, 54, 118, 110, 99, 51, 53, 110, 54, 51, 116, 105, 110, 112, 122, 117, 54, 98, 97, 51, 122, 121, 109, 117, 102, 111, 100, 105, 103, 52, 104, 112, 122, 97, 114, 119, 99, 108, 46, 112, 121}; +.global .align 1 .b8 assertMessage_0[37] = {105, 110, 100, 101, 120, 32, 111, 117, 116, 32, 111, 102, 32, 98, 111, 117, 110, 100, 115, 58, 32, 48, 32, 60, 61, 32, 116, 109, 112, 52, 48, 32, 60, 32, 49, 55}; +.extern .shared .align 16 .b8 global_smem[]; + // @triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2 +.visible .entry triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2( + .param .u64 .ptr .global .align 1 triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_0, + .param .u64 .ptr .global .align 1 triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_1, + .param .u64 .ptr .global .align 1 triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_2, + .param .u64 .ptr .global .align 1 triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_3, + .param .u64 .ptr .global .align 1 triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_4, + .param .u64 .ptr .global .align 1 triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_5, + .param .u64 .ptr .global .align 1 triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_6, + .param .u32 triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_7, + .param .u32 triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_8, + .param .u64 .ptr .global .align 1 triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_9, + .param .u64 .ptr .global .align 1 triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_10 +) +.reqntid 64 +{ + .reg .pred %p<204>; + .reg .b32 %r<600>; + .reg .b64 %rd<78>; + .loc 1 18 0 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:18:0 +$L__func_begin0: + .loc 1 18 0 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:18:0 + +// %bb.0: + ld.param.b64 %rd14, [triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_0]; +$L__tmp0: + .loc 1 24 28 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:24:28 + mov.u32 %r16, %ctaid.x; + .loc 1 24 33 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:24:33 + shl.b32 %r1, %r16, 3; + .loc 1 25 44 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:25:44 + mov.u32 %r2, %tid.x; + and.b32 %r3, %r2, 56; + bfe.u32 %r17, %r2, 3, 3; + and.b32 %r4, %r2, 7; + .loc 1 25 23 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:25:23 + or.b32 %r5, %r17, %r1; + .loc 1 26 21 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:26:21 + setp.gt.s32 %p2, %r5, 31; + setp.lt.s32 %p1, %r5, 32; + .loc 1 27 38 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:27:38 + shl.b32 %r6, %r4, 1; + or.b32 %r7, %r6, 1; + .loc 1 34 40 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:34:40 + shl.b32 %r18, %r5, 4; + .loc 1 34 37 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:34:37 + or.b32 %r19, %r18, %r6; + .loc 1 34 30 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:34:30 + mad.wide.s32 %rd13, %r19, 8, %rd14; + .loc 1 34 45 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:34:45 + // begin inline asm + mov.u64 %rd11, 0x0; + mov.u64 %rd12, 0x0; + @%p1 ld.global.v2.b64 { %rd11, %rd12 }, [ %rd13 + 0 ]; + // end inline asm + .loc 1 39 18 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:39:18 + add.s64 %rd15, %rd11, -1; + setp.gt.u64 %p3, %rd15, 16382; + setp.lt.u64 %p4, %rd15, 16383; + add.s64 %rd16, %rd12, -1; + setp.lt.u64 %p5, %rd16, 16383; + .loc 1 0 0 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:0 + selp.b32 %r20, 1, 0, %p4; + selp.b32 %r21, 1, 0, %p5; +$L__tmp1: + .loc 2 627 44 // triton_helpers.py:627:44 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + and.b32 %r22, %r2, 1; + shr.u32 %r23, %r2, 1; + bfe.u32 %r24, %r2, 1, 1; + shr.u32 %r25, %r2, 2; + bfe.u32 %r26, %r2, 2, 1; + .loc 2 537 21 // triton_helpers.py:537:21 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r27, %r22, 1; + xor.b32 %r28, %r24, 1; + xor.b32 %r29, %r26, 1; + .loc 2 574 22 // triton_helpers.py:574:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + and.pred %p6, %p5, %p3; + .loc 2 599 28 // triton_helpers.py:599:28 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + setp.ne.b32 %p7, %r22, 0; + xor.pred %p8, %p6, %p7; + .loc 2 600 38 // triton_helpers.py:600:38 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r30, %r20, %r21; + .loc 2 600 46 // triton_helpers.py:600:46 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r31, %r30, 0, %p8; + .loc 2 600 15 // triton_helpers.py:600:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r32, %r31, %r20; + xor.b32 %r33, %r31, %r21; + .loc 2 601 48 // triton_helpers.py:601:48 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r34, %r7, %r6; + .loc 2 601 59 // triton_helpers.py:601:59 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r35, %r34, 0, %p8; + .loc 2 601 22 // triton_helpers.py:601:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r36, %r35, %r6; + xor.b32 %r37, %r35, %r7; + .loc 2 538 40 // triton_helpers.py:538:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r38, %r32, %r27; + mul.lo.s32 %r39, %r33, %r27; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r40, %r38, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r41, %r38, %r40; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r42, %r39, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r43, %r39, %r42; + .loc 2 539 41 // triton_helpers.py:539:41 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r44, %r32, %r22; + mul.lo.s32 %r45, %r33, %r22; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r46, %r44, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r47, %r44, %r46; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r48, %r45, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r49, %r45, %r48; + .loc 2 548 23 // triton_helpers.py:548:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r50, %r36, %r27; + mul.lo.s32 %r51, %r37, %r27; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r52, %r50, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r53, %r50, %r52; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r54, %r51, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r55, %r51, %r54; + .loc 2 551 23 // triton_helpers.py:551:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r56, %r36, %r22; + mul.lo.s32 %r57, %r37, %r22; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r58, %r56, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r59, %r56, %r58; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r60, %r57, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r61, %r57, %r60; + .loc 2 599 28 // triton_helpers.py:599:28 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + and.b32 %r62, %r23, 1; + setp.ne.b32 %p9, %r62, 0; + setp.ge.s32 %p10, %r41, %r47; + setp.ge.s32 %p11, %r43, %r49; + setp.ne.b32 %p12, %r43, %r49; + setp.ne.b32 %p13, %r41, %r47; + setp.le.s32 %p14, %r55, %r61; + setp.le.s32 %p15, %r53, %r59; + or.pred %p16, %p13, %p15; + or.pred %p17, %p12, %p14; + and.pred %p18, %p11, %p17; + and.pred %p19, %p10, %p16; + xor.pred %p20, %p19, %p9; + xor.pred %p21, %p18, %p9; + .loc 2 600 38 // triton_helpers.py:600:38 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r63, %r41, %r47; + xor.b32 %r64, %r43, %r49; + .loc 2 600 46 // triton_helpers.py:600:46 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r65, 0, %r64, %p21; + selp.b32 %r66, 0, %r63, %p20; + .loc 2 600 15 // triton_helpers.py:600:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r67, %r66, %r32; + xor.b32 %r68, %r65, %r33; + .loc 2 601 48 // triton_helpers.py:601:48 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r69, %r53, %r59; + xor.b32 %r70, %r55, %r61; + .loc 2 601 59 // triton_helpers.py:601:59 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r71, 0, %r70, %p21; + selp.b32 %r72, 0, %r69, %p20; + .loc 2 601 22 // triton_helpers.py:601:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r73, %r72, %r36; + xor.b32 %r74, %r71, %r37; + .loc 2 599 28 // triton_helpers.py:599:28 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + setp.ge.s32 %p22, %r67, %r68; + setp.ne.b32 %p23, %r67, %r68; + setp.le.s32 %p24, %r73, %r74; + or.pred %p25, %p23, %p24; + and.pred %p26, %p22, %p25; + xor.pred %p27, %p26, %p9; + .loc 2 600 38 // triton_helpers.py:600:38 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r75, %r67, %r68; + .loc 2 600 46 // triton_helpers.py:600:46 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r76, 0, %r75, %p27; + .loc 2 601 48 // triton_helpers.py:601:48 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r77, %r73, %r74; + .loc 2 601 59 // triton_helpers.py:601:59 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r78, 0, %r77, %p27; + .loc 2 601 22 // triton_helpers.py:601:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r79, %r78, %r73; + xor.b32 %r80, %r78, %r74; + .loc 2 548 23 // triton_helpers.py:548:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r81, %r79, %r28; + mul.lo.s32 %r82, %r80, %r28; + .loc 2 551 23 // triton_helpers.py:551:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r83, %r79, %r24; + mul.lo.s32 %r84, %r80, %r24; + .loc 2 599 28 // triton_helpers.py:599:28 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + and.b32 %r85, %r25, 1; + setp.ne.b32 %p28, %r85, 0; + .loc 2 600 15 // triton_helpers.py:600:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r86, %r76, %r68; + xor.b32 %r87, %r76, %r67; + .loc 2 538 40 // triton_helpers.py:538:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r88, %r87, %r28; + mul.lo.s32 %r89, %r86, %r28; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r90, %r88, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r91, %r88, %r90; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r92, %r89, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r93, %r89, %r92; + .loc 2 539 41 // triton_helpers.py:539:41 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r94, %r87, %r24; + mul.lo.s32 %r95, %r86, %r24; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r96, %r94, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r97, %r94, %r96; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r98, %r95, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r99, %r95, %r98; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r100, %r81, 2, 31, -1; + shfl.sync.bfly.b32 %r101, %r82, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r102, %r82, %r101; + add.s32 %r103, %r81, %r100; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r104, %r83, 2, 31, -1; + shfl.sync.bfly.b32 %r105, %r84, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r106, %r84, %r105; + add.s32 %r107, %r83, %r104; + .loc 2 599 28 // triton_helpers.py:599:28 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + setp.ge.s32 %p29, %r93, %r99; + setp.ge.s32 %p30, %r91, %r97; + setp.ne.b32 %p31, %r91, %r97; + setp.ne.b32 %p32, %r93, %r99; + .loc 2 600 38 // triton_helpers.py:600:38 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r108, %r93, %r99; + xor.b32 %r109, %r91, %r97; + .loc 2 599 28 // triton_helpers.py:599:28 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + setp.le.s32 %p33, %r103, %r107; + setp.le.s32 %p34, %r102, %r106; + or.pred %p35, %p32, %p34; + or.pred %p36, %p31, %p33; + and.pred %p37, %p30, %p36; + and.pred %p38, %p29, %p35; + xor.pred %p39, %p38, %p28; + xor.pred %p40, %p37, %p28; + .loc 2 600 46 // triton_helpers.py:600:46 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r110, 0, %r109, %p40; + selp.b32 %r111, 0, %r108, %p39; + .loc 2 600 15 // triton_helpers.py:600:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r112, %r111, %r86; + xor.b32 %r113, %r110, %r87; + .loc 2 601 48 // triton_helpers.py:601:48 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r114, %r102, %r106; + xor.b32 %r115, %r103, %r107; + .loc 2 601 59 // triton_helpers.py:601:59 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r116, 0, %r115, %p40; + selp.b32 %r117, 0, %r114, %p39; + .loc 2 601 22 // triton_helpers.py:601:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r118, %r117, %r80; + xor.b32 %r119, %r116, %r79; + .loc 2 538 40 // triton_helpers.py:538:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r120, %r113, %r27; + mul.lo.s32 %r121, %r112, %r27; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r122, %r120, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r123, %r120, %r122; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r124, %r121, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r125, %r121, %r124; + .loc 2 539 41 // triton_helpers.py:539:41 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r126, %r113, %r22; + mul.lo.s32 %r127, %r112, %r22; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r128, %r126, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r129, %r126, %r128; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r130, %r127, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r131, %r127, %r130; + .loc 2 548 23 // triton_helpers.py:548:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r132, %r119, %r27; + mul.lo.s32 %r133, %r118, %r27; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r134, %r132, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r135, %r132, %r134; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r136, %r133, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r137, %r133, %r136; + .loc 2 551 23 // triton_helpers.py:551:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r138, %r119, %r22; + mul.lo.s32 %r139, %r118, %r22; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r140, %r138, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r141, %r138, %r140; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r142, %r139, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r143, %r139, %r142; + .loc 2 599 28 // triton_helpers.py:599:28 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + setp.ge.s32 %p41, %r123, %r129; + setp.ge.s32 %p42, %r125, %r131; + setp.ne.b32 %p43, %r125, %r131; + setp.ne.b32 %p44, %r123, %r129; + setp.le.s32 %p45, %r137, %r143; + setp.le.s32 %p46, %r135, %r141; + or.pred %p47, %p44, %p46; + or.pred %p48, %p43, %p45; + and.pred %p49, %p42, %p48; + and.pred %p50, %p41, %p47; + xor.pred %p51, %p50, %p28; + xor.pred %p52, %p49, %p28; + .loc 2 600 38 // triton_helpers.py:600:38 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r144, %r123, %r129; + xor.b32 %r145, %r125, %r131; + .loc 2 600 46 // triton_helpers.py:600:46 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r146, 0, %r145, %p52; + selp.b32 %r147, 0, %r144, %p51; + .loc 2 600 15 // triton_helpers.py:600:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r148, %r147, %r113; + xor.b32 %r149, %r146, %r112; + .loc 2 601 48 // triton_helpers.py:601:48 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r150, %r135, %r141; + xor.b32 %r151, %r137, %r143; + .loc 2 601 59 // triton_helpers.py:601:59 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r152, 0, %r151, %p52; + selp.b32 %r153, 0, %r150, %p51; + .loc 2 601 22 // triton_helpers.py:601:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r154, %r153, %r119; + xor.b32 %r155, %r152, %r118; + .loc 2 599 28 // triton_helpers.py:599:28 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + setp.ge.s32 %p53, %r148, %r149; + setp.ne.b32 %p54, %r148, %r149; + setp.le.s32 %p55, %r154, %r155; + or.pred %p56, %p54, %p55; + and.pred %p57, %p53, %p56; + xor.pred %p58, %p57, %p28; + .loc 2 600 38 // triton_helpers.py:600:38 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r156, %r148, %r149; + .loc 2 600 46 // triton_helpers.py:600:46 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r157, 0, %r156, %p58; + .loc 2 600 15 // triton_helpers.py:600:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r158, %r157, %r149; + xor.b32 %r159, %r157, %r148; + .loc 2 601 48 // triton_helpers.py:601:48 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r160, %r154, %r155; + .loc 2 601 59 // triton_helpers.py:601:59 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r161, 0, %r160, %p58; + .loc 2 601 22 // triton_helpers.py:601:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r162, %r161, %r155; + xor.b32 %r163, %r161, %r154; + .loc 2 538 40 // triton_helpers.py:538:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r164, %r159, %r29; + mul.lo.s32 %r165, %r158, %r29; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r166, %r164, 4, 31, -1; + shfl.sync.bfly.b32 %r167, %r165, 4, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r168, %r164, %r166; + add.s32 %r169, %r165, %r167; + .loc 2 539 41 // triton_helpers.py:539:41 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r170, %r159, %r26; + mul.lo.s32 %r171, %r158, %r26; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r172, %r170, 4, 31, -1; + shfl.sync.bfly.b32 %r173, %r171, 4, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r174, %r170, %r172; + add.s32 %r175, %r171, %r173; + .loc 2 548 23 // triton_helpers.py:548:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r176, %r163, %r29; + mul.lo.s32 %r177, %r162, %r29; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r178, %r176, 4, 31, -1; + shfl.sync.bfly.b32 %r179, %r177, 4, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r180, %r177, %r179; + add.s32 %r181, %r176, %r178; + .loc 2 551 23 // triton_helpers.py:551:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r182, %r163, %r26; + mul.lo.s32 %r183, %r162, %r26; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r184, %r182, 4, 31, -1; + shfl.sync.bfly.b32 %r185, %r183, 4, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r186, %r183, %r185; + add.s32 %r187, %r182, %r184; + .loc 2 574 22 // triton_helpers.py:574:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + setp.lt.s32 %p59, %r169, %r175; + setp.lt.s32 %p60, %r168, %r174; + .loc 2 591 21 // triton_helpers.py:591:21 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + setp.eq.b32 %p61, %r168, %r174; + setp.eq.b32 %p62, %r169, %r175; + .loc 2 594 40 // triton_helpers.py:594:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + setp.gt.s32 %p63, %r181, %r187; + setp.gt.s32 %p64, %r180, %r186; + .loc 2 594 29 // triton_helpers.py:594:29 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + and.pred %p65, %p62, %p64; + and.pred %p66, %p61, %p63; + .loc 2 594 23 // triton_helpers.py:594:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + or.pred %p67, %p60, %p66; + or.pred %p68, %p59, %p65; + .loc 2 600 38 // triton_helpers.py:600:38 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r188, %r168, %r174; + xor.b32 %r189, %r169, %r175; + .loc 2 600 46 // triton_helpers.py:600:46 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r190, %r189, 0, %p68; + selp.b32 %r191, %r188, 0, %p67; + .loc 2 600 15 // triton_helpers.py:600:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r192, %r191, %r159; + xor.b32 %r193, %r190, %r158; + .loc 2 601 48 // triton_helpers.py:601:48 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r194, %r180, %r186; + xor.b32 %r195, %r181, %r187; + .loc 2 601 59 // triton_helpers.py:601:59 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r196, %r195, 0, %p67; + selp.b32 %r197, %r194, 0, %p68; + .loc 2 601 22 // triton_helpers.py:601:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r198, %r197, %r162; + xor.b32 %r199, %r196, %r163; + .loc 2 538 40 // triton_helpers.py:538:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r200, %r193, %r28; + mul.lo.s32 %r201, %r192, %r28; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r202, %r201, 2, 31, -1; + shfl.sync.bfly.b32 %r203, %r200, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r204, %r201, %r202; + add.s32 %r205, %r200, %r203; + .loc 2 539 41 // triton_helpers.py:539:41 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r206, %r193, %r24; + mul.lo.s32 %r207, %r192, %r24; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r208, %r207, 2, 31, -1; + shfl.sync.bfly.b32 %r209, %r206, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r210, %r207, %r208; + add.s32 %r211, %r206, %r209; + .loc 2 548 23 // triton_helpers.py:548:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r212, %r199, %r28; + mul.lo.s32 %r213, %r198, %r28; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r214, %r212, 2, 31, -1; + shfl.sync.bfly.b32 %r215, %r213, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r216, %r213, %r215; + add.s32 %r217, %r212, %r214; + .loc 2 551 23 // triton_helpers.py:551:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r218, %r199, %r24; + mul.lo.s32 %r219, %r198, %r24; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r220, %r218, 2, 31, -1; + shfl.sync.bfly.b32 %r221, %r219, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r222, %r219, %r221; + add.s32 %r223, %r218, %r220; + .loc 2 574 22 // triton_helpers.py:574:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + setp.lt.s32 %p69, %r205, %r211; + setp.lt.s32 %p70, %r204, %r210; + .loc 2 591 21 // triton_helpers.py:591:21 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + setp.eq.b32 %p71, %r204, %r210; + setp.eq.b32 %p72, %r205, %r211; + .loc 2 594 40 // triton_helpers.py:594:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + setp.gt.s32 %p73, %r217, %r223; + setp.gt.s32 %p74, %r216, %r222; + .loc 2 594 29 // triton_helpers.py:594:29 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + and.pred %p75, %p72, %p74; + and.pred %p76, %p71, %p73; + .loc 2 594 23 // triton_helpers.py:594:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + or.pred %p77, %p70, %p76; + or.pred %p78, %p69, %p75; + .loc 2 600 38 // triton_helpers.py:600:38 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r224, %r204, %r210; + xor.b32 %r225, %r205, %r211; + .loc 2 600 46 // triton_helpers.py:600:46 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r226, %r225, 0, %p78; + selp.b32 %r227, %r224, 0, %p77; + .loc 2 600 15 // triton_helpers.py:600:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r228, %r227, %r192; + xor.b32 %r229, %r226, %r193; + .loc 2 601 48 // triton_helpers.py:601:48 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r230, %r217, %r223; + xor.b32 %r231, %r216, %r222; + .loc 2 601 59 // triton_helpers.py:601:59 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r232, %r231, 0, %p78; + selp.b32 %r233, %r230, 0, %p77; + .loc 2 601 22 // triton_helpers.py:601:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r234, %r233, %r199; + xor.b32 %r235, %r232, %r198; + .loc 2 538 40 // triton_helpers.py:538:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r236, %r229, %r27; + mul.lo.s32 %r237, %r228, %r27; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r238, %r237, 1, 31, -1; + shfl.sync.bfly.b32 %r239, %r236, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r240, %r237, %r238; + add.s32 %r241, %r236, %r239; + .loc 2 539 41 // triton_helpers.py:539:41 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r242, %r229, %r22; + mul.lo.s32 %r243, %r228, %r22; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r244, %r243, 1, 31, -1; + shfl.sync.bfly.b32 %r245, %r242, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r246, %r243, %r244; + add.s32 %r247, %r242, %r245; + .loc 2 548 23 // triton_helpers.py:548:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r248, %r235, %r27; + mul.lo.s32 %r249, %r234, %r27; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r250, %r249, 1, 31, -1; + shfl.sync.bfly.b32 %r251, %r248, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r252, %r248, %r251; + add.s32 %r253, %r249, %r250; + .loc 2 551 23 // triton_helpers.py:551:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r254, %r235, %r22; + mul.lo.s32 %r255, %r234, %r22; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r256, %r255, 1, 31, -1; + shfl.sync.bfly.b32 %r257, %r254, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r258, %r254, %r257; + add.s32 %r259, %r255, %r256; + .loc 2 574 22 // triton_helpers.py:574:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + setp.lt.s32 %p79, %r241, %r247; + setp.lt.s32 %p80, %r240, %r246; + .loc 2 591 21 // triton_helpers.py:591:21 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + setp.eq.b32 %p81, %r240, %r246; + setp.eq.b32 %p82, %r241, %r247; + .loc 2 594 40 // triton_helpers.py:594:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + setp.gt.s32 %p83, %r253, %r259; + setp.gt.s32 %p84, %r252, %r258; + .loc 2 594 29 // triton_helpers.py:594:29 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + and.pred %p85, %p82, %p84; + and.pred %p86, %p81, %p83; + .loc 2 594 23 // triton_helpers.py:594:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + or.pred %p87, %p80, %p86; + or.pred %p88, %p79, %p85; + .loc 2 600 38 // triton_helpers.py:600:38 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r260, %r240, %r246; + xor.b32 %r261, %r241, %r247; + .loc 2 600 46 // triton_helpers.py:600:46 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r262, %r261, 0, %p88; + selp.b32 %r263, %r260, 0, %p87; + .loc 2 600 15 // triton_helpers.py:600:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r264, %r263, %r228; + xor.b32 %r265, %r262, %r229; + .loc 2 601 48 // triton_helpers.py:601:48 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r266, %r253, %r259; + xor.b32 %r267, %r252, %r258; + .loc 2 601 59 // triton_helpers.py:601:59 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r268, %r267, 0, %p88; + selp.b32 %r269, %r266, 0, %p87; + .loc 2 601 22 // triton_helpers.py:601:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r270, %r269, %r234; + xor.b32 %r271, %r268, %r235; + .loc 2 574 22 // triton_helpers.py:574:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + setp.lt.s32 %p89, %r264, %r265; + .loc 2 591 21 // triton_helpers.py:591:21 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + setp.eq.b32 %p90, %r264, %r265; + .loc 2 594 40 // triton_helpers.py:594:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + setp.gt.s32 %p91, %r270, %r271; + .loc 2 601 48 // triton_helpers.py:601:48 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r272, %r270, %r271; + .loc 2 601 59 // triton_helpers.py:601:59 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r273, %r272, 0, %p91; + selp.b32 %r274, %r273, 0, %p90; + selp.b32 %r275, %r272, %r274, %p89; + .loc 2 601 22 // triton_helpers.py:601:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r8, %r275, %r270; + xor.b32 %r9, %r275, %r271; +$L__tmp2: + .loc 1 47 20 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:47:20 + setp.ne.b64 %p92, %rd11, 16384; + setp.eq.b64 %p93, %rd11, 16384; + setp.eq.b64 %p94, %rd12, 16384; + .loc 1 0 0 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:0 + selp.b32 %r276, 1, 0, %p93; + selp.b32 %r277, 1, 0, %p94; +$L__tmp3: + .loc 2 574 22 // triton_helpers.py:574:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + and.pred %p95, %p94, %p92; + .loc 2 599 28 // triton_helpers.py:599:28 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.pred %p96, %p95, %p7; + .loc 2 600 38 // triton_helpers.py:600:38 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r278, %r276, %r277; + .loc 2 600 46 // triton_helpers.py:600:46 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + selp.b32 %r279, %r278, 0, %p96; + .loc 2 600 15 // triton_helpers.py:600:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r280, %r279, %r276; + xor.b32 %r281, %r279, %r277; + .loc 2 601 59 // triton_helpers.py:601:59 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + selp.b32 %r282, %r34, 0, %p96; + .loc 2 601 22 // triton_helpers.py:601:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r283, %r282, %r6; + xor.b32 %r284, %r282, %r7; + .loc 2 538 40 // triton_helpers.py:538:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r285, %r280, %r27; + mul.lo.s32 %r286, %r281, %r27; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r287, %r285, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r288, %r287, %r285; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r289, %r286, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r290, %r289, %r286; + .loc 2 539 41 // triton_helpers.py:539:41 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r291, %r280, %r22; + mul.lo.s32 %r292, %r281, %r22; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r293, %r291, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r294, %r293, %r291; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r295, %r292, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r296, %r295, %r292; + .loc 2 548 23 // triton_helpers.py:548:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r297, %r283, %r27; + mul.lo.s32 %r298, %r284, %r27; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r299, %r297, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r300, %r299, %r297; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r301, %r298, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r302, %r301, %r298; + .loc 2 551 23 // triton_helpers.py:551:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r303, %r283, %r22; + mul.lo.s32 %r304, %r284, %r22; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r305, %r303, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r306, %r305, %r303; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r307, %r304, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r308, %r307, %r304; + .loc 2 599 28 // triton_helpers.py:599:28 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + setp.ge.s32 %p97, %r288, %r294; + setp.ge.s32 %p98, %r290, %r296; + setp.ne.b32 %p99, %r290, %r296; + setp.ne.b32 %p100, %r288, %r294; + setp.le.s32 %p101, %r302, %r308; + setp.le.s32 %p102, %r300, %r306; + or.pred %p103, %p100, %p102; + or.pred %p104, %p99, %p101; + and.pred %p105, %p98, %p104; + and.pred %p106, %p97, %p103; + xor.pred %p107, %p106, %p9; + xor.pred %p108, %p105, %p9; + .loc 2 600 38 // triton_helpers.py:600:38 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r309, %r296, %r290; + xor.b32 %r310, %r306, %r300; + xor.b32 %r311, %r294, %r288; + xor.b32 %r312, %r308, %r302; + .loc 2 600 46 // triton_helpers.py:600:46 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + selp.b32 %r313, 0, %r309, %p108; + selp.b32 %r314, 0, %r310, %p107; + selp.b32 %r315, 0, %r311, %p107; + selp.b32 %r316, 0, %r312, %p108; + .loc 2 600 15 // triton_helpers.py:600:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r317, %r314, %r283; + xor.b32 %r318, %r313, %r281; + xor.b32 %r319, %r316, %r284; + xor.b32 %r320, %r315, %r280; + .loc 2 599 28 // triton_helpers.py:599:28 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + setp.ge.s32 %p109, %r320, %r318; + setp.ne.b32 %p110, %r318, %r320; + setp.le.s32 %p111, %r317, %r319; + or.pred %p112, %p110, %p111; + and.pred %p113, %p109, %p112; + xor.pred %p114, %p113, %p9; + .loc 2 600 38 // triton_helpers.py:600:38 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r321, %r318, %r320; + .loc 2 600 46 // triton_helpers.py:600:46 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + selp.b32 %r322, 0, %r321, %p114; + .loc 2 600 15 // triton_helpers.py:600:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r323, %r322, %r320; + xor.b32 %r324, %r322, %r318; + .loc 2 601 48 // triton_helpers.py:601:48 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r325, %r319, %r317; + .loc 2 601 59 // triton_helpers.py:601:59 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + selp.b32 %r326, 0, %r325, %p114; + .loc 2 601 22 // triton_helpers.py:601:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r327, %r326, %r317; + xor.b32 %r328, %r326, %r319; + .loc 2 538 40 // triton_helpers.py:538:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r329, %r323, %r28; + mul.lo.s32 %r330, %r324, %r28; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r331, %r329, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r332, %r329, %r331; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r333, %r330, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r334, %r330, %r333; + .loc 2 539 41 // triton_helpers.py:539:41 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r335, %r323, %r24; + mul.lo.s32 %r336, %r324, %r24; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r337, %r335, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r338, %r335, %r337; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r339, %r336, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r340, %r336, %r339; + .loc 2 548 23 // triton_helpers.py:548:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r341, %r327, %r28; + mul.lo.s32 %r342, %r328, %r28; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r343, %r341, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r344, %r341, %r343; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r345, %r342, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r346, %r342, %r345; + .loc 2 551 23 // triton_helpers.py:551:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r347, %r327, %r24; + mul.lo.s32 %r348, %r328, %r24; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r349, %r347, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r350, %r347, %r349; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r351, %r348, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r352, %r348, %r351; + .loc 2 599 28 // triton_helpers.py:599:28 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + setp.ge.s32 %p115, %r334, %r340; + setp.ge.s32 %p116, %r332, %r338; + setp.ne.b32 %p117, %r332, %r338; + setp.ne.b32 %p118, %r334, %r340; + .loc 2 600 38 // triton_helpers.py:600:38 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r353, %r332, %r338; + xor.b32 %r354, %r334, %r340; + .loc 2 599 28 // triton_helpers.py:599:28 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + setp.le.s32 %p119, %r344, %r350; + setp.le.s32 %p120, %r346, %r352; + or.pred %p121, %p118, %p120; + or.pred %p122, %p117, %p119; + and.pred %p123, %p116, %p122; + and.pred %p124, %p115, %p121; + xor.pred %p125, %p124, %p28; + xor.pred %p126, %p123, %p28; + .loc 2 600 46 // triton_helpers.py:600:46 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + selp.b32 %r355, 0, %r353, %p126; + selp.b32 %r356, 0, %r354, %p125; + .loc 2 600 15 // triton_helpers.py:600:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r357, %r355, %r323; + xor.b32 %r358, %r356, %r324; + .loc 2 601 48 // triton_helpers.py:601:48 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r359, %r346, %r352; + xor.b32 %r360, %r344, %r350; + .loc 2 601 59 // triton_helpers.py:601:59 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + selp.b32 %r361, 0, %r360, %p126; + selp.b32 %r362, 0, %r359, %p125; + .loc 2 601 22 // triton_helpers.py:601:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r363, %r362, %r328; + xor.b32 %r364, %r361, %r327; + .loc 2 538 40 // triton_helpers.py:538:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r365, %r357, %r27; + mul.lo.s32 %r366, %r358, %r27; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r367, %r365, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r368, %r365, %r367; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r369, %r366, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r370, %r366, %r369; + .loc 2 539 41 // triton_helpers.py:539:41 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r371, %r357, %r22; + mul.lo.s32 %r372, %r358, %r22; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r373, %r371, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r374, %r371, %r373; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r375, %r372, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r376, %r372, %r375; + .loc 2 548 23 // triton_helpers.py:548:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r377, %r364, %r27; + mul.lo.s32 %r378, %r363, %r27; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r379, %r377, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r380, %r377, %r379; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r381, %r378, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r382, %r378, %r381; + .loc 2 551 23 // triton_helpers.py:551:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r383, %r364, %r22; + mul.lo.s32 %r384, %r363, %r22; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r385, %r383, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r386, %r383, %r385; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r387, %r384, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r388, %r384, %r387; + .loc 2 599 28 // triton_helpers.py:599:28 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + setp.ge.s32 %p127, %r368, %r374; + setp.ge.s32 %p128, %r370, %r376; + setp.ne.b32 %p129, %r370, %r376; + setp.ne.b32 %p130, %r368, %r374; + setp.le.s32 %p131, %r382, %r388; + setp.le.s32 %p132, %r380, %r386; + or.pred %p133, %p130, %p132; + or.pred %p134, %p129, %p131; + and.pred %p135, %p128, %p134; + and.pred %p136, %p127, %p133; + xor.pred %p137, %p136, %p28; + xor.pred %p138, %p135, %p28; + .loc 2 600 38 // triton_helpers.py:600:38 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r389, %r368, %r374; + xor.b32 %r390, %r370, %r376; + .loc 2 600 46 // triton_helpers.py:600:46 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + selp.b32 %r391, 0, %r390, %p138; + selp.b32 %r392, 0, %r389, %p137; + .loc 2 600 15 // triton_helpers.py:600:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r393, %r392, %r357; + xor.b32 %r394, %r391, %r358; + .loc 2 601 48 // triton_helpers.py:601:48 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r395, %r380, %r386; + xor.b32 %r396, %r382, %r388; + .loc 2 601 59 // triton_helpers.py:601:59 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + selp.b32 %r397, 0, %r396, %p138; + selp.b32 %r398, 0, %r395, %p137; + .loc 2 601 22 // triton_helpers.py:601:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r399, %r398, %r364; + xor.b32 %r400, %r397, %r363; + .loc 2 599 28 // triton_helpers.py:599:28 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + setp.ge.s32 %p139, %r393, %r394; + setp.ne.b32 %p140, %r393, %r394; + setp.le.s32 %p141, %r399, %r400; + or.pred %p142, %p140, %p141; + and.pred %p143, %p139, %p142; + xor.pred %p144, %p143, %p28; + .loc 2 600 38 // triton_helpers.py:600:38 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r401, %r393, %r394; + .loc 2 600 46 // triton_helpers.py:600:46 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + selp.b32 %r402, 0, %r401, %p144; + .loc 2 600 15 // triton_helpers.py:600:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r403, %r402, %r393; + xor.b32 %r404, %r402, %r394; + .loc 2 601 48 // triton_helpers.py:601:48 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r405, %r399, %r400; + .loc 2 601 59 // triton_helpers.py:601:59 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + selp.b32 %r406, 0, %r405, %p144; + .loc 2 601 22 // triton_helpers.py:601:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r407, %r406, %r399; + xor.b32 %r408, %r406, %r400; + .loc 2 538 40 // triton_helpers.py:538:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r409, %r403, %r29; + mul.lo.s32 %r410, %r404, %r29; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r411, %r409, 4, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r412, %r409, %r411; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r413, %r410, 4, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r414, %r410, %r413; + .loc 2 539 41 // triton_helpers.py:539:41 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r415, %r403, %r26; + mul.lo.s32 %r416, %r404, %r26; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r417, %r415, 4, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r418, %r415, %r417; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r419, %r416, 4, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r420, %r416, %r419; + .loc 2 548 23 // triton_helpers.py:548:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r421, %r407, %r29; + mul.lo.s32 %r422, %r408, %r29; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r423, %r421, 4, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r424, %r421, %r423; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r425, %r422, 4, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r426, %r422, %r425; + .loc 2 551 23 // triton_helpers.py:551:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r427, %r407, %r26; + mul.lo.s32 %r428, %r408, %r26; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r429, %r427, 4, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r430, %r427, %r429; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r431, %r428, 4, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r432, %r428, %r431; + .loc 2 574 22 // triton_helpers.py:574:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + setp.lt.s32 %p145, %r412, %r418; + setp.lt.s32 %p146, %r414, %r420; + .loc 2 591 21 // triton_helpers.py:591:21 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + setp.eq.b32 %p147, %r414, %r420; + setp.eq.b32 %p148, %r412, %r418; + .loc 2 594 40 // triton_helpers.py:594:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + setp.gt.s32 %p149, %r426, %r432; + setp.gt.s32 %p150, %r424, %r430; + .loc 2 594 29 // triton_helpers.py:594:29 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + and.pred %p151, %p148, %p150; + and.pred %p152, %p147, %p149; + .loc 2 594 23 // triton_helpers.py:594:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + or.pred %p153, %p146, %p152; + or.pred %p154, %p145, %p151; + .loc 2 600 38 // triton_helpers.py:600:38 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r433, %r414, %r420; + xor.b32 %r434, %r412, %r418; + .loc 2 600 46 // triton_helpers.py:600:46 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + selp.b32 %r435, %r434, 0, %p154; + selp.b32 %r436, %r433, 0, %p153; + .loc 2 600 15 // triton_helpers.py:600:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r437, %r436, %r404; + xor.b32 %r438, %r435, %r403; + .loc 2 601 48 // triton_helpers.py:601:48 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r439, %r424, %r430; + xor.b32 %r440, %r426, %r432; + .loc 2 601 59 // triton_helpers.py:601:59 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + selp.b32 %r441, %r439, 0, %p154; + selp.b32 %r442, %r440, 0, %p153; + .loc 2 601 22 // triton_helpers.py:601:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r443, %r441, %r407; + xor.b32 %r444, %r442, %r408; + .loc 2 538 40 // triton_helpers.py:538:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r445, %r438, %r28; + mul.lo.s32 %r446, %r437, %r28; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r447, %r445, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r448, %r445, %r447; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r449, %r446, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r450, %r446, %r449; + .loc 2 539 41 // triton_helpers.py:539:41 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r451, %r438, %r24; + mul.lo.s32 %r452, %r437, %r24; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r453, %r451, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r454, %r451, %r453; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r455, %r452, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r456, %r452, %r455; + .loc 2 548 23 // triton_helpers.py:548:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r457, %r443, %r28; + mul.lo.s32 %r458, %r444, %r28; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r459, %r457, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r460, %r457, %r459; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r461, %r458, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r462, %r458, %r461; + .loc 2 551 23 // triton_helpers.py:551:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r463, %r443, %r24; + mul.lo.s32 %r464, %r444, %r24; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r465, %r463, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r466, %r463, %r465; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r467, %r464, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r468, %r464, %r467; + .loc 2 574 22 // triton_helpers.py:574:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + setp.lt.s32 %p155, %r448, %r454; + setp.lt.s32 %p156, %r450, %r456; + .loc 2 591 21 // triton_helpers.py:591:21 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + setp.eq.b32 %p157, %r450, %r456; + setp.eq.b32 %p158, %r448, %r454; + .loc 2 594 40 // triton_helpers.py:594:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + setp.gt.s32 %p159, %r462, %r468; + setp.gt.s32 %p160, %r460, %r466; + .loc 2 594 29 // triton_helpers.py:594:29 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + and.pred %p161, %p158, %p160; + and.pred %p162, %p157, %p159; + .loc 2 594 23 // triton_helpers.py:594:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + or.pred %p163, %p156, %p162; + or.pred %p164, %p155, %p161; + .loc 2 600 38 // triton_helpers.py:600:38 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r469, %r450, %r456; + xor.b32 %r470, %r448, %r454; + .loc 2 600 46 // triton_helpers.py:600:46 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + selp.b32 %r471, %r470, 0, %p164; + selp.b32 %r472, %r469, 0, %p163; + .loc 2 600 15 // triton_helpers.py:600:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r473, %r472, %r437; + xor.b32 %r474, %r471, %r438; + .loc 2 601 48 // triton_helpers.py:601:48 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r475, %r460, %r466; + xor.b32 %r476, %r462, %r468; + .loc 2 601 59 // triton_helpers.py:601:59 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + selp.b32 %r477, %r475, 0, %p164; + selp.b32 %r478, %r476, 0, %p163; + .loc 2 601 22 // triton_helpers.py:601:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r479, %r477, %r443; + xor.b32 %r480, %r478, %r444; + .loc 2 538 40 // triton_helpers.py:538:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r481, %r474, %r27; + mul.lo.s32 %r482, %r473, %r27; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r483, %r481, 1, 31, -1; + shfl.sync.bfly.b32 %r485, %r482, 1, 31, -1; + .loc 2 539 41 // triton_helpers.py:539:41 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r487, %r474, %r22; + mul.lo.s32 %r488, %r473, %r22; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r489, %r487, 1, 31, -1; + shfl.sync.bfly.b32 %r491, %r488, 1, 31, -1; + .loc 2 548 23 // triton_helpers.py:548:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r493, %r479, %r27; + mul.lo.s32 %r494, %r480, %r27; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r495, %r493, 1, 31, -1; + shfl.sync.bfly.b32 %r497, %r494, 1, 31, -1; + .loc 2 551 23 // triton_helpers.py:551:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r499, %r479, %r22; + mul.lo.s32 %r500, %r480, %r22; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r501, %r499, 1, 31, -1; + shfl.sync.bfly.b32 %r503, %r500, 1, 31, -1; +$L__tmp4: + .loc 1 54 35 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:54:35 + and.pred %p178, %p1, %p4; + selp.b64 %rd17, 1, 0, %p178; + and.pred %p179, %p1, %p5; + selp.b64 %rd18, 1, 0, %p179; +$L__tmp5: + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:55:26 ] + add.s64 %rd19, %rd17, %rd18; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:55:26 ] + cvt.u32.u64 %r521, %rd19; + shfl.sync.bfly.b32 %r522, %r521, 4, 31, -1; + mov.b32 %r523, 0; + shfl.sync.bfly.b32 %r524, %r523, 4, 31, -1; + cvt.u64.u32 %rd20, %r522; + cvt.u64.u32 %rd21, %r524; + shl.b64 %rd22, %rd21, 32; + or.b64 %rd23, %rd20, %rd22; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:55:26 ] + add.s64 %rd24, %rd19, %rd23; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:55:26 ] + mov.b64 {_, %r525}, %rd24; + cvt.u32.u64 %r526, %rd24; + shfl.sync.bfly.b32 %r527, %r526, 2, 31, -1; + shfl.sync.bfly.b32 %r528, %r525, 2, 31, -1; + cvt.u64.u32 %rd25, %r527; + cvt.u64.u32 %rd26, %r528; + shl.b64 %rd27, %rd26, 32; + or.b64 %rd28, %rd25, %rd27; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:55:26 ] + add.s64 %rd29, %rd24, %rd28; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:55:26 ] + mov.b64 {_, %r529}, %rd29; + cvt.u32.u64 %r530, %rd29; + shfl.sync.bfly.b32 %r531, %r530, 1, 31, -1; + shfl.sync.bfly.b32 %r532, %r529, 1, 31, -1; + cvt.u64.u32 %rd30, %r531; + cvt.u64.u32 %rd31, %r532; + shl.b64 %rd32, %rd31, 32; + or.b64 %rd33, %rd30, %rd32; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:55:26 ] + add.s64 %rd34, %rd29, %rd33; +$L__tmp6: + .loc 1 60 21 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:60:21 + mov.b32 %r533, global_smem; + add.s32 %r534, %r533, %r3; + st.shared.b64 [%r534], %rd34; + bar.sync 0; + shl.b32 %r535, %r4, 3; + add.s32 %r536, %r533, %r535; + ld.shared.b64 %rd2, [%r536]; + .loc 1 58 35 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:58:35 + and.pred %p180, %p1, %p93; + selp.b64 %rd35, 1, 0, %p180; + and.pred %p181, %p1, %p94; + selp.b64 %rd36, 1, 0, %p181; +$L__tmp7: + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:59:26 ] + add.s64 %rd37, %rd35, %rd36; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:59:26 ] + cvt.u32.u64 %r537, %rd37; + shfl.sync.bfly.b32 %r538, %r537, 4, 31, -1; + shfl.sync.bfly.b32 %r539, %r523, 4, 31, -1; + cvt.u64.u32 %rd38, %r538; + cvt.u64.u32 %rd39, %r539; + shl.b64 %rd40, %rd39, 32; + or.b64 %rd41, %rd38, %rd40; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:59:26 ] + add.s64 %rd42, %rd37, %rd41; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:59:26 ] + mov.b64 {_, %r540}, %rd42; + cvt.u32.u64 %r541, %rd42; + shfl.sync.bfly.b32 %r542, %r541, 2, 31, -1; + shfl.sync.bfly.b32 %r543, %r540, 2, 31, -1; + cvt.u64.u32 %rd43, %r542; + cvt.u64.u32 %rd44, %r543; + shl.b64 %rd45, %rd44, 32; + or.b64 %rd46, %rd43, %rd45; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:59:26 ] + add.s64 %rd47, %rd42, %rd46; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:59:26 ] + mov.b64 {_, %r544}, %rd47; + cvt.u32.u64 %r545, %rd47; + shfl.sync.bfly.b32 %r546, %r545, 1, 31, -1; + shfl.sync.bfly.b32 %r547, %r544, 1, 31, -1; + cvt.u64.u32 %rd48, %r546; + cvt.u64.u32 %rd49, %r547; + shl.b64 %rd50, %rd49, 32; + or.b64 %rd51, %rd48, %rd50; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:59:26 ] + add.s64 %rd3, %rd47, %rd51; +$L__tmp8: + .loc 1 61 21 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:61:21 + bar.sync 0; + st.shared.b64 [%r534], %rd3; + bar.sync 0; + .loc 1 60 21 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:60:21 + cvt.u32.u64 %r548, %rd34; + .loc 1 64 19 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:64:19 + setp.lt.s32 %p182, %r6, %r548; + setp.lt.s32 %p183, %r7, %r548; + .loc 1 66 35 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:66:35 + selp.b32 %r549, %r8, 16, %p182; + selp.b32 %r550, %r9, 16, %p183; + .loc 1 68 20 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:68:20 + add.s32 %r551, %r549, 17; + add.s32 %r552, %r550, 17; + .loc 1 69 20 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:69:20 + setp.lt.s32 %p184, %r549, 0; + setp.lt.s32 %p185, %r550, 0; + .loc 1 70 35 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:70:35 + selp.b32 %r12, %r551, %r549, %p184; + selp.b32 %r13, %r552, %r550, %p185; + .loc 1 71 63 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:71:63 + max.u32 %r553, %r12, %r13; + setp.lt.u32 %p186, %r553, 17; + or.pred %p187, %p2, %p186; + @%p187 bra $L__BB0_2; + bra.uni $L__BB0_1; +$L__BB0_2: + .loc 1 0 0 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:0 + add.s32 %r484, %r481, %r483; + add.s32 %r486, %r482, %r485; + add.s32 %r490, %r487, %r489; + add.s32 %r492, %r488, %r491; + add.s32 %r496, %r493, %r495; + add.s32 %r498, %r494, %r497; + add.s32 %r502, %r499, %r501; + add.s32 %r504, %r500, %r503; + setp.lt.s32 %p165, %r486, %r492; + setp.lt.s32 %p166, %r484, %r490; + setp.eq.b32 %p167, %r484, %r490; + setp.eq.b32 %p168, %r486, %r492; + setp.gt.s32 %p169, %r496, %r502; + setp.gt.s32 %p170, %r498, %r504; + and.pred %p171, %p168, %p170; + and.pred %p172, %p167, %p169; + or.pred %p173, %p166, %p172; + or.pred %p174, %p165, %p171; + xor.b32 %r505, %r484, %r490; + xor.b32 %r506, %r486, %r492; + selp.b32 %r507, %r506, 0, %p174; + selp.b32 %r508, %r505, 0, %p173; + xor.b32 %r509, %r508, %r474; + xor.b32 %r510, %r507, %r473; + xor.b32 %r511, %r496, %r502; + xor.b32 %r512, %r498, %r504; + selp.b32 %r513, %r511, 0, %p173; + selp.b32 %r514, %r512, 0, %p174; + xor.b32 %r515, %r513, %r479; + xor.b32 %r516, %r514, %r480; + setp.lt.s32 %p175, %r509, %r510; + setp.eq.b32 %p176, %r509, %r510; + setp.gt.s32 %p177, %r515, %r516; + xor.b32 %r517, %r515, %r516; + selp.b32 %r518, %r517, 0, %p177; + selp.b32 %r519, %r518, 0, %p176; + selp.b32 %r520, %r517, %r519, %p175; + xor.b32 %r10, %r520, %r515; + xor.b32 %r11, %r520, %r516; + .loc 1 61 21 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:61:21 + ld.shared.b64 %rd4, [%r536]; + cvt.u32.u64 %r554, %rd3; + .loc 1 71 63 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:71:63 + bar.sync 0; + .loc 1 75 19 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:75:19 + setp.lt.s32 %p189, %r6, %r554; + setp.lt.s32 %p190, %r7, %r554; + .loc 1 76 35 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:76:35 + selp.b32 %r555, %r10, 16, %p189; + selp.b32 %r556, %r11, 16, %p190; + .loc 1 77 20 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:77:20 + add.s32 %r557, %r555, 17; + add.s32 %r558, %r556, 17; + .loc 1 78 20 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:78:20 + setp.lt.s32 %p191, %r555, 0; + setp.lt.s32 %p192, %r556, 0; + .loc 1 79 35 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:79:35 + selp.b32 %r14, %r557, %r555, %p191; + selp.b32 %r15, %r558, %r556, %p192; + .loc 1 80 63 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:80:63 + max.u32 %r559, %r14, %r15; + setp.lt.u32 %p193, %r559, 17; + or.pred %p194, %p2, %p193; + @%p194 bra $L__BB0_4; + bra.uni $L__BB0_3; +$L__BB0_4: + .loc 1 0 63 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:0:63 + ld.param.b64 %rd10, [triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_6]; + ld.param.b64 %rd9, [triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_5]; + ld.param.b64 %rd8, [triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_4]; + ld.param.b64 %rd7, [triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_3]; + ld.param.b64 %rd6, [triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_2]; + ld.param.b64 %rd5, [triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_1]; + cvt.s64.s32 %rd1, %r19; + .loc 1 61 21 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:61:21 + cvt.u32.u64 %r561, %rd4; + .loc 1 60 21 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:60:21 + cvt.u32.u64 %r560, %rd2; + .loc 1 25 23 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:25:23 + or.b32 %r570, %r1, %r4; + .loc 1 26 21 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:26:21 + setp.lt.s32 %p198, %r570, 32; + .loc 1 80 63 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:80:63 + bar.sync 0; + .loc 1 81 25 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:81:25 + mul.wide.s32 %rd60, %r570, 4; + add.s64 %rd52, %rd5, %rd60; + .loc 1 81 37 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:81:37 + setp.eq.b32 %p203, %r3, 0; + and.pred %p195, %p203, %p198; + // begin inline asm + @%p195 st.global.b32 [ %rd52 + 0 ], { %r560 }; + // end inline asm + .loc 1 82 25 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:82:25 + add.s64 %rd53, %rd6, %rd60; + .loc 1 82 37 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:82:37 + // begin inline asm + @%p195 st.global.b32 [ %rd53 + 0 ], { %r561 }; + // end inline asm + .loc 1 83 25 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:83:25 + shl.b64 %rd61, %rd1, 2; + add.s64 %rd54, %rd7, %rd61; + .loc 1 83 47 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:83:47 + // begin inline asm + @%p1 st.global.v2.b32 [ %rd54 + 0 ], { %r8, %r9 }; + // end inline asm + .loc 1 84 52 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:84:52 + mul.lo.s32 %r571, %r5, 17; + .loc 1 84 49 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:84:49 + add.s32 %r572, %r12, %r571; + add.s32 %r573, %r13, %r571; + .loc 1 84 25 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:84:25 + mad.wide.s32 %rd62, %r572, 4, %rd8; + mad.wide.s32 %rd63, %r573, 4, %rd8; + .loc 1 84 85 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:84:85 + bar.sync 0; + shl.b32 %r574, %r2, 5; + and.b32 %r575, %r574, 96; + and.b32 %r576, %r2, 48; + shl.b32 %r577, %r576, 3; + shl.b32 %r578, %r2, 1; + and.b32 %r579, %r578, 120; + or.b32 %r580, %r575, %r577; + xor.b32 %r581, %r580, %r579; + add.s32 %r583, %r533, %r581; + st.shared.b64 [%r583], %rd62; + xor.b32 %r584, %r581, 8; + add.s32 %r585, %r533, %r584; + st.shared.b64 [%r585+512], %rd63; + bar.sync 0; + shl.b32 %r586, %r2, 6; + and.b32 %r587, %r586, 384; + shl.b32 %r588, %r4, 4; + shl.b32 %r589, %r576, 1; + bfe.s32 %r590, %r2, 3, 1; + and.b32 %r591, %r590, 520; + or.b32 %r592, %r587, %r588; + xor.b32 %r593, %r592, %r589; + or.b32 %r594, %r593, %r591; + add.s32 %r595, %r533, %r594; + ld.shared.b64 %rd55, [%r595]; + xor.b32 %r596, %r594, 8; + add.s32 %r597, %r533, %r596; + ld.shared.b64 %rd56, [%r597]; + mov.b32 %r564, 1; + // begin inline asm + @%p198 st.global.b32 [ %rd55 + 0 ], { %r564 }; + // end inline asm + // begin inline asm + @%p198 st.global.b32 [ %rd56 + 0 ], { %r564 }; + // end inline asm + .loc 1 85 25 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:85:25 + add.s64 %rd57, %rd9, %rd61; + .loc 1 85 47 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:85:47 + // begin inline asm + @%p1 st.global.v2.b32 [ %rd57 + 0 ], { %r10, %r11 }; + // end inline asm + .loc 1 86 49 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:86:49 + add.s32 %r598, %r14, %r571; + add.s32 %r599, %r15, %r571; + .loc 1 86 25 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:86:25 + mad.wide.s32 %rd64, %r598, 4, %rd10; + mad.wide.s32 %rd65, %r599, 4, %rd10; + .loc 1 86 85 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:86:85 + bar.sync 0; + st.shared.b64 [%r583], %rd64; + st.shared.b64 [%r585+512], %rd65; + bar.sync 0; + ld.shared.b64 %rd58, [%r595]; + ld.shared.b64 %rd59, [%r597]; + // begin inline asm + @%p198 st.global.b32 [ %rd58 + 0 ], { %r564 }; + // end inline asm + // begin inline asm + @%p198 st.global.b32 [ %rd59 + 0 ], { %r564 }; + // end inline asm + .loc 1 86 4 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:86:4 + ret; +$L__BB0_1: + .loc 1 71 63 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:71:63 + { // callseq 1, 0 + .param .b64 param0; + .param .b64 param1; + .param .b32 param2; + .param .b64 param3; + .param .b64 param4; + mov.b64 %rd72, assertFunc_0; + cvta.global.u64 %rd73, %rd72; + st.param.b64 [param3], %rd73; + mov.b64 %rd74, assertFile_0; + cvta.global.u64 %rd75, %rd74; + st.param.b64 [param1], %rd75; + mov.b64 %rd76, assertMessage_0; + cvta.global.u64 %rd77, %rd76; + st.param.b64 [param0], %rd77; + st.param.b64 [param4], 1; + st.param.b32 [param2], 71; + call.uni __assertfail, (param0, param1, param2, param3, param4); + } // callseq 1 + trap; +$L__BB0_3: + .loc 1 80 63 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:80:63 + { // callseq 0, 0 + .param .b64 param0; + .param .b64 param1; + .param .b32 param2; + .param .b64 param3; + .param .b64 param4; + mov.b64 %rd66, assertFunc_1; + cvta.global.u64 %rd67, %rd66; + st.param.b64 [param3], %rd67; + mov.b64 %rd68, assertFile_1; + cvta.global.u64 %rd69, %rd68; + st.param.b64 [param1], %rd69; + mov.b64 %rd70, assertMessage_1; + cvta.global.u64 %rd71, %rd70; + st.param.b64 [param0], %rd71; + st.param.b64 [param4], 1; + st.param.b32 [param2], 80; + call.uni __assertfail, (param0, param1, param2, param3, param4); + } // callseq 0 + trap; +$L__tmp9: +$L__func_end0: + // -- End function +} + .file 1 "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py" + .file 2 "/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py" + .file 3 "/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py" + .section .debug_abbrev + { +.b8 1 // Abbreviation Code +.b8 17 // DW_TAG_compile_unit +.b8 1 // DW_CHILDREN_yes +.b8 37 // DW_AT_producer +.b8 8 // DW_FORM_string +.b8 19 // DW_AT_language +.b8 5 // DW_FORM_data2 +.b8 3 // DW_AT_name +.b8 8 // DW_FORM_string +.b8 16 // DW_AT_stmt_list +.b8 6 // DW_FORM_data4 +.b8 27 // DW_AT_comp_dir +.b8 8 // DW_FORM_string +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 2 // Abbreviation Code +.b8 46 // DW_TAG_subprogram +.b8 0 // DW_CHILDREN_no +.b8 3 // DW_AT_name +.b8 8 // DW_FORM_string +.b8 32 // DW_AT_inline +.b8 11 // DW_FORM_data1 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 3 // Abbreviation Code +.b8 46 // DW_TAG_subprogram +.b8 1 // DW_CHILDREN_yes +.b8 17 // DW_AT_low_pc +.b8 1 // DW_FORM_addr +.b8 18 // DW_AT_high_pc +.b8 1 // DW_FORM_addr +.b8 49 // DW_AT_abstract_origin +.b8 19 // DW_FORM_ref4 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 4 // Abbreviation Code +.b8 29 // DW_TAG_inlined_subroutine +.b8 0 // DW_CHILDREN_no +.b8 49 // DW_AT_abstract_origin +.b8 19 // DW_FORM_ref4 +.b8 17 // DW_AT_low_pc +.b8 1 // DW_FORM_addr +.b8 18 // DW_AT_high_pc +.b8 1 // DW_FORM_addr +.b8 88 // DW_AT_call_file +.b8 11 // DW_FORM_data1 +.b8 89 // DW_AT_call_line +.b8 11 // DW_FORM_data1 +.b8 87 // DW_AT_call_column +.b8 11 // DW_FORM_data1 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 0 // EOM(3) + } + .section .debug_info + { +.b32 376 // Length of Unit +.b8 2 // DWARF version number +.b8 0 +.b32 .debug_abbrev // Offset Into Abbrev. Section +.b8 8 // Address Size (in bytes) +.b8 1 // Abbrev [1] 0xb:0x171 DW_TAG_compile_unit +.b8 116 // DW_AT_producer +.b8 114 +.b8 105 +.b8 116 +.b8 111 +.b8 110 +.b8 0 +.b8 2 // DW_AT_language +.b8 0 +.b8 99 // DW_AT_name +.b8 98 +.b8 115 +.b8 54 +.b8 53 +.b8 50 +.b8 105 +.b8 55 +.b8 99 +.b8 116 +.b8 53 +.b8 55 +.b8 117 +.b8 103 +.b8 120 +.b8 54 +.b8 118 +.b8 110 +.b8 99 +.b8 51 +.b8 53 +.b8 110 +.b8 54 +.b8 51 +.b8 116 +.b8 105 +.b8 110 +.b8 112 +.b8 122 +.b8 117 +.b8 54 +.b8 98 +.b8 97 +.b8 51 +.b8 122 +.b8 121 +.b8 109 +.b8 117 +.b8 102 +.b8 111 +.b8 100 +.b8 105 +.b8 103 +.b8 52 +.b8 104 +.b8 112 +.b8 122 +.b8 97 +.b8 114 +.b8 119 +.b8 99 +.b8 108 +.b8 46 +.b8 112 +.b8 121 +.b8 0 +.b32 .debug_line // DW_AT_stmt_list +.b8 47 // DW_AT_comp_dir +.b8 119 +.b8 111 +.b8 114 +.b8 107 +.b8 115 +.b8 112 +.b8 97 +.b8 99 +.b8 101 +.b8 47 +.b8 104 +.b8 97 +.b8 110 +.b8 114 +.b8 117 +.b8 105 +.b8 47 +.b8 83 +.b8 112 +.b8 101 +.b8 99 +.b8 70 +.b8 111 +.b8 114 +.b8 103 +.b8 101 +.b8 45 +.b8 101 +.b8 120 +.b8 116 +.b8 47 +.b8 99 +.b8 97 +.b8 99 +.b8 104 +.b8 101 +.b8 47 +.b8 99 +.b8 111 +.b8 109 +.b8 112 +.b8 105 +.b8 108 +.b8 101 +.b8 100 +.b8 95 +.b8 107 +.b8 101 +.b8 114 +.b8 110 +.b8 101 +.b8 108 +.b8 115 +.b8 47 +.b8 98 +.b8 115 +.b8 0 +.b8 2 // Abbrev [2] 0x8b:0x7a DW_TAG_subprogram +.b8 116 // DW_AT_name +.b8 114 +.b8 105 +.b8 116 +.b8 111 +.b8 110 +.b8 95 +.b8 112 +.b8 101 +.b8 114 +.b8 95 +.b8 102 +.b8 117 +.b8 115 +.b8 101 +.b8 100 +.b8 95 +.b8 95 +.b8 116 +.b8 111 +.b8 95 +.b8 99 +.b8 111 +.b8 112 +.b8 121 +.b8 95 +.b8 97 +.b8 114 +.b8 97 +.b8 110 +.b8 103 +.b8 101 +.b8 95 +.b8 98 +.b8 105 +.b8 116 +.b8 119 +.b8 105 +.b8 115 +.b8 101 +.b8 95 +.b8 97 +.b8 110 +.b8 100 +.b8 95 +.b8 101 +.b8 113 +.b8 95 +.b8 103 +.b8 116 +.b8 95 +.b8 105 +.b8 110 +.b8 100 +.b8 101 +.b8 120 +.b8 95 +.b8 112 +.b8 117 +.b8 116 +.b8 95 +.b8 108 +.b8 116 +.b8 95 +.b8 110 +.b8 101 +.b8 119 +.b8 95 +.b8 122 +.b8 101 +.b8 114 +.b8 111 +.b8 115 +.b8 95 +.b8 115 +.b8 99 +.b8 97 +.b8 108 +.b8 97 +.b8 114 +.b8 95 +.b8 116 +.b8 101 +.b8 110 +.b8 115 +.b8 111 +.b8 114 +.b8 95 +.b8 115 +.b8 111 +.b8 114 +.b8 116 +.b8 95 +.b8 115 +.b8 117 +.b8 109 +.b8 95 +.b8 117 +.b8 110 +.b8 115 +.b8 113 +.b8 117 +.b8 101 +.b8 101 +.b8 122 +.b8 101 +.b8 95 +.b8 118 +.b8 105 +.b8 101 +.b8 119 +.b8 95 +.b8 119 +.b8 104 +.b8 101 +.b8 114 +.b8 101 +.b8 95 +.b8 50 +.b8 0 +.b8 1 // DW_AT_inline +.b8 3 // Abbrev [3] 0x105:0x76 DW_TAG_subprogram +.b64 $L__func_begin0 // DW_AT_low_pc +.b64 $L__func_end0 // DW_AT_high_pc +.b32 139 // DW_AT_abstract_origin +.b8 4 // Abbrev [4] 0x11a:0x18 DW_TAG_inlined_subroutine +.b32 139 // DW_AT_abstract_origin +.b64 $L__tmp1 // DW_AT_low_pc +.b64 $L__tmp2 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 46 // DW_AT_call_line +.b8 71 // DW_AT_call_column +.b8 4 // Abbrev [4] 0x132:0x18 DW_TAG_inlined_subroutine +.b32 139 // DW_AT_abstract_origin +.b64 $L__tmp3 // DW_AT_low_pc +.b64 $L__tmp4 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 51 // DW_AT_call_line +.b8 71 // DW_AT_call_column +.b8 4 // Abbrev [4] 0x14a:0x18 DW_TAG_inlined_subroutine +.b32 139 // DW_AT_abstract_origin +.b64 $L__tmp5 // DW_AT_low_pc +.b64 $L__tmp6 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 55 // DW_AT_call_line +.b8 26 // DW_AT_call_column +.b8 4 // Abbrev [4] 0x162:0x18 DW_TAG_inlined_subroutine +.b32 139 // DW_AT_abstract_origin +.b64 $L__tmp7 // DW_AT_low_pc +.b64 $L__tmp8 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 59 // DW_AT_call_line +.b8 26 // DW_AT_call_column +.b8 0 // End Of Children Mark +.b8 0 // End Of Children Mark + } + .section .debug_macinfo { } diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/5P66Y2R4BAVAKI2AZ4OOKOSCRGKH7SKT5SYLTP4RXGBUCBAJIDZQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.source b/SpecForge-ext/cache/compiled_kernels/triton/0/5P66Y2R4BAVAKI2AZ4OOKOSCRGKH7SKT5SYLTP4RXGBUCBAJIDZQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.source new file mode 100644 index 0000000000000000000000000000000000000000..36d4f099ba5fd8b198d8fdbf3588357040d1206c --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/5P66Y2R4BAVAKI2AZ4OOKOSCRGKH7SKT5SYLTP4RXGBUCBAJIDZQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.source @@ -0,0 +1,1405 @@ +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":18:0) +#loc91 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":640:0) +#loc95 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":607:0) +#loc103 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":518:0) +#loc141 = loc(unknown) +#loc166 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":285:0) +#loc170 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":260:0) +#loc175 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":86:0) +#loc179 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":63:0) +#loc188 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":131:0) +#loc193 = loc("in_ptr0"(#loc)) +#loc194 = loc("out_ptr4"(#loc)) +#loc195 = loc("out_ptr5"(#loc)) +#loc196 = loc("out_ptr6"(#loc)) +#loc197 = loc("out_ptr7"(#loc)) +#loc198 = loc("out_ptr8"(#loc)) +#loc199 = loc("out_ptr9"(#loc)) +#loc200 = loc("xnumel"(#loc)) +#loc201 = loc("r0_numel"(#loc)) +#loc257 = loc("x"(#loc91)) +#loc258 = loc("idxs"(#loc91)) +#loc259 = loc("x"(#loc95)) +#loc260 = loc("idxs"(#loc95)) +#loc265 = loc("x"(#loc103)) +#loc266 = loc("idxs"(#loc103)) +#loc267 = loc("flip"(#loc103)) +#loc323 = loc("input"(#loc166)) +#loc324 = loc("a"(#loc170)) +#loc325 = loc("b"(#loc170)) +#loc327 = loc("x"(#loc175)) +#loc328 = loc("x"(#loc179)) +#loc329 = loc("input"(#loc188)) +module { + tt.func public @triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %out_ptr4: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr4"(#loc)), %out_ptr5: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr5"(#loc)), %out_ptr6: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr6"(#loc)), %out_ptr7: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr7"(#loc)), %out_ptr8: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr8"(#loc)), %out_ptr9: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr9"(#loc)), %xnumel: i32 {tt.divisibility = 16 : i32} loc("xnumel"(#loc)), %r0_numel: i32 {tt.divisibility = 16 : i32} loc("r0_numel"(#loc))) attributes {noinline = false} { + %xnumel_0 = arith.constant 32 : i32 loc(#loc202) + %r0_numel_1 = arith.constant 16 : i32 loc(#loc203) + %xoffset = tt.get_program_id x : i32 loc(#loc204) + %xoffset_2 = arith.constant 8 : i32 loc(#loc205) + %xoffset_3 = arith.constant 8 : i32 loc(#loc205) + %xoffset_4 = arith.muli %xoffset, %xoffset_3 : i32 loc(#loc205) + %xindex = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> loc(#loc206) + %xindex_5 = tt.expand_dims %xindex {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> loc(#loc207) + %xindex_6 = tt.splat %xoffset_4 : i32 -> tensor<8x1xi32> loc(#loc208) + %xindex_7 = arith.addi %xindex_6, %xindex_5 : tensor<8x1xi32> loc(#loc208) + %xmask = arith.constant dense<32> : tensor<8x1xi32> loc(#loc209) + %xmask_8 = arith.cmpi slt, %xindex_7, %xmask : tensor<8x1xi32> loc(#loc209) + %r0_index = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc210) + %r0_index_9 = tt.expand_dims %r0_index {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> loc(#loc211) + %r0_offset = arith.constant 0 : i32 loc(#loc212) + %r0_mask = arith.constant true loc(#loc213) + %r0_mask_10 = arith.constant dense : tensor<8x16xi1> loc(#loc213) + %tmp0 = arith.constant 16 : i32 loc(#loc214) + %tmp0_11 = arith.constant 16 : i32 loc(#loc214) + %tmp0_12 = arith.constant dense<16> : tensor<8x1xi32> loc(#loc214) + %tmp0_13 = arith.muli %tmp0_12, %xindex_7 : tensor<8x1xi32> loc(#loc214) + %tmp0_14 = tt.broadcast %r0_index_9 : tensor<1x16xi32> -> tensor<8x16xi32> loc(#loc215) + %tmp0_15 = tt.broadcast %tmp0_13 : tensor<8x1xi32> -> tensor<8x16xi32> loc(#loc215) + %tmp0_16 = arith.addi %tmp0_14, %tmp0_15 : tensor<8x16xi32> loc(#loc215) + %tmp0_17 = tt.splat %in_ptr0 : !tt.ptr -> tensor<8x16x!tt.ptr> loc(#loc216) + %tmp0_18 = tt.addptr %tmp0_17, %tmp0_16 : tensor<8x16x!tt.ptr>, tensor<8x16xi32> loc(#loc216) + %tmp0_19 = arith.constant 0.000000e+00 : f32 loc(#loc217) + %tmp0_20 = tt.broadcast %xmask_8 : tensor<8x1xi1> -> tensor<8x16xi1> loc(#loc217) + %tmp0_21 = arith.constant dense<0.000000e+00> : tensor<8x16xf32> loc(#loc217) + %tmp0_22 = arith.fptosi %tmp0_21 : tensor<8x16xf32> to tensor<8x16xi64> loc(#loc217) + %tmp0_23 = tt.load %tmp0_18, %tmp0_20, %tmp0_22 : tensor<8x16x!tt.ptr> loc(#loc217) + %tmp1 = arith.constant 0 : i64 loc(#loc218) + %tmp1_24 = arith.constant dense<0> : tensor<1x1xi64> loc(#loc218) + %tmp2 = arith.constant dense<0> : tensor<8x16xi64> loc(#loc219) + %tmp2_25 = arith.cmpi sgt, %tmp0_23, %tmp2 : tensor<8x16xi64> loc(#loc219) + %tmp3 = arith.constant 16384 : i64 loc(#loc220) + %tmp3_26 = arith.constant dense<16384> : tensor<1x1xi64> loc(#loc220) + %tmp4 = arith.constant dense<16384> : tensor<8x16xi64> loc(#loc221) + %tmp4_27 = arith.cmpi slt, %tmp0_23, %tmp4 : tensor<8x16xi64> loc(#loc221) + %tmp5 = arith.andi %tmp2_25, %tmp4_27 : tensor<8x16xi1> loc(#loc222) + %tmp6 = arith.extui %tmp5 : tensor<8x16xi1> to tensor<8x16xi8> loc(#loc223) + %tmp7 = arith.extsi %tmp6 : tensor<8x16xi8> to tensor<8x16xi32> loc(#loc224) + %tmp9 = arith.trunci %r0_index_9 : tensor<1x16xi32> to tensor<1x16xi16> loc(#loc225) + %tmp11 = tt.broadcast %tmp9 : tensor<1x16xi16> -> tensor<8x16xi16> loc(#loc226) + %0:2 = tt.call @"torch._inductor.runtime.triton_helpers.sort_with_index__i32S8_16S_i16S8_16S__(2,)cconstexpr_None__(3,)cconstexpr_1__(4,)cconstexpr_True__(5,)cconstexpr_True_"(%tmp7, %tmp11) : (tensor<8x16xi32>, tensor<8x16xi16>) -> (tensor<8x16xi32>, tensor<8x16xi32>) loc(#loc26) + %tmp14 = arith.constant dense<16384> : tensor<8x16xi64> loc(#loc227) + %tmp14_28 = arith.cmpi eq, %tmp0_23, %tmp14 : tensor<8x16xi64> loc(#loc227) + %tmp15 = arith.extui %tmp14_28 : tensor<8x16xi1> to tensor<8x16xi8> loc(#loc228) + %tmp16 = arith.extsi %tmp15 : tensor<8x16xi8> to tensor<8x16xi32> loc(#loc229) + %1:2 = tt.call @"torch._inductor.runtime.triton_helpers.sort_with_index__i32S8_16S_i16S8_16S__(2,)cconstexpr_None__(3,)cconstexpr_1__(4,)cconstexpr_True__(5,)cconstexpr_True_"(%tmp16, %tmp11) : (tensor<8x16xi32>, tensor<8x16xi16>) -> (tensor<8x16xi32>, tensor<8x16xi32>) loc(#loc30) + %tmp20 = arith.extsi %tmp7 : tensor<8x16xi32> to tensor<8x16xi64> loc(#loc230) + %tmp23 = arith.constant 0 : i32 loc(#loc231) + %tmp23_29 = arith.constant 0 : i64 loc(#loc231) + %tmp23_30 = arith.constant dense<0> : tensor<8x16xi64> loc(#loc231) + %tmp23_31 = tt.broadcast %xmask_8 : tensor<8x1xi1> -> tensor<8x16xi1> loc(#loc231) + %tmp23_32 = arith.select %tmp23_31, %tmp20, %tmp23_30 : tensor<8x16xi1>, tensor<8x16xi64> loc(#loc231) + %tmp24 = tt.call @"triton.language.standard.sum__i64S8_16S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%tmp23_32) : (tensor<8x16xi64>) -> tensor<8xi64> loc(#loc232) + %tmp24_33 = tt.expand_dims %tmp24 {axis = 1 : i32} : tensor<8xi64> -> tensor<8x1xi64> loc(#loc233) + %tmp25 = arith.extsi %tmp16 : tensor<8x16xi32> to tensor<8x16xi64> loc(#loc234) + %tmp28 = arith.constant 0 : i32 loc(#loc235) + %tmp28_34 = arith.constant 0 : i64 loc(#loc235) + %tmp28_35 = arith.constant dense<0> : tensor<8x16xi64> loc(#loc235) + %tmp28_36 = tt.broadcast %xmask_8 : tensor<8x1xi1> -> tensor<8x16xi1> loc(#loc235) + %tmp28_37 = arith.select %tmp28_36, %tmp25, %tmp28_35 : tensor<8x16xi1>, tensor<8x16xi64> loc(#loc235) + %tmp29 = tt.call @"triton.language.standard.sum__i64S8_16S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%tmp28_37) : (tensor<8x16xi64>) -> tensor<8xi64> loc(#loc236) + %tmp29_38 = tt.expand_dims %tmp29 {axis = 1 : i32} : tensor<8xi64> -> tensor<8x1xi64> loc(#loc237) + %tmp30 = arith.trunci %tmp24_33 : tensor<8x1xi64> to tensor<8x1xi32> loc(#loc238) + %tmp31 = arith.trunci %tmp29_38 : tensor<8x1xi64> to tensor<8x1xi32> loc(#loc239) + %tmp32 = arith.extsi %0#1 : tensor<8x16xi32> to tensor<8x16xi64> loc(#loc240) + %tmp33 = arith.trunci %tmp32 : tensor<8x16xi64> to tensor<8x16xi32> loc(#loc241) + %tmp34 = tt.broadcast %r0_index_9 : tensor<1x16xi32> -> tensor<8x16xi32> loc(#loc242) + %tmp34_39 = tt.broadcast %tmp30 : tensor<8x1xi32> -> tensor<8x16xi32> loc(#loc242) + %tmp34_40 = arith.cmpi slt, %tmp34, %tmp34_39 : tensor<8x16xi32> loc(#loc242) + %tmp35 = arith.constant 16 : i32 loc(#loc243) + %tmp35_41 = arith.constant dense<16> : tensor<1x1xi32> loc(#loc243) + %tmp36 = arith.constant dense<16> : tensor<8x16xi32> loc(#loc244) + %tmp36_42 = arith.select %tmp34_40, %tmp33, %tmp36 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc244) + %tmp37 = arith.constant 17 : i32 loc(#loc245) + %tmp37_43 = arith.constant dense<17> : tensor<8x16xi32> loc(#loc245) + %tmp38 = arith.addi %tmp36_42, %tmp37_43 : tensor<8x16xi32> loc(#loc246) + %tmp39 = arith.constant 0 : i32 loc(#loc247) + %tmp39_44 = arith.constant dense<0> : tensor<8x16xi32> loc(#loc247) + %tmp39_45 = arith.cmpi slt, %tmp36_42, %tmp39_44 : tensor<8x16xi32> loc(#loc247) + %tmp40 = arith.select %tmp39_45, %tmp38, %tmp36_42 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc248) + %c0_i32 = arith.constant 0 : i32 loc(#loc50) + %cst = arith.constant dense<0> : tensor<8x16xi32> loc(#loc50) + %2 = arith.cmpi sle, %cst, %tmp40 : tensor<8x16xi32> loc(#loc50) + %c17_i32 = arith.constant 17 : i32 loc(#loc51) + %cst_46 = arith.constant dense<17> : tensor<8x16xi32> loc(#loc51) + %3 = arith.cmpi slt, %tmp40, %cst_46 : tensor<8x16xi32> loc(#loc51) + %4 = arith.andi %2, %3 : tensor<8x16xi1> loc(#loc52) + %true = arith.constant true loc(#loc53) + %cst_47 = arith.constant dense : tensor<8x1xi1> loc(#loc53) + %5 = arith.xori %xmask_8, %cst_47 : tensor<8x1xi1> loc(#loc53) + %6 = tt.broadcast %5 : tensor<8x1xi1> -> tensor<8x16xi1> loc(#loc54) + %7 = arith.ori %4, %6 : tensor<8x16xi1> loc(#loc54) + tt.assert %7, "index out of bounds: 0 <= tmp40 < 17" : tensor<8x16xi1> loc(#loc55) + %tmp42 = arith.constant 1 : i32 loc(#loc249) + %tmp42_48 = arith.constant dense<1> : tensor<1x1xi32> loc(#loc249) + %tmp43 = arith.extsi %1#1 : tensor<8x16xi32> to tensor<8x16xi64> loc(#loc250) + %tmp44 = arith.trunci %tmp43 : tensor<8x16xi64> to tensor<8x16xi32> loc(#loc251) + %tmp45 = tt.broadcast %r0_index_9 : tensor<1x16xi32> -> tensor<8x16xi32> loc(#loc252) + %tmp45_49 = tt.broadcast %tmp31 : tensor<8x1xi32> -> tensor<8x16xi32> loc(#loc252) + %tmp45_50 = arith.cmpi slt, %tmp45, %tmp45_49 : tensor<8x16xi32> loc(#loc252) + %tmp46 = arith.constant dense<16> : tensor<8x16xi32> loc(#loc253) + %tmp46_51 = arith.select %tmp45_50, %tmp44, %tmp46 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc253) + %tmp47 = arith.addi %tmp46_51, %tmp37_43 : tensor<8x16xi32> loc(#loc254) + %tmp48 = arith.constant 0 : i32 loc(#loc255) + %tmp48_52 = arith.constant dense<0> : tensor<8x16xi32> loc(#loc255) + %tmp48_53 = arith.cmpi slt, %tmp46_51, %tmp48_52 : tensor<8x16xi32> loc(#loc255) + %tmp49 = arith.select %tmp48_53, %tmp47, %tmp46_51 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc256) + %c0_i32_54 = arith.constant 0 : i32 loc(#loc64) + %cst_55 = arith.constant dense<0> : tensor<8x16xi32> loc(#loc64) + %8 = arith.cmpi sle, %cst_55, %tmp49 : tensor<8x16xi32> loc(#loc64) + %c17_i32_56 = arith.constant 17 : i32 loc(#loc65) + %cst_57 = arith.constant dense<17> : tensor<8x16xi32> loc(#loc65) + %9 = arith.cmpi slt, %tmp49, %cst_57 : tensor<8x16xi32> loc(#loc65) + %10 = arith.andi %8, %9 : tensor<8x16xi1> loc(#loc66) + %true_58 = arith.constant true loc(#loc67) + %cst_59 = arith.constant dense : tensor<8x1xi1> loc(#loc67) + %11 = arith.xori %xmask_8, %cst_59 : tensor<8x1xi1> loc(#loc67) + %12 = tt.broadcast %11 : tensor<8x1xi1> -> tensor<8x16xi1> loc(#loc68) + %13 = arith.ori %10, %12 : tensor<8x16xi1> loc(#loc68) + tt.assert %13, "index out of bounds: 0 <= tmp49 < 17" : tensor<8x16xi1> loc(#loc69) + %14 = tt.splat %out_ptr4 : !tt.ptr -> tensor<8x1x!tt.ptr> loc(#loc70) + %15 = tt.addptr %14, %xindex_7 : tensor<8x1x!tt.ptr>, tensor<8x1xi32> loc(#loc70) + tt.store %15, %tmp30, %xmask_8 : tensor<8x1x!tt.ptr> loc(#loc71) + %16 = tt.splat %out_ptr5 : !tt.ptr -> tensor<8x1x!tt.ptr> loc(#loc72) + %17 = tt.addptr %16, %xindex_7 : tensor<8x1x!tt.ptr>, tensor<8x1xi32> loc(#loc72) + tt.store %17, %tmp31, %xmask_8 : tensor<8x1x!tt.ptr> loc(#loc73) + %c16_i32 = arith.constant 16 : i32 loc(#loc74) + %c16_i32_60 = arith.constant 16 : i32 loc(#loc74) + %cst_61 = arith.constant dense<16> : tensor<8x1xi32> loc(#loc74) + %18 = arith.muli %cst_61, %xindex_7 : tensor<8x1xi32> loc(#loc74) + %19 = tt.broadcast %r0_index_9 : tensor<1x16xi32> -> tensor<8x16xi32> loc(#loc75) + %20 = tt.broadcast %18 : tensor<8x1xi32> -> tensor<8x16xi32> loc(#loc75) + %21 = arith.addi %19, %20 : tensor<8x16xi32> loc(#loc75) + %22 = tt.splat %out_ptr6 : !tt.ptr -> tensor<8x16x!tt.ptr> loc(#loc76) + %23 = tt.addptr %22, %21 : tensor<8x16x!tt.ptr>, tensor<8x16xi32> loc(#loc76) + %24 = tt.broadcast %xmask_8 : tensor<8x1xi1> -> tensor<8x16xi1> loc(#loc77) + tt.store %23, %tmp33, %24 : tensor<8x16x!tt.ptr> loc(#loc77) + %c17_i32_62 = arith.constant 17 : i32 loc(#loc78) + %c17_i32_63 = arith.constant 17 : i32 loc(#loc78) + %cst_64 = arith.constant dense<17> : tensor<8x1xi32> loc(#loc78) + %25 = arith.muli %cst_64, %xindex_7 : tensor<8x1xi32> loc(#loc78) + %26 = tt.broadcast %25 : tensor<8x1xi32> -> tensor<8x16xi32> loc(#loc79) + %27 = arith.addi %tmp40, %26 : tensor<8x16xi32> loc(#loc79) + %28 = tt.splat %out_ptr7 : !tt.ptr -> tensor<8x16x!tt.ptr> loc(#loc80) + %29 = tt.addptr %28, %27 : tensor<8x16x!tt.ptr>, tensor<8x16xi32> loc(#loc80) + %cst_65 = arith.constant dense<1> : tensor<8x16xi32> loc(#loc81) + %30 = tt.broadcast %xmask_8 : tensor<8x1xi1> -> tensor<8x16xi1> loc(#loc81) + tt.store %29, %cst_65, %30 : tensor<8x16x!tt.ptr> loc(#loc81) + %c16_i32_66 = arith.constant 16 : i32 loc(#loc82) + %c16_i32_67 = arith.constant 16 : i32 loc(#loc82) + %cst_68 = arith.constant dense<16> : tensor<8x1xi32> loc(#loc82) + %31 = arith.muli %cst_68, %xindex_7 : tensor<8x1xi32> loc(#loc82) + %32 = tt.broadcast %r0_index_9 : tensor<1x16xi32> -> tensor<8x16xi32> loc(#loc83) + %33 = tt.broadcast %31 : tensor<8x1xi32> -> tensor<8x16xi32> loc(#loc83) + %34 = arith.addi %32, %33 : tensor<8x16xi32> loc(#loc83) + %35 = tt.splat %out_ptr8 : !tt.ptr -> tensor<8x16x!tt.ptr> loc(#loc84) + %36 = tt.addptr %35, %34 : tensor<8x16x!tt.ptr>, tensor<8x16xi32> loc(#loc84) + %37 = tt.broadcast %xmask_8 : tensor<8x1xi1> -> tensor<8x16xi1> loc(#loc85) + tt.store %36, %tmp44, %37 : tensor<8x16x!tt.ptr> loc(#loc85) + %c17_i32_69 = arith.constant 17 : i32 loc(#loc86) + %c17_i32_70 = arith.constant 17 : i32 loc(#loc86) + %cst_71 = arith.constant dense<17> : tensor<8x1xi32> loc(#loc86) + %38 = arith.muli %cst_71, %xindex_7 : tensor<8x1xi32> loc(#loc86) + %39 = tt.broadcast %38 : tensor<8x1xi32> -> tensor<8x16xi32> loc(#loc87) + %40 = arith.addi %tmp49, %39 : tensor<8x16xi32> loc(#loc87) + %41 = tt.splat %out_ptr9 : !tt.ptr -> tensor<8x16x!tt.ptr> loc(#loc88) + %42 = tt.addptr %41, %40 : tensor<8x16x!tt.ptr>, tensor<8x16xi32> loc(#loc88) + %cst_72 = arith.constant dense<1> : tensor<8x16xi32> loc(#loc89) + %43 = tt.broadcast %xmask_8 : tensor<8x1xi1> -> tensor<8x16xi1> loc(#loc89) + tt.store %42, %cst_72, %43 : tensor<8x16x!tt.ptr> loc(#loc89) + tt.return loc(#loc90) + } loc(#loc) + tt.func private @"torch._inductor.runtime.triton_helpers.sort_with_index__i32S8_16S_i16S8_16S__(2,)cconstexpr_None__(3,)cconstexpr_1__(4,)cconstexpr_True__(5,)cconstexpr_True_"(%x: tensor<8x16xi32> loc("x"(#loc91)), %idxs: tensor<8x16xi16> loc("idxs"(#loc91))) -> (tensor<8x16xi32>, tensor<8x16xi32>) attributes {noinline = false} { + %0:2 = tt.call @"torch._inductor.runtime.triton_helpers._bitonic_merge_with_index__i32S8_16S_i16S8_16S__(2,)cconstexpr_None__(3,)cconstexpr_1__(4,)cconstexpr_True__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x, %idxs) : (tensor<8x16xi32>, tensor<8x16xi16>) -> (tensor<8x16xi32>, tensor<8x16xi32>) loc(#loc92) + %1:2 = tt.call @"torch._inductor.runtime.triton_helpers._bitonic_merge_with_index__i32S8_16S_i32S8_16S__(2,)cconstexpr_None__(3,)cconstexpr_2__(4,)cconstexpr_True__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%0#0, %0#1) : (tensor<8x16xi32>, tensor<8x16xi32>) -> (tensor<8x16xi32>, tensor<8x16xi32>) loc(#loc92) + %2:2 = tt.call @"torch._inductor.runtime.triton_helpers._bitonic_merge_with_index__i32S8_16S_i32S8_16S__(2,)cconstexpr_None__(3,)cconstexpr_3__(4,)cconstexpr_True__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%1#0, %1#1) : (tensor<8x16xi32>, tensor<8x16xi32>) -> (tensor<8x16xi32>, tensor<8x16xi32>) loc(#loc92) + %3:2 = tt.call @"torch._inductor.runtime.triton_helpers._bitonic_merge_with_index__i32S8_16S_i32S8_16S__(2,)cconstexpr_None__(3,)cconstexpr_4__(4,)cconstexpr_False__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%2#0, %2#1) : (tensor<8x16xi32>, tensor<8x16xi32>) -> (tensor<8x16xi32>, tensor<8x16xi32>) loc(#loc92) + tt.return %3#0, %3#1 : tensor<8x16xi32>, tensor<8x16xi32> loc(#loc93) + ^bb1: // no predecessors + %4 = ub.poison : tensor<8x16xi32> loc(#loc94) + %5 = ub.poison : tensor<8x16xi32> loc(#loc94) + tt.return %4, %5 : tensor<8x16xi32>, tensor<8x16xi32> loc(#loc94) + } loc(#loc91) + tt.func private @"torch._inductor.runtime.triton_helpers._bitonic_merge_with_index__i32S8_16S_i16S8_16S__(2,)cconstexpr_None__(3,)cconstexpr_1__(4,)cconstexpr_True__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x: tensor<8x16xi32> loc("x"(#loc95)), %idxs: tensor<8x16xi16> loc("idxs"(#loc95))) -> (tensor<8x16xi32>, tensor<8x16xi32>) attributes {noinline = false} { + %flip = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc261) + %flip_0 = tt.expand_dims %flip {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc262) + %flip_1 = tt.expand_dims %flip_0 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc262) + %flip_2 = tt.broadcast %flip_1 : tensor<1x2x1xi32> -> tensor<32x2x2xi32> loc(#loc263) + %flip_3 = tt.reshape %flip_2 : tensor<32x2x2xi32> -> tensor<8x16xi32> loc(#loc264) + %0:2 = tt.call @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S8_16S_i16S8_16S_i32S8_16S__(2,)cconstexpr_None__(4,)cconstexpr_3__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x, %idxs, %flip_3) : (tensor<8x16xi32>, tensor<8x16xi16>, tensor<8x16xi32>) -> (tensor<8x16xi32>, tensor<8x16xi32>) loc(#loc100) + tt.return %0#0, %0#1 : tensor<8x16xi32>, tensor<8x16xi32> loc(#loc101) + ^bb1: // no predecessors + %1 = ub.poison : tensor<8x16xi32> loc(#loc102) + %2 = ub.poison : tensor<8x16xi32> loc(#loc102) + tt.return %1, %2 : tensor<8x16xi32>, tensor<8x16xi32> loc(#loc102) + } loc(#loc95) + tt.func private @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S8_16S_i16S8_16S_i32S8_16S__(2,)cconstexpr_None__(4,)cconstexpr_3__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x: tensor<8x16xi32> loc("x"(#loc103)), %idxs: tensor<8x16xi16> loc("idxs"(#loc103)), %flip: tensor<8x16xi32> loc("flip"(#loc103))) -> (tensor<8x16xi32>, tensor<8x16xi32>) attributes {noinline = false} { + %y = tt.reshape %x : tensor<8x16xi32> -> tensor<64x2x1xi32> loc(#loc268) + %right_mask = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc269) + %right_mask_0 = tt.expand_dims %right_mask {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc270) + %right_mask_1 = tt.expand_dims %right_mask_0 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc270) + %left_mask = arith.constant 1 : i32 loc(#loc271) + %left_mask_2 = arith.constant 1 : i32 loc(#loc271) + %left_mask_3 = arith.constant dense<1> : tensor<1x2x1xi32> loc(#loc271) + %left_mask_4 = arith.subi %left_mask_3, %right_mask_1 : tensor<1x2x1xi32> loc(#loc271) + %ileft = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<64x2x1xi32> loc(#loc272) + %ileft_5 = arith.muli %y, %ileft : tensor<64x2x1xi32> loc(#loc272) + %ileft_6 = tt.call @"triton.language.standard.sum__i32S64_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%ileft_5) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc273) + %ileft_7 = tt.expand_dims %ileft_6 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc274) + %ileft_8 = tt.broadcast %ileft_7 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc275) + %iright = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<64x2x1xi32> loc(#loc276) + %iright_9 = arith.muli %y, %iright : tensor<64x2x1xi32> loc(#loc276) + %iright_10 = tt.call @"triton.language.standard.sum__i32S64_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%iright_9) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc277) + %iright_11 = tt.expand_dims %iright_10 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc278) + %iright_12 = tt.broadcast %iright_11 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc279) + %ileft_13 = tt.reshape %ileft_8 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc280) + %iright_14 = tt.reshape %iright_12 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc281) + %y_idx = tt.reshape %idxs : tensor<8x16xi16> -> tensor<64x2x1xi16> loc(#loc282) + %left_idx = arith.trunci %left_mask_4 : tensor<1x2x1xi32> to tensor<1x2x1xi16> loc(#loc283) + %left_idx_15 = tt.broadcast %left_idx : tensor<1x2x1xi16> -> tensor<64x2x1xi16> loc(#loc284) + %left_idx_16 = arith.muli %y_idx, %left_idx_15 : tensor<64x2x1xi16> loc(#loc284) + %left_idx_17 = tt.call @"triton.language.standard.sum__i16S64_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%left_idx_16) : (tensor<64x2x1xi16>) -> tensor<64x1xi32> loc(#loc285) + %left_idx_18 = tt.expand_dims %left_idx_17 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc286) + %left_idx_19 = tt.broadcast %left_idx_18 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc287) + %right_idx = arith.trunci %right_mask_1 : tensor<1x2x1xi32> to tensor<1x2x1xi16> loc(#loc288) + %right_idx_20 = tt.broadcast %right_idx : tensor<1x2x1xi16> -> tensor<64x2x1xi16> loc(#loc289) + %right_idx_21 = arith.muli %y_idx, %right_idx_20 : tensor<64x2x1xi16> loc(#loc289) + %right_idx_22 = tt.call @"triton.language.standard.sum__i16S64_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%right_idx_21) : (tensor<64x2x1xi16>) -> tensor<64x1xi32> loc(#loc290) + %right_idx_23 = tt.expand_dims %right_idx_22 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc291) + %right_idx_24 = tt.broadcast %right_idx_23 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc292) + %left_idx_25 = tt.reshape %left_idx_19 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc293) + %right_idx_26 = tt.reshape %right_idx_24 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc294) + %left_valid_mask = arith.constant true loc(#loc295) + %left_valid_mask_27 = arith.constant dense : tensor<8x16xi1> loc(#loc295) + %right_valid_mask = arith.constant true loc(#loc296) + %right_valid_mask_28 = arith.constant dense : tensor<8x16xi1> loc(#loc296) + %left_isnan = arith.cmpi ne, %ileft_13, %ileft_13 : tensor<8x16xi32> loc(#loc297) + %right_isnan = arith.cmpi ne, %iright_14, %iright_14 : tensor<8x16xi32> loc(#loc298) + %cond = arith.cmpi slt, %ileft_13, %iright_14 : tensor<8x16xi32> loc(#loc331) + %0 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S8_16S__(%ileft_13) : (tensor<8x16xi32>) -> i1 loc(#loc136) + %1 = scf.if %0 -> (tensor<8x16xi1>) { + %cond_49 = arith.constant true loc(#loc300) + %cond_50 = arith.constant dense : tensor<8x16xi1> loc(#loc300) + %cond_51 = arith.xori %left_isnan, %cond_50 : tensor<8x16xi1> loc(#loc300) + %cond_52 = arith.andi %right_isnan, %cond_51 : tensor<8x16xi1> loc(#loc301) + %cond_53 = arith.ori %cond, %cond_52 : tensor<8x16xi1> loc(#loc332) + scf.yield %cond_53 : tensor<8x16xi1> loc(#loc332) + } else { + scf.yield %cond : tensor<8x16xi1> loc(#loc141) + } loc(#loc137) + %eq = arith.cmpi eq, %ileft_13, %iright_14 : tensor<8x16xi32> loc(#loc333) + %2 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S8_16S__(%ileft_13) : (tensor<8x16xi32>) -> i1 loc(#loc143) + %3 = scf.if %2 -> (tensor<8x16xi1>) { + %eq_49 = arith.andi %left_isnan, %right_isnan : tensor<8x16xi1> loc(#loc304) + %eq_50 = arith.ori %eq, %eq_49 : tensor<8x16xi1> loc(#loc334) + scf.yield %eq_50 : tensor<8x16xi1> loc(#loc334) + } else { + scf.yield %eq : tensor<8x16xi1> loc(#loc141) + } loc(#loc144) + %cond_29 = arith.cmpi sgt, %left_idx_25, %right_idx_26 : tensor<8x16xi32> loc(#loc306) + %cond_30 = arith.andi %3, %cond_29 : tensor<8x16xi1> loc(#loc307) + %cond_31 = arith.ori %1, %cond_30 : tensor<8x16xi1> loc(#loc308) + %cond_32 = arith.cmpi ugt, %right_valid_mask_28, %left_valid_mask_27 : tensor<8x16xi1> loc(#loc309) + %cond_33 = arith.cmpi eq, %right_valid_mask_28, %left_valid_mask_27 : tensor<8x16xi1> loc(#loc310) + %cond_34 = arith.andi %cond_33, %cond_31 : tensor<8x16xi1> loc(#loc311) + %cond_35 = arith.ori %cond_32, %cond_34 : tensor<8x16xi1> loc(#loc312) + %cond_36 = arith.extui %cond_35 : tensor<8x16xi1> to tensor<8x16xi32> loc(#loc313) + %cond_37 = arith.xori %cond_36, %flip : tensor<8x16xi32> loc(#loc313) + %cond_38 = arith.constant 0 : i32 loc(#loc314) + %cond_39 = arith.constant dense<0> : tensor<8x16xi32> loc(#loc314) + %cond_40 = arith.cmpi ne, %cond_37, %cond_39 : tensor<8x16xi32> loc(#loc314) + %ret = arith.xori %ileft_13, %iright_14 : tensor<8x16xi32> loc(#loc315) + %ret_41 = tt.call @triton.language.standard.zeros_like__i32S8_16S__(%x) : (tensor<8x16xi32>) -> tensor<8x16xi32> loc(#loc316) + %ret_42 = arith.select %cond_40, %ret, %ret_41 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc317) + %ret_43 = arith.xori %x, %ret_42 : tensor<8x16xi32> loc(#loc318) + %new_idxs = arith.xori %left_idx_25, %right_idx_26 : tensor<8x16xi32> loc(#loc319) + %new_idxs_44 = tt.call @triton.language.standard.zeros_like__i16S8_16S__(%idxs) : (tensor<8x16xi16>) -> tensor<8x16xi16> loc(#loc320) + %new_idxs_45 = arith.extsi %new_idxs_44 : tensor<8x16xi16> to tensor<8x16xi32> loc(#loc321) + %new_idxs_46 = arith.select %cond_40, %new_idxs, %new_idxs_45 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc321) + %new_idxs_47 = arith.extsi %idxs : tensor<8x16xi16> to tensor<8x16xi32> loc(#loc322) + %new_idxs_48 = arith.xori %new_idxs_47, %new_idxs_46 : tensor<8x16xi32> loc(#loc322) + tt.return %ret_43, %new_idxs_48 : tensor<8x16xi32>, tensor<8x16xi32> loc(#loc164) + ^bb1: // no predecessors + %4 = ub.poison : tensor<8x16xi32> loc(#loc165) + %5 = ub.poison : tensor<8x16xi32> loc(#loc165) + tt.return %4, %5 : tensor<8x16xi32>, tensor<8x16xi32> loc(#loc165) + } loc(#loc103) + tt.func private @"triton.language.standard.sum__i32S64_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%input: tensor<64x2x1xi32> loc("input"(#loc166))) -> tensor<64x1xi32> attributes {noinline = false} { + %0 = "tt.reduce"(%input) <{axis = 1 : i32}> ({ + ^bb0(%arg1: i32 loc(unknown), %arg2: i32 loc(unknown)): + %2 = tt.call @triton.language.standard._sum_combine__i32_i32__(%arg1, %arg2) : (i32, i32) -> i32 loc(#loc167) + tt.reduce.return %2 : i32 loc(#loc167) + }) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc167) + tt.return %0 : tensor<64x1xi32> loc(#loc168) + ^bb1: // no predecessors + %1 = ub.poison : tensor<64x1xi32> loc(#loc169) + tt.return %1 : tensor<64x1xi32> loc(#loc169) + } loc(#loc166) + tt.func private @triton.language.standard._sum_combine__i32_i32__(%a: i32 loc("a"(#loc170)), %b: i32 loc("b"(#loc170))) -> i32 attributes {noinline = false} { + %0 = arith.addi %a, %b : i32 loc(#loc171) + tt.return %0 : i32 loc(#loc172) + ^bb1: // no predecessors + %1 = ub.poison : i32 loc(#loc173) + tt.return %1 : i32 loc(#loc173) + } loc(#loc170) + tt.func private @"triton.language.standard.sum__i16S64_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%input: tensor<64x2x1xi16> loc("input"(#loc166))) -> tensor<64x1xi32> attributes {noinline = false} { + %input_0 = arith.extsi %input : tensor<64x2x1xi16> to tensor<64x2x1xi32> loc(#loc326) + %0 = "tt.reduce"(%input_0) <{axis = 1 : i32}> ({ + ^bb0(%arg1: i32 loc(unknown), %arg2: i32 loc(unknown)): + %2 = tt.call @triton.language.standard._sum_combine__i32_i32__(%arg1, %arg2) : (i32, i32) -> i32 loc(#loc167) + tt.reduce.return %2 : i32 loc(#loc167) + }) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc167) + tt.return %0 : tensor<64x1xi32> loc(#loc168) + ^bb1: // no predecessors + %1 = ub.poison : tensor<64x1xi32> loc(#loc169) + tt.return %1 : tensor<64x1xi32> loc(#loc169) + } loc(#loc166) + tt.func private @torch._inductor.runtime.triton_helpers.is_floating__i32S8_16S__(%x: tensor<8x16xi32> loc("x"(#loc175))) -> i1 attributes {noinline = false} { + %0 = tt.call @torch._inductor.runtime.triton_helpers.promote_to_tensor__i32S8_16S__(%x) : (tensor<8x16xi32>) -> tensor<8x16xi32> loc(#loc176) + %false = arith.constant false loc(#loc177) + tt.return %false : i1 loc(#loc177) + ^bb1: // no predecessors + %1 = ub.poison : i1 loc(#loc178) + tt.return %1 : i1 loc(#loc178) + } loc(#loc175) + tt.func private @torch._inductor.runtime.triton_helpers.promote_to_tensor__i32S8_16S__(%x: tensor<8x16xi32> loc("x"(#loc179))) -> tensor<8x16xi32> attributes {noinline = false} { + %0 = tt.call @"triton.language.standard.zeros____(0, 0)cconstexpr_1__(1,)cconstexpr_int1_"() : () -> tensor<1xi1> loc(#loc180) + %1 = arith.extui %0 : tensor<1xi1> to tensor<1xi32> loc(#loc181) + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc181) + %3 = tt.broadcast %2 : tensor<1x1xi32> -> tensor<8x16xi32> loc(#loc181) + %4 = arith.addi %x, %3 : tensor<8x16xi32> loc(#loc181) + tt.return %4 : tensor<8x16xi32> loc(#loc182) + ^bb1: // no predecessors + %5 = ub.poison : tensor<8x16xi32> loc(#loc183) + tt.return %5 : tensor<8x16xi32> loc(#loc183) + } loc(#loc179) + tt.func private @"triton.language.standard.zeros____(0, 0)cconstexpr_1__(1,)cconstexpr_int1_"() -> tensor<1xi1> attributes {noinline = false} { + %false = arith.constant false loc(#loc185) + %cst = arith.constant dense : tensor<1xi1> loc(#loc185) + tt.return %cst : tensor<1xi1> loc(#loc186) + ^bb1: // no predecessors + %0 = ub.poison : tensor<1xi1> loc(#loc187) + tt.return %0 : tensor<1xi1> loc(#loc187) + } loc(#loc184) + tt.func private @triton.language.standard.zeros_like__i32S8_16S__(%input: tensor<8x16xi32> loc("input"(#loc188))) -> tensor<8x16xi32> attributes {noinline = false} { + %0 = tt.call @"triton.language.standard.zeros____(0, 0)cconstexpr_8__(0, 1)cconstexpr_16__(1,)cconstexpr_int32_"() : () -> tensor<8x16xi32> loc(#loc189) + tt.return %0 : tensor<8x16xi32> loc(#loc190) + ^bb1: // no predecessors + %1 = ub.poison : tensor<8x16xi32> loc(#loc191) + tt.return %1 : tensor<8x16xi32> loc(#loc191) + } loc(#loc188) + tt.func private @"triton.language.standard.zeros____(0, 0)cconstexpr_8__(0, 1)cconstexpr_16__(1,)cconstexpr_int32_"() -> tensor<8x16xi32> attributes {noinline = false} { + %c0_i32 = arith.constant 0 : i32 loc(#loc185) + %cst = arith.constant dense<0> : tensor<8x16xi32> loc(#loc185) + tt.return %cst : tensor<8x16xi32> loc(#loc186) + ^bb1: // no predecessors + %0 = ub.poison : tensor<8x16xi32> loc(#loc187) + tt.return %0 : tensor<8x16xi32> loc(#loc187) + } loc(#loc184) + tt.func private @triton.language.standard.zeros_like__i16S8_16S__(%input: tensor<8x16xi16> loc("input"(#loc188))) -> tensor<8x16xi16> attributes {noinline = false} { + %0 = tt.call @"triton.language.standard.zeros____(0, 0)cconstexpr_8__(0, 1)cconstexpr_16__(1,)cconstexpr_int16_"() : () -> tensor<8x16xi16> loc(#loc189) + tt.return %0 : tensor<8x16xi16> loc(#loc190) + ^bb1: // no predecessors + %1 = ub.poison : tensor<8x16xi16> loc(#loc191) + tt.return %1 : tensor<8x16xi16> loc(#loc191) + } loc(#loc188) + tt.func private @"triton.language.standard.zeros____(0, 0)cconstexpr_8__(0, 1)cconstexpr_16__(1,)cconstexpr_int16_"() -> tensor<8x16xi16> attributes {noinline = false} { + %c0_i16 = arith.constant 0 : i16 loc(#loc185) + %cst = arith.constant dense<0> : tensor<8x16xi16> loc(#loc185) + tt.return %cst : tensor<8x16xi16> loc(#loc186) + ^bb1: // no predecessors + %0 = ub.poison : tensor<8x16xi16> loc(#loc187) + tt.return %0 : tensor<8x16xi16> loc(#loc187) + } loc(#loc184) + tt.func private @"torch._inductor.runtime.triton_helpers._bitonic_merge_with_index__i32S8_16S_i32S8_16S__(2,)cconstexpr_None__(3,)cconstexpr_2__(4,)cconstexpr_True__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x: tensor<8x16xi32> loc("x"(#loc95)), %idxs: tensor<8x16xi32> loc("idxs"(#loc95))) -> (tensor<8x16xi32>, tensor<8x16xi32>) attributes {noinline = false} { + %flip = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc261) + %flip_0 = tt.expand_dims %flip {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc262) + %flip_1 = tt.expand_dims %flip_0 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc262) + %flip_2 = tt.broadcast %flip_1 : tensor<1x2x1xi32> -> tensor<16x2x4xi32> loc(#loc263) + %flip_3 = tt.reshape %flip_2 : tensor<16x2x4xi32> -> tensor<8x16xi32> loc(#loc264) + %0:2 = tt.call @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S8_16S_i32S8_16S_i32S8_16S__(2,)cconstexpr_None__(4,)cconstexpr_2__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x, %idxs, %flip_3) : (tensor<8x16xi32>, tensor<8x16xi32>, tensor<8x16xi32>) -> (tensor<8x16xi32>, tensor<8x16xi32>) loc(#loc100) + %1:2 = tt.call @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S8_16S_i32S8_16S_i32S8_16S__(2,)cconstexpr_None__(4,)cconstexpr_3__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%0#0, %0#1, %flip_3) : (tensor<8x16xi32>, tensor<8x16xi32>, tensor<8x16xi32>) -> (tensor<8x16xi32>, tensor<8x16xi32>) loc(#loc100) + tt.return %1#0, %1#1 : tensor<8x16xi32>, tensor<8x16xi32> loc(#loc101) + ^bb1: // no predecessors + %2 = ub.poison : tensor<8x16xi32> loc(#loc102) + %3 = ub.poison : tensor<8x16xi32> loc(#loc102) + tt.return %2, %3 : tensor<8x16xi32>, tensor<8x16xi32> loc(#loc102) + } loc(#loc95) + tt.func private @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S8_16S_i32S8_16S_i32S8_16S__(2,)cconstexpr_None__(4,)cconstexpr_2__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x: tensor<8x16xi32> loc("x"(#loc103)), %idxs: tensor<8x16xi32> loc("idxs"(#loc103)), %flip: tensor<8x16xi32> loc("flip"(#loc103))) -> (tensor<8x16xi32>, tensor<8x16xi32>) attributes {noinline = false} { + %y = tt.reshape %x : tensor<8x16xi32> -> tensor<32x2x2xi32> loc(#loc268) + %right_mask = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc269) + %right_mask_0 = tt.expand_dims %right_mask {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc270) + %right_mask_1 = tt.expand_dims %right_mask_0 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc270) + %left_mask = arith.constant 1 : i32 loc(#loc271) + %left_mask_2 = arith.constant 1 : i32 loc(#loc271) + %left_mask_3 = arith.constant dense<1> : tensor<1x2x1xi32> loc(#loc271) + %left_mask_4 = arith.subi %left_mask_3, %right_mask_1 : tensor<1x2x1xi32> loc(#loc271) + %ileft = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<32x2x2xi32> loc(#loc272) + %ileft_5 = arith.muli %y, %ileft : tensor<32x2x2xi32> loc(#loc272) + %ileft_6 = tt.call @"triton.language.standard.sum__i32S32_2_2S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%ileft_5) : (tensor<32x2x2xi32>) -> tensor<32x2xi32> loc(#loc273) + %ileft_7 = tt.expand_dims %ileft_6 {axis = 1 : i32} : tensor<32x2xi32> -> tensor<32x1x2xi32> loc(#loc274) + %ileft_8 = tt.broadcast %ileft_7 : tensor<32x1x2xi32> -> tensor<32x2x2xi32> loc(#loc275) + %iright = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<32x2x2xi32> loc(#loc276) + %iright_9 = arith.muli %y, %iright : tensor<32x2x2xi32> loc(#loc276) + %iright_10 = tt.call @"triton.language.standard.sum__i32S32_2_2S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%iright_9) : (tensor<32x2x2xi32>) -> tensor<32x2xi32> loc(#loc277) + %iright_11 = tt.expand_dims %iright_10 {axis = 1 : i32} : tensor<32x2xi32> -> tensor<32x1x2xi32> loc(#loc278) + %iright_12 = tt.broadcast %iright_11 : tensor<32x1x2xi32> -> tensor<32x2x2xi32> loc(#loc279) + %ileft_13 = tt.reshape %ileft_8 : tensor<32x2x2xi32> -> tensor<8x16xi32> loc(#loc280) + %iright_14 = tt.reshape %iright_12 : tensor<32x2x2xi32> -> tensor<8x16xi32> loc(#loc281) + %y_idx = tt.reshape %idxs : tensor<8x16xi32> -> tensor<32x2x2xi32> loc(#loc282) + %left_idx = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<32x2x2xi32> loc(#loc284) + %left_idx_15 = arith.muli %y_idx, %left_idx : tensor<32x2x2xi32> loc(#loc284) + %left_idx_16 = tt.call @"triton.language.standard.sum__i32S32_2_2S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%left_idx_15) : (tensor<32x2x2xi32>) -> tensor<32x2xi32> loc(#loc285) + %left_idx_17 = tt.expand_dims %left_idx_16 {axis = 1 : i32} : tensor<32x2xi32> -> tensor<32x1x2xi32> loc(#loc286) + %left_idx_18 = tt.broadcast %left_idx_17 : tensor<32x1x2xi32> -> tensor<32x2x2xi32> loc(#loc287) + %right_idx = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<32x2x2xi32> loc(#loc289) + %right_idx_19 = arith.muli %y_idx, %right_idx : tensor<32x2x2xi32> loc(#loc289) + %right_idx_20 = tt.call @"triton.language.standard.sum__i32S32_2_2S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%right_idx_19) : (tensor<32x2x2xi32>) -> tensor<32x2xi32> loc(#loc290) + %right_idx_21 = tt.expand_dims %right_idx_20 {axis = 1 : i32} : tensor<32x2xi32> -> tensor<32x1x2xi32> loc(#loc291) + %right_idx_22 = tt.broadcast %right_idx_21 : tensor<32x1x2xi32> -> tensor<32x2x2xi32> loc(#loc292) + %left_idx_23 = tt.reshape %left_idx_18 : tensor<32x2x2xi32> -> tensor<8x16xi32> loc(#loc293) + %right_idx_24 = tt.reshape %right_idx_22 : tensor<32x2x2xi32> -> tensor<8x16xi32> loc(#loc294) + %left_valid_mask = arith.constant true loc(#loc295) + %left_valid_mask_25 = arith.constant dense : tensor<8x16xi1> loc(#loc295) + %right_valid_mask = arith.constant true loc(#loc296) + %right_valid_mask_26 = arith.constant dense : tensor<8x16xi1> loc(#loc296) + %left_isnan = arith.cmpi ne, %ileft_13, %ileft_13 : tensor<8x16xi32> loc(#loc297) + %right_isnan = arith.cmpi ne, %iright_14, %iright_14 : tensor<8x16xi32> loc(#loc298) + %cond = arith.cmpi slt, %ileft_13, %iright_14 : tensor<8x16xi32> loc(#loc331) + %0 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S8_16S__(%ileft_13) : (tensor<8x16xi32>) -> i1 loc(#loc136) + %1 = scf.if %0 -> (tensor<8x16xi1>) { + %cond_45 = arith.constant true loc(#loc300) + %cond_46 = arith.constant dense : tensor<8x16xi1> loc(#loc300) + %cond_47 = arith.xori %left_isnan, %cond_46 : tensor<8x16xi1> loc(#loc300) + %cond_48 = arith.andi %right_isnan, %cond_47 : tensor<8x16xi1> loc(#loc301) + %cond_49 = arith.ori %cond, %cond_48 : tensor<8x16xi1> loc(#loc332) + scf.yield %cond_49 : tensor<8x16xi1> loc(#loc332) + } else { + scf.yield %cond : tensor<8x16xi1> loc(#loc141) + } loc(#loc137) + %eq = arith.cmpi eq, %ileft_13, %iright_14 : tensor<8x16xi32> loc(#loc333) + %2 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S8_16S__(%ileft_13) : (tensor<8x16xi32>) -> i1 loc(#loc143) + %3 = scf.if %2 -> (tensor<8x16xi1>) { + %eq_45 = arith.andi %left_isnan, %right_isnan : tensor<8x16xi1> loc(#loc304) + %eq_46 = arith.ori %eq, %eq_45 : tensor<8x16xi1> loc(#loc334) + scf.yield %eq_46 : tensor<8x16xi1> loc(#loc334) + } else { + scf.yield %eq : tensor<8x16xi1> loc(#loc141) + } loc(#loc144) + %cond_27 = arith.cmpi sgt, %left_idx_23, %right_idx_24 : tensor<8x16xi32> loc(#loc306) + %cond_28 = arith.andi %3, %cond_27 : tensor<8x16xi1> loc(#loc307) + %cond_29 = arith.ori %1, %cond_28 : tensor<8x16xi1> loc(#loc308) + %cond_30 = arith.cmpi ugt, %right_valid_mask_26, %left_valid_mask_25 : tensor<8x16xi1> loc(#loc309) + %cond_31 = arith.cmpi eq, %right_valid_mask_26, %left_valid_mask_25 : tensor<8x16xi1> loc(#loc310) + %cond_32 = arith.andi %cond_31, %cond_29 : tensor<8x16xi1> loc(#loc311) + %cond_33 = arith.ori %cond_30, %cond_32 : tensor<8x16xi1> loc(#loc312) + %cond_34 = arith.extui %cond_33 : tensor<8x16xi1> to tensor<8x16xi32> loc(#loc313) + %cond_35 = arith.xori %cond_34, %flip : tensor<8x16xi32> loc(#loc313) + %cond_36 = arith.constant 0 : i32 loc(#loc314) + %cond_37 = arith.constant dense<0> : tensor<8x16xi32> loc(#loc314) + %cond_38 = arith.cmpi ne, %cond_35, %cond_37 : tensor<8x16xi32> loc(#loc314) + %ret = arith.xori %ileft_13, %iright_14 : tensor<8x16xi32> loc(#loc315) + %ret_39 = tt.call @triton.language.standard.zeros_like__i32S8_16S__(%x) : (tensor<8x16xi32>) -> tensor<8x16xi32> loc(#loc316) + %ret_40 = arith.select %cond_38, %ret, %ret_39 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc317) + %ret_41 = arith.xori %x, %ret_40 : tensor<8x16xi32> loc(#loc318) + %new_idxs = arith.xori %left_idx_23, %right_idx_24 : tensor<8x16xi32> loc(#loc319) + %new_idxs_42 = tt.call @triton.language.standard.zeros_like__i32S8_16S__(%idxs) : (tensor<8x16xi32>) -> tensor<8x16xi32> loc(#loc320) + %new_idxs_43 = arith.select %cond_38, %new_idxs, %new_idxs_42 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc321) + %new_idxs_44 = arith.xori %idxs, %new_idxs_43 : tensor<8x16xi32> loc(#loc322) + tt.return %ret_41, %new_idxs_44 : tensor<8x16xi32>, tensor<8x16xi32> loc(#loc164) + ^bb1: // no predecessors + %4 = ub.poison : tensor<8x16xi32> loc(#loc165) + %5 = ub.poison : tensor<8x16xi32> loc(#loc165) + tt.return %4, %5 : tensor<8x16xi32>, tensor<8x16xi32> loc(#loc165) + } loc(#loc103) + tt.func private @"triton.language.standard.sum__i32S32_2_2S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%input: tensor<32x2x2xi32> loc("input"(#loc166))) -> tensor<32x2xi32> attributes {noinline = false} { + %0 = "tt.reduce"(%input) <{axis = 1 : i32}> ({ + ^bb0(%arg1: i32 loc(unknown), %arg2: i32 loc(unknown)): + %2 = tt.call @triton.language.standard._sum_combine__i32_i32__(%arg1, %arg2) : (i32, i32) -> i32 loc(#loc167) + tt.reduce.return %2 : i32 loc(#loc167) + }) : (tensor<32x2x2xi32>) -> tensor<32x2xi32> loc(#loc167) + tt.return %0 : tensor<32x2xi32> loc(#loc168) + ^bb1: // no predecessors + %1 = ub.poison : tensor<32x2xi32> loc(#loc169) + tt.return %1 : tensor<32x2xi32> loc(#loc169) + } loc(#loc166) + tt.func private @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S8_16S_i32S8_16S_i32S8_16S__(2,)cconstexpr_None__(4,)cconstexpr_3__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x: tensor<8x16xi32> loc("x"(#loc103)), %idxs: tensor<8x16xi32> loc("idxs"(#loc103)), %flip: tensor<8x16xi32> loc("flip"(#loc103))) -> (tensor<8x16xi32>, tensor<8x16xi32>) attributes {noinline = false} { + %y = tt.reshape %x : tensor<8x16xi32> -> tensor<64x2x1xi32> loc(#loc268) + %right_mask = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc269) + %right_mask_0 = tt.expand_dims %right_mask {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc270) + %right_mask_1 = tt.expand_dims %right_mask_0 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc270) + %left_mask = arith.constant 1 : i32 loc(#loc271) + %left_mask_2 = arith.constant 1 : i32 loc(#loc271) + %left_mask_3 = arith.constant dense<1> : tensor<1x2x1xi32> loc(#loc271) + %left_mask_4 = arith.subi %left_mask_3, %right_mask_1 : tensor<1x2x1xi32> loc(#loc271) + %ileft = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<64x2x1xi32> loc(#loc272) + %ileft_5 = arith.muli %y, %ileft : tensor<64x2x1xi32> loc(#loc272) + %ileft_6 = tt.call @"triton.language.standard.sum__i32S64_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%ileft_5) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc273) + %ileft_7 = tt.expand_dims %ileft_6 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc274) + %ileft_8 = tt.broadcast %ileft_7 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc275) + %iright = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<64x2x1xi32> loc(#loc276) + %iright_9 = arith.muli %y, %iright : tensor<64x2x1xi32> loc(#loc276) + %iright_10 = tt.call @"triton.language.standard.sum__i32S64_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%iright_9) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc277) + %iright_11 = tt.expand_dims %iright_10 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc278) + %iright_12 = tt.broadcast %iright_11 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc279) + %ileft_13 = tt.reshape %ileft_8 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc280) + %iright_14 = tt.reshape %iright_12 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc281) + %y_idx = tt.reshape %idxs : tensor<8x16xi32> -> tensor<64x2x1xi32> loc(#loc282) + %left_idx = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<64x2x1xi32> loc(#loc284) + %left_idx_15 = arith.muli %y_idx, %left_idx : tensor<64x2x1xi32> loc(#loc284) + %left_idx_16 = tt.call @"triton.language.standard.sum__i32S64_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%left_idx_15) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc285) + %left_idx_17 = tt.expand_dims %left_idx_16 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc286) + %left_idx_18 = tt.broadcast %left_idx_17 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc287) + %right_idx = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<64x2x1xi32> loc(#loc289) + %right_idx_19 = arith.muli %y_idx, %right_idx : tensor<64x2x1xi32> loc(#loc289) + %right_idx_20 = tt.call @"triton.language.standard.sum__i32S64_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%right_idx_19) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc290) + %right_idx_21 = tt.expand_dims %right_idx_20 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc291) + %right_idx_22 = tt.broadcast %right_idx_21 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc292) + %left_idx_23 = tt.reshape %left_idx_18 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc293) + %right_idx_24 = tt.reshape %right_idx_22 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc294) + %left_valid_mask = arith.constant true loc(#loc295) + %left_valid_mask_25 = arith.constant dense : tensor<8x16xi1> loc(#loc295) + %right_valid_mask = arith.constant true loc(#loc296) + %right_valid_mask_26 = arith.constant dense : tensor<8x16xi1> loc(#loc296) + %left_isnan = arith.cmpi ne, %ileft_13, %ileft_13 : tensor<8x16xi32> loc(#loc297) + %right_isnan = arith.cmpi ne, %iright_14, %iright_14 : tensor<8x16xi32> loc(#loc298) + %cond = arith.cmpi slt, %ileft_13, %iright_14 : tensor<8x16xi32> loc(#loc331) + %0 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S8_16S__(%ileft_13) : (tensor<8x16xi32>) -> i1 loc(#loc136) + %1 = scf.if %0 -> (tensor<8x16xi1>) { + %cond_45 = arith.constant true loc(#loc300) + %cond_46 = arith.constant dense : tensor<8x16xi1> loc(#loc300) + %cond_47 = arith.xori %left_isnan, %cond_46 : tensor<8x16xi1> loc(#loc300) + %cond_48 = arith.andi %right_isnan, %cond_47 : tensor<8x16xi1> loc(#loc301) + %cond_49 = arith.ori %cond, %cond_48 : tensor<8x16xi1> loc(#loc332) + scf.yield %cond_49 : tensor<8x16xi1> loc(#loc332) + } else { + scf.yield %cond : tensor<8x16xi1> loc(#loc141) + } loc(#loc137) + %eq = arith.cmpi eq, %ileft_13, %iright_14 : tensor<8x16xi32> loc(#loc333) + %2 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S8_16S__(%ileft_13) : (tensor<8x16xi32>) -> i1 loc(#loc143) + %3 = scf.if %2 -> (tensor<8x16xi1>) { + %eq_45 = arith.andi %left_isnan, %right_isnan : tensor<8x16xi1> loc(#loc304) + %eq_46 = arith.ori %eq, %eq_45 : tensor<8x16xi1> loc(#loc334) + scf.yield %eq_46 : tensor<8x16xi1> loc(#loc334) + } else { + scf.yield %eq : tensor<8x16xi1> loc(#loc141) + } loc(#loc144) + %cond_27 = arith.cmpi sgt, %left_idx_23, %right_idx_24 : tensor<8x16xi32> loc(#loc306) + %cond_28 = arith.andi %3, %cond_27 : tensor<8x16xi1> loc(#loc307) + %cond_29 = arith.ori %1, %cond_28 : tensor<8x16xi1> loc(#loc308) + %cond_30 = arith.cmpi ugt, %right_valid_mask_26, %left_valid_mask_25 : tensor<8x16xi1> loc(#loc309) + %cond_31 = arith.cmpi eq, %right_valid_mask_26, %left_valid_mask_25 : tensor<8x16xi1> loc(#loc310) + %cond_32 = arith.andi %cond_31, %cond_29 : tensor<8x16xi1> loc(#loc311) + %cond_33 = arith.ori %cond_30, %cond_32 : tensor<8x16xi1> loc(#loc312) + %cond_34 = arith.extui %cond_33 : tensor<8x16xi1> to tensor<8x16xi32> loc(#loc313) + %cond_35 = arith.xori %cond_34, %flip : tensor<8x16xi32> loc(#loc313) + %cond_36 = arith.constant 0 : i32 loc(#loc314) + %cond_37 = arith.constant dense<0> : tensor<8x16xi32> loc(#loc314) + %cond_38 = arith.cmpi ne, %cond_35, %cond_37 : tensor<8x16xi32> loc(#loc314) + %ret = arith.xori %ileft_13, %iright_14 : tensor<8x16xi32> loc(#loc315) + %ret_39 = tt.call @triton.language.standard.zeros_like__i32S8_16S__(%x) : (tensor<8x16xi32>) -> tensor<8x16xi32> loc(#loc316) + %ret_40 = arith.select %cond_38, %ret, %ret_39 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc317) + %ret_41 = arith.xori %x, %ret_40 : tensor<8x16xi32> loc(#loc318) + %new_idxs = arith.xori %left_idx_23, %right_idx_24 : tensor<8x16xi32> loc(#loc319) + %new_idxs_42 = tt.call @triton.language.standard.zeros_like__i32S8_16S__(%idxs) : (tensor<8x16xi32>) -> tensor<8x16xi32> loc(#loc320) + %new_idxs_43 = arith.select %cond_38, %new_idxs, %new_idxs_42 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc321) + %new_idxs_44 = arith.xori %idxs, %new_idxs_43 : tensor<8x16xi32> loc(#loc322) + tt.return %ret_41, %new_idxs_44 : tensor<8x16xi32>, tensor<8x16xi32> loc(#loc164) + ^bb1: // no predecessors + %4 = ub.poison : tensor<8x16xi32> loc(#loc165) + %5 = ub.poison : tensor<8x16xi32> loc(#loc165) + tt.return %4, %5 : tensor<8x16xi32>, tensor<8x16xi32> loc(#loc165) + } loc(#loc103) + tt.func private @"torch._inductor.runtime.triton_helpers._bitonic_merge_with_index__i32S8_16S_i32S8_16S__(2,)cconstexpr_None__(3,)cconstexpr_3__(4,)cconstexpr_True__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x: tensor<8x16xi32> loc("x"(#loc95)), %idxs: tensor<8x16xi32> loc("idxs"(#loc95))) -> (tensor<8x16xi32>, tensor<8x16xi32>) attributes {noinline = false} { + %flip = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc261) + %flip_0 = tt.expand_dims %flip {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc262) + %flip_1 = tt.expand_dims %flip_0 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc262) + %flip_2 = tt.broadcast %flip_1 : tensor<1x2x1xi32> -> tensor<8x2x8xi32> loc(#loc263) + %flip_3 = tt.reshape %flip_2 : tensor<8x2x8xi32> -> tensor<8x16xi32> loc(#loc264) + %0:2 = tt.call @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S8_16S_i32S8_16S_i32S8_16S__(2,)cconstexpr_None__(4,)cconstexpr_1__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x, %idxs, %flip_3) : (tensor<8x16xi32>, tensor<8x16xi32>, tensor<8x16xi32>) -> (tensor<8x16xi32>, tensor<8x16xi32>) loc(#loc100) + %1:2 = tt.call @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S8_16S_i32S8_16S_i32S8_16S__(2,)cconstexpr_None__(4,)cconstexpr_2__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%0#0, %0#1, %flip_3) : (tensor<8x16xi32>, tensor<8x16xi32>, tensor<8x16xi32>) -> (tensor<8x16xi32>, tensor<8x16xi32>) loc(#loc100) + %2:2 = tt.call @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S8_16S_i32S8_16S_i32S8_16S__(2,)cconstexpr_None__(4,)cconstexpr_3__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%1#0, %1#1, %flip_3) : (tensor<8x16xi32>, tensor<8x16xi32>, tensor<8x16xi32>) -> (tensor<8x16xi32>, tensor<8x16xi32>) loc(#loc100) + tt.return %2#0, %2#1 : tensor<8x16xi32>, tensor<8x16xi32> loc(#loc101) + ^bb1: // no predecessors + %3 = ub.poison : tensor<8x16xi32> loc(#loc102) + %4 = ub.poison : tensor<8x16xi32> loc(#loc102) + tt.return %3, %4 : tensor<8x16xi32>, tensor<8x16xi32> loc(#loc102) + } loc(#loc95) + tt.func private @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S8_16S_i32S8_16S_i32S8_16S__(2,)cconstexpr_None__(4,)cconstexpr_1__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x: tensor<8x16xi32> loc("x"(#loc103)), %idxs: tensor<8x16xi32> loc("idxs"(#loc103)), %flip: tensor<8x16xi32> loc("flip"(#loc103))) -> (tensor<8x16xi32>, tensor<8x16xi32>) attributes {noinline = false} { + %y = tt.reshape %x : tensor<8x16xi32> -> tensor<16x2x4xi32> loc(#loc268) + %right_mask = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc269) + %right_mask_0 = tt.expand_dims %right_mask {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc270) + %right_mask_1 = tt.expand_dims %right_mask_0 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc270) + %left_mask = arith.constant 1 : i32 loc(#loc271) + %left_mask_2 = arith.constant 1 : i32 loc(#loc271) + %left_mask_3 = arith.constant dense<1> : tensor<1x2x1xi32> loc(#loc271) + %left_mask_4 = arith.subi %left_mask_3, %right_mask_1 : tensor<1x2x1xi32> loc(#loc271) + %ileft = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<16x2x4xi32> loc(#loc272) + %ileft_5 = arith.muli %y, %ileft : tensor<16x2x4xi32> loc(#loc272) + %ileft_6 = tt.call @"triton.language.standard.sum__i32S16_2_4S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%ileft_5) : (tensor<16x2x4xi32>) -> tensor<16x4xi32> loc(#loc273) + %ileft_7 = tt.expand_dims %ileft_6 {axis = 1 : i32} : tensor<16x4xi32> -> tensor<16x1x4xi32> loc(#loc274) + %ileft_8 = tt.broadcast %ileft_7 : tensor<16x1x4xi32> -> tensor<16x2x4xi32> loc(#loc275) + %iright = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<16x2x4xi32> loc(#loc276) + %iright_9 = arith.muli %y, %iright : tensor<16x2x4xi32> loc(#loc276) + %iright_10 = tt.call @"triton.language.standard.sum__i32S16_2_4S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%iright_9) : (tensor<16x2x4xi32>) -> tensor<16x4xi32> loc(#loc277) + %iright_11 = tt.expand_dims %iright_10 {axis = 1 : i32} : tensor<16x4xi32> -> tensor<16x1x4xi32> loc(#loc278) + %iright_12 = tt.broadcast %iright_11 : tensor<16x1x4xi32> -> tensor<16x2x4xi32> loc(#loc279) + %ileft_13 = tt.reshape %ileft_8 : tensor<16x2x4xi32> -> tensor<8x16xi32> loc(#loc280) + %iright_14 = tt.reshape %iright_12 : tensor<16x2x4xi32> -> tensor<8x16xi32> loc(#loc281) + %y_idx = tt.reshape %idxs : tensor<8x16xi32> -> tensor<16x2x4xi32> loc(#loc282) + %left_idx = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<16x2x4xi32> loc(#loc284) + %left_idx_15 = arith.muli %y_idx, %left_idx : tensor<16x2x4xi32> loc(#loc284) + %left_idx_16 = tt.call @"triton.language.standard.sum__i32S16_2_4S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%left_idx_15) : (tensor<16x2x4xi32>) -> tensor<16x4xi32> loc(#loc285) + %left_idx_17 = tt.expand_dims %left_idx_16 {axis = 1 : i32} : tensor<16x4xi32> -> tensor<16x1x4xi32> loc(#loc286) + %left_idx_18 = tt.broadcast %left_idx_17 : tensor<16x1x4xi32> -> tensor<16x2x4xi32> loc(#loc287) + %right_idx = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<16x2x4xi32> loc(#loc289) + %right_idx_19 = arith.muli %y_idx, %right_idx : tensor<16x2x4xi32> loc(#loc289) + %right_idx_20 = tt.call @"triton.language.standard.sum__i32S16_2_4S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%right_idx_19) : (tensor<16x2x4xi32>) -> tensor<16x4xi32> loc(#loc290) + %right_idx_21 = tt.expand_dims %right_idx_20 {axis = 1 : i32} : tensor<16x4xi32> -> tensor<16x1x4xi32> loc(#loc291) + %right_idx_22 = tt.broadcast %right_idx_21 : tensor<16x1x4xi32> -> tensor<16x2x4xi32> loc(#loc292) + %left_idx_23 = tt.reshape %left_idx_18 : tensor<16x2x4xi32> -> tensor<8x16xi32> loc(#loc293) + %right_idx_24 = tt.reshape %right_idx_22 : tensor<16x2x4xi32> -> tensor<8x16xi32> loc(#loc294) + %left_valid_mask = arith.constant true loc(#loc295) + %left_valid_mask_25 = arith.constant dense : tensor<8x16xi1> loc(#loc295) + %right_valid_mask = arith.constant true loc(#loc296) + %right_valid_mask_26 = arith.constant dense : tensor<8x16xi1> loc(#loc296) + %left_isnan = arith.cmpi ne, %ileft_13, %ileft_13 : tensor<8x16xi32> loc(#loc297) + %right_isnan = arith.cmpi ne, %iright_14, %iright_14 : tensor<8x16xi32> loc(#loc298) + %cond = arith.cmpi slt, %ileft_13, %iright_14 : tensor<8x16xi32> loc(#loc331) + %0 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S8_16S__(%ileft_13) : (tensor<8x16xi32>) -> i1 loc(#loc136) + %1 = scf.if %0 -> (tensor<8x16xi1>) { + %cond_45 = arith.constant true loc(#loc300) + %cond_46 = arith.constant dense : tensor<8x16xi1> loc(#loc300) + %cond_47 = arith.xori %left_isnan, %cond_46 : tensor<8x16xi1> loc(#loc300) + %cond_48 = arith.andi %right_isnan, %cond_47 : tensor<8x16xi1> loc(#loc301) + %cond_49 = arith.ori %cond, %cond_48 : tensor<8x16xi1> loc(#loc332) + scf.yield %cond_49 : tensor<8x16xi1> loc(#loc332) + } else { + scf.yield %cond : tensor<8x16xi1> loc(#loc141) + } loc(#loc137) + %eq = arith.cmpi eq, %ileft_13, %iright_14 : tensor<8x16xi32> loc(#loc333) + %2 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S8_16S__(%ileft_13) : (tensor<8x16xi32>) -> i1 loc(#loc143) + %3 = scf.if %2 -> (tensor<8x16xi1>) { + %eq_45 = arith.andi %left_isnan, %right_isnan : tensor<8x16xi1> loc(#loc304) + %eq_46 = arith.ori %eq, %eq_45 : tensor<8x16xi1> loc(#loc334) + scf.yield %eq_46 : tensor<8x16xi1> loc(#loc334) + } else { + scf.yield %eq : tensor<8x16xi1> loc(#loc141) + } loc(#loc144) + %cond_27 = arith.cmpi sgt, %left_idx_23, %right_idx_24 : tensor<8x16xi32> loc(#loc306) + %cond_28 = arith.andi %3, %cond_27 : tensor<8x16xi1> loc(#loc307) + %cond_29 = arith.ori %1, %cond_28 : tensor<8x16xi1> loc(#loc308) + %cond_30 = arith.cmpi ugt, %right_valid_mask_26, %left_valid_mask_25 : tensor<8x16xi1> loc(#loc309) + %cond_31 = arith.cmpi eq, %right_valid_mask_26, %left_valid_mask_25 : tensor<8x16xi1> loc(#loc310) + %cond_32 = arith.andi %cond_31, %cond_29 : tensor<8x16xi1> loc(#loc311) + %cond_33 = arith.ori %cond_30, %cond_32 : tensor<8x16xi1> loc(#loc312) + %cond_34 = arith.extui %cond_33 : tensor<8x16xi1> to tensor<8x16xi32> loc(#loc313) + %cond_35 = arith.xori %cond_34, %flip : tensor<8x16xi32> loc(#loc313) + %cond_36 = arith.constant 0 : i32 loc(#loc314) + %cond_37 = arith.constant dense<0> : tensor<8x16xi32> loc(#loc314) + %cond_38 = arith.cmpi ne, %cond_35, %cond_37 : tensor<8x16xi32> loc(#loc314) + %ret = arith.xori %ileft_13, %iright_14 : tensor<8x16xi32> loc(#loc315) + %ret_39 = tt.call @triton.language.standard.zeros_like__i32S8_16S__(%x) : (tensor<8x16xi32>) -> tensor<8x16xi32> loc(#loc316) + %ret_40 = arith.select %cond_38, %ret, %ret_39 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc317) + %ret_41 = arith.xori %x, %ret_40 : tensor<8x16xi32> loc(#loc318) + %new_idxs = arith.xori %left_idx_23, %right_idx_24 : tensor<8x16xi32> loc(#loc319) + %new_idxs_42 = tt.call @triton.language.standard.zeros_like__i32S8_16S__(%idxs) : (tensor<8x16xi32>) -> tensor<8x16xi32> loc(#loc320) + %new_idxs_43 = arith.select %cond_38, %new_idxs, %new_idxs_42 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc321) + %new_idxs_44 = arith.xori %idxs, %new_idxs_43 : tensor<8x16xi32> loc(#loc322) + tt.return %ret_41, %new_idxs_44 : tensor<8x16xi32>, tensor<8x16xi32> loc(#loc164) + ^bb1: // no predecessors + %4 = ub.poison : tensor<8x16xi32> loc(#loc165) + %5 = ub.poison : tensor<8x16xi32> loc(#loc165) + tt.return %4, %5 : tensor<8x16xi32>, tensor<8x16xi32> loc(#loc165) + } loc(#loc103) + tt.func private @"triton.language.standard.sum__i32S16_2_4S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%input: tensor<16x2x4xi32> loc("input"(#loc166))) -> tensor<16x4xi32> attributes {noinline = false} { + %0 = "tt.reduce"(%input) <{axis = 1 : i32}> ({ + ^bb0(%arg1: i32 loc(unknown), %arg2: i32 loc(unknown)): + %2 = tt.call @triton.language.standard._sum_combine__i32_i32__(%arg1, %arg2) : (i32, i32) -> i32 loc(#loc167) + tt.reduce.return %2 : i32 loc(#loc167) + }) : (tensor<16x2x4xi32>) -> tensor<16x4xi32> loc(#loc167) + tt.return %0 : tensor<16x4xi32> loc(#loc168) + ^bb1: // no predecessors + %1 = ub.poison : tensor<16x4xi32> loc(#loc169) + tt.return %1 : tensor<16x4xi32> loc(#loc169) + } loc(#loc166) + tt.func private @"torch._inductor.runtime.triton_helpers._bitonic_merge_with_index__i32S8_16S_i32S8_16S__(2,)cconstexpr_None__(3,)cconstexpr_4__(4,)cconstexpr_False__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x: tensor<8x16xi32> loc("x"(#loc95)), %idxs: tensor<8x16xi32> loc("idxs"(#loc95))) -> (tensor<8x16xi32>, tensor<8x16xi32>) attributes {noinline = false} { + %flip = arith.constant false loc(#loc330) + %0:2 = tt.call @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S8_16S_i32S8_16S_u1__(2,)cconstexpr_None__(4,)cconstexpr_0__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x, %idxs, %flip) : (tensor<8x16xi32>, tensor<8x16xi32>, i1) -> (tensor<8x16xi32>, tensor<8x16xi32>) loc(#loc100) + %1:2 = tt.call @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S8_16S_i32S8_16S_u1__(2,)cconstexpr_None__(4,)cconstexpr_1__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%0#0, %0#1, %flip) : (tensor<8x16xi32>, tensor<8x16xi32>, i1) -> (tensor<8x16xi32>, tensor<8x16xi32>) loc(#loc100) + %2:2 = tt.call @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S8_16S_i32S8_16S_u1__(2,)cconstexpr_None__(4,)cconstexpr_2__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%1#0, %1#1, %flip) : (tensor<8x16xi32>, tensor<8x16xi32>, i1) -> (tensor<8x16xi32>, tensor<8x16xi32>) loc(#loc100) + %3:2 = tt.call @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S8_16S_i32S8_16S_u1__(2,)cconstexpr_None__(4,)cconstexpr_3__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%2#0, %2#1, %flip) : (tensor<8x16xi32>, tensor<8x16xi32>, i1) -> (tensor<8x16xi32>, tensor<8x16xi32>) loc(#loc100) + tt.return %3#0, %3#1 : tensor<8x16xi32>, tensor<8x16xi32> loc(#loc101) + ^bb1: // no predecessors + %4 = ub.poison : tensor<8x16xi32> loc(#loc102) + %5 = ub.poison : tensor<8x16xi32> loc(#loc102) + tt.return %4, %5 : tensor<8x16xi32>, tensor<8x16xi32> loc(#loc102) + } loc(#loc95) + tt.func private @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S8_16S_i32S8_16S_u1__(2,)cconstexpr_None__(4,)cconstexpr_0__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x: tensor<8x16xi32> loc("x"(#loc103)), %idxs: tensor<8x16xi32> loc("idxs"(#loc103)), %flip: i1 loc("flip"(#loc103))) -> (tensor<8x16xi32>, tensor<8x16xi32>) attributes {noinline = false} { + %y = tt.reshape %x : tensor<8x16xi32> -> tensor<8x2x8xi32> loc(#loc268) + %right_mask = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc269) + %right_mask_0 = tt.expand_dims %right_mask {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc270) + %right_mask_1 = tt.expand_dims %right_mask_0 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc270) + %left_mask = arith.constant 1 : i32 loc(#loc271) + %left_mask_2 = arith.constant 1 : i32 loc(#loc271) + %left_mask_3 = arith.constant dense<1> : tensor<1x2x1xi32> loc(#loc271) + %left_mask_4 = arith.subi %left_mask_3, %right_mask_1 : tensor<1x2x1xi32> loc(#loc271) + %ileft = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<8x2x8xi32> loc(#loc272) + %ileft_5 = arith.muli %y, %ileft : tensor<8x2x8xi32> loc(#loc272) + %ileft_6 = tt.call @"triton.language.standard.sum__i32S8_2_8S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%ileft_5) : (tensor<8x2x8xi32>) -> tensor<8x8xi32> loc(#loc273) + %ileft_7 = tt.expand_dims %ileft_6 {axis = 1 : i32} : tensor<8x8xi32> -> tensor<8x1x8xi32> loc(#loc274) + %ileft_8 = tt.broadcast %ileft_7 : tensor<8x1x8xi32> -> tensor<8x2x8xi32> loc(#loc275) + %iright = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<8x2x8xi32> loc(#loc276) + %iright_9 = arith.muli %y, %iright : tensor<8x2x8xi32> loc(#loc276) + %iright_10 = tt.call @"triton.language.standard.sum__i32S8_2_8S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%iright_9) : (tensor<8x2x8xi32>) -> tensor<8x8xi32> loc(#loc277) + %iright_11 = tt.expand_dims %iright_10 {axis = 1 : i32} : tensor<8x8xi32> -> tensor<8x1x8xi32> loc(#loc278) + %iright_12 = tt.broadcast %iright_11 : tensor<8x1x8xi32> -> tensor<8x2x8xi32> loc(#loc279) + %ileft_13 = tt.reshape %ileft_8 : tensor<8x2x8xi32> -> tensor<8x16xi32> loc(#loc280) + %iright_14 = tt.reshape %iright_12 : tensor<8x2x8xi32> -> tensor<8x16xi32> loc(#loc281) + %y_idx = tt.reshape %idxs : tensor<8x16xi32> -> tensor<8x2x8xi32> loc(#loc282) + %left_idx = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<8x2x8xi32> loc(#loc284) + %left_idx_15 = arith.muli %y_idx, %left_idx : tensor<8x2x8xi32> loc(#loc284) + %left_idx_16 = tt.call @"triton.language.standard.sum__i32S8_2_8S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%left_idx_15) : (tensor<8x2x8xi32>) -> tensor<8x8xi32> loc(#loc285) + %left_idx_17 = tt.expand_dims %left_idx_16 {axis = 1 : i32} : tensor<8x8xi32> -> tensor<8x1x8xi32> loc(#loc286) + %left_idx_18 = tt.broadcast %left_idx_17 : tensor<8x1x8xi32> -> tensor<8x2x8xi32> loc(#loc287) + %right_idx = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<8x2x8xi32> loc(#loc289) + %right_idx_19 = arith.muli %y_idx, %right_idx : tensor<8x2x8xi32> loc(#loc289) + %right_idx_20 = tt.call @"triton.language.standard.sum__i32S8_2_8S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%right_idx_19) : (tensor<8x2x8xi32>) -> tensor<8x8xi32> loc(#loc290) + %right_idx_21 = tt.expand_dims %right_idx_20 {axis = 1 : i32} : tensor<8x8xi32> -> tensor<8x1x8xi32> loc(#loc291) + %right_idx_22 = tt.broadcast %right_idx_21 : tensor<8x1x8xi32> -> tensor<8x2x8xi32> loc(#loc292) + %left_idx_23 = tt.reshape %left_idx_18 : tensor<8x2x8xi32> -> tensor<8x16xi32> loc(#loc293) + %right_idx_24 = tt.reshape %right_idx_22 : tensor<8x2x8xi32> -> tensor<8x16xi32> loc(#loc294) + %left_valid_mask = arith.constant true loc(#loc295) + %left_valid_mask_25 = arith.constant dense : tensor<8x16xi1> loc(#loc295) + %right_valid_mask = arith.constant true loc(#loc296) + %right_valid_mask_26 = arith.constant dense : tensor<8x16xi1> loc(#loc296) + %left_isnan = arith.cmpi ne, %ileft_13, %ileft_13 : tensor<8x16xi32> loc(#loc297) + %right_isnan = arith.cmpi ne, %iright_14, %iright_14 : tensor<8x16xi32> loc(#loc298) + %cond = arith.cmpi slt, %ileft_13, %iright_14 : tensor<8x16xi32> loc(#loc331) + %0 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S8_16S__(%ileft_13) : (tensor<8x16xi32>) -> i1 loc(#loc136) + %1 = scf.if %0 -> (tensor<8x16xi1>) { + %cond_42 = arith.constant true loc(#loc300) + %cond_43 = arith.constant dense : tensor<8x16xi1> loc(#loc300) + %cond_44 = arith.xori %left_isnan, %cond_43 : tensor<8x16xi1> loc(#loc300) + %cond_45 = arith.andi %right_isnan, %cond_44 : tensor<8x16xi1> loc(#loc301) + %cond_46 = arith.ori %cond, %cond_45 : tensor<8x16xi1> loc(#loc332) + scf.yield %cond_46 : tensor<8x16xi1> loc(#loc332) + } else { + scf.yield %cond : tensor<8x16xi1> loc(#loc141) + } loc(#loc137) + %eq = arith.cmpi eq, %ileft_13, %iright_14 : tensor<8x16xi32> loc(#loc333) + %2 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S8_16S__(%ileft_13) : (tensor<8x16xi32>) -> i1 loc(#loc143) + %3 = scf.if %2 -> (tensor<8x16xi1>) { + %eq_42 = arith.andi %left_isnan, %right_isnan : tensor<8x16xi1> loc(#loc304) + %eq_43 = arith.ori %eq, %eq_42 : tensor<8x16xi1> loc(#loc334) + scf.yield %eq_43 : tensor<8x16xi1> loc(#loc334) + } else { + scf.yield %eq : tensor<8x16xi1> loc(#loc141) + } loc(#loc144) + %cond_27 = arith.cmpi sgt, %left_idx_23, %right_idx_24 : tensor<8x16xi32> loc(#loc306) + %cond_28 = arith.andi %3, %cond_27 : tensor<8x16xi1> loc(#loc307) + %cond_29 = arith.ori %1, %cond_28 : tensor<8x16xi1> loc(#loc308) + %cond_30 = arith.cmpi ugt, %right_valid_mask_26, %left_valid_mask_25 : tensor<8x16xi1> loc(#loc309) + %cond_31 = arith.cmpi eq, %right_valid_mask_26, %left_valid_mask_25 : tensor<8x16xi1> loc(#loc310) + %cond_32 = arith.andi %cond_31, %cond_29 : tensor<8x16xi1> loc(#loc311) + %cond_33 = arith.ori %cond_30, %cond_32 : tensor<8x16xi1> loc(#loc312) + %cond_34 = tt.splat %flip : i1 -> tensor<8x16xi1> loc(#loc313) + %cond_35 = arith.xori %cond_33, %cond_34 : tensor<8x16xi1> loc(#loc313) + %ret = arith.xori %ileft_13, %iright_14 : tensor<8x16xi32> loc(#loc315) + %ret_36 = tt.call @triton.language.standard.zeros_like__i32S8_16S__(%x) : (tensor<8x16xi32>) -> tensor<8x16xi32> loc(#loc316) + %ret_37 = arith.select %cond_35, %ret, %ret_36 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc317) + %ret_38 = arith.xori %x, %ret_37 : tensor<8x16xi32> loc(#loc318) + %new_idxs = arith.xori %left_idx_23, %right_idx_24 : tensor<8x16xi32> loc(#loc319) + %new_idxs_39 = tt.call @triton.language.standard.zeros_like__i32S8_16S__(%idxs) : (tensor<8x16xi32>) -> tensor<8x16xi32> loc(#loc320) + %new_idxs_40 = arith.select %cond_35, %new_idxs, %new_idxs_39 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc321) + %new_idxs_41 = arith.xori %idxs, %new_idxs_40 : tensor<8x16xi32> loc(#loc322) + tt.return %ret_38, %new_idxs_41 : tensor<8x16xi32>, tensor<8x16xi32> loc(#loc164) + ^bb1: // no predecessors + %4 = ub.poison : tensor<8x16xi32> loc(#loc165) + %5 = ub.poison : tensor<8x16xi32> loc(#loc165) + tt.return %4, %5 : tensor<8x16xi32>, tensor<8x16xi32> loc(#loc165) + } loc(#loc103) + tt.func private @"triton.language.standard.sum__i32S8_2_8S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%input: tensor<8x2x8xi32> loc("input"(#loc166))) -> tensor<8x8xi32> attributes {noinline = false} { + %0 = "tt.reduce"(%input) <{axis = 1 : i32}> ({ + ^bb0(%arg1: i32 loc(unknown), %arg2: i32 loc(unknown)): + %2 = tt.call @triton.language.standard._sum_combine__i32_i32__(%arg1, %arg2) : (i32, i32) -> i32 loc(#loc167) + tt.reduce.return %2 : i32 loc(#loc167) + }) : (tensor<8x2x8xi32>) -> tensor<8x8xi32> loc(#loc167) + tt.return %0 : tensor<8x8xi32> loc(#loc168) + ^bb1: // no predecessors + %1 = ub.poison : tensor<8x8xi32> loc(#loc169) + tt.return %1 : tensor<8x8xi32> loc(#loc169) + } loc(#loc166) + tt.func private @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S8_16S_i32S8_16S_u1__(2,)cconstexpr_None__(4,)cconstexpr_1__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x: tensor<8x16xi32> loc("x"(#loc103)), %idxs: tensor<8x16xi32> loc("idxs"(#loc103)), %flip: i1 loc("flip"(#loc103))) -> (tensor<8x16xi32>, tensor<8x16xi32>) attributes {noinline = false} { + %y = tt.reshape %x : tensor<8x16xi32> -> tensor<16x2x4xi32> loc(#loc268) + %right_mask = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc269) + %right_mask_0 = tt.expand_dims %right_mask {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc270) + %right_mask_1 = tt.expand_dims %right_mask_0 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc270) + %left_mask = arith.constant 1 : i32 loc(#loc271) + %left_mask_2 = arith.constant 1 : i32 loc(#loc271) + %left_mask_3 = arith.constant dense<1> : tensor<1x2x1xi32> loc(#loc271) + %left_mask_4 = arith.subi %left_mask_3, %right_mask_1 : tensor<1x2x1xi32> loc(#loc271) + %ileft = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<16x2x4xi32> loc(#loc272) + %ileft_5 = arith.muli %y, %ileft : tensor<16x2x4xi32> loc(#loc272) + %ileft_6 = tt.call @"triton.language.standard.sum__i32S16_2_4S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%ileft_5) : (tensor<16x2x4xi32>) -> tensor<16x4xi32> loc(#loc273) + %ileft_7 = tt.expand_dims %ileft_6 {axis = 1 : i32} : tensor<16x4xi32> -> tensor<16x1x4xi32> loc(#loc274) + %ileft_8 = tt.broadcast %ileft_7 : tensor<16x1x4xi32> -> tensor<16x2x4xi32> loc(#loc275) + %iright = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<16x2x4xi32> loc(#loc276) + %iright_9 = arith.muli %y, %iright : tensor<16x2x4xi32> loc(#loc276) + %iright_10 = tt.call @"triton.language.standard.sum__i32S16_2_4S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%iright_9) : (tensor<16x2x4xi32>) -> tensor<16x4xi32> loc(#loc277) + %iright_11 = tt.expand_dims %iright_10 {axis = 1 : i32} : tensor<16x4xi32> -> tensor<16x1x4xi32> loc(#loc278) + %iright_12 = tt.broadcast %iright_11 : tensor<16x1x4xi32> -> tensor<16x2x4xi32> loc(#loc279) + %ileft_13 = tt.reshape %ileft_8 : tensor<16x2x4xi32> -> tensor<8x16xi32> loc(#loc280) + %iright_14 = tt.reshape %iright_12 : tensor<16x2x4xi32> -> tensor<8x16xi32> loc(#loc281) + %y_idx = tt.reshape %idxs : tensor<8x16xi32> -> tensor<16x2x4xi32> loc(#loc282) + %left_idx = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<16x2x4xi32> loc(#loc284) + %left_idx_15 = arith.muli %y_idx, %left_idx : tensor<16x2x4xi32> loc(#loc284) + %left_idx_16 = tt.call @"triton.language.standard.sum__i32S16_2_4S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%left_idx_15) : (tensor<16x2x4xi32>) -> tensor<16x4xi32> loc(#loc285) + %left_idx_17 = tt.expand_dims %left_idx_16 {axis = 1 : i32} : tensor<16x4xi32> -> tensor<16x1x4xi32> loc(#loc286) + %left_idx_18 = tt.broadcast %left_idx_17 : tensor<16x1x4xi32> -> tensor<16x2x4xi32> loc(#loc287) + %right_idx = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<16x2x4xi32> loc(#loc289) + %right_idx_19 = arith.muli %y_idx, %right_idx : tensor<16x2x4xi32> loc(#loc289) + %right_idx_20 = tt.call @"triton.language.standard.sum__i32S16_2_4S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%right_idx_19) : (tensor<16x2x4xi32>) -> tensor<16x4xi32> loc(#loc290) + %right_idx_21 = tt.expand_dims %right_idx_20 {axis = 1 : i32} : tensor<16x4xi32> -> tensor<16x1x4xi32> loc(#loc291) + %right_idx_22 = tt.broadcast %right_idx_21 : tensor<16x1x4xi32> -> tensor<16x2x4xi32> loc(#loc292) + %left_idx_23 = tt.reshape %left_idx_18 : tensor<16x2x4xi32> -> tensor<8x16xi32> loc(#loc293) + %right_idx_24 = tt.reshape %right_idx_22 : tensor<16x2x4xi32> -> tensor<8x16xi32> loc(#loc294) + %left_valid_mask = arith.constant true loc(#loc295) + %left_valid_mask_25 = arith.constant dense : tensor<8x16xi1> loc(#loc295) + %right_valid_mask = arith.constant true loc(#loc296) + %right_valid_mask_26 = arith.constant dense : tensor<8x16xi1> loc(#loc296) + %left_isnan = arith.cmpi ne, %ileft_13, %ileft_13 : tensor<8x16xi32> loc(#loc297) + %right_isnan = arith.cmpi ne, %iright_14, %iright_14 : tensor<8x16xi32> loc(#loc298) + %cond = arith.cmpi slt, %ileft_13, %iright_14 : tensor<8x16xi32> loc(#loc331) + %0 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S8_16S__(%ileft_13) : (tensor<8x16xi32>) -> i1 loc(#loc136) + %1 = scf.if %0 -> (tensor<8x16xi1>) { + %cond_42 = arith.constant true loc(#loc300) + %cond_43 = arith.constant dense : tensor<8x16xi1> loc(#loc300) + %cond_44 = arith.xori %left_isnan, %cond_43 : tensor<8x16xi1> loc(#loc300) + %cond_45 = arith.andi %right_isnan, %cond_44 : tensor<8x16xi1> loc(#loc301) + %cond_46 = arith.ori %cond, %cond_45 : tensor<8x16xi1> loc(#loc332) + scf.yield %cond_46 : tensor<8x16xi1> loc(#loc332) + } else { + scf.yield %cond : tensor<8x16xi1> loc(#loc141) + } loc(#loc137) + %eq = arith.cmpi eq, %ileft_13, %iright_14 : tensor<8x16xi32> loc(#loc333) + %2 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S8_16S__(%ileft_13) : (tensor<8x16xi32>) -> i1 loc(#loc143) + %3 = scf.if %2 -> (tensor<8x16xi1>) { + %eq_42 = arith.andi %left_isnan, %right_isnan : tensor<8x16xi1> loc(#loc304) + %eq_43 = arith.ori %eq, %eq_42 : tensor<8x16xi1> loc(#loc334) + scf.yield %eq_43 : tensor<8x16xi1> loc(#loc334) + } else { + scf.yield %eq : tensor<8x16xi1> loc(#loc141) + } loc(#loc144) + %cond_27 = arith.cmpi sgt, %left_idx_23, %right_idx_24 : tensor<8x16xi32> loc(#loc306) + %cond_28 = arith.andi %3, %cond_27 : tensor<8x16xi1> loc(#loc307) + %cond_29 = arith.ori %1, %cond_28 : tensor<8x16xi1> loc(#loc308) + %cond_30 = arith.cmpi ugt, %right_valid_mask_26, %left_valid_mask_25 : tensor<8x16xi1> loc(#loc309) + %cond_31 = arith.cmpi eq, %right_valid_mask_26, %left_valid_mask_25 : tensor<8x16xi1> loc(#loc310) + %cond_32 = arith.andi %cond_31, %cond_29 : tensor<8x16xi1> loc(#loc311) + %cond_33 = arith.ori %cond_30, %cond_32 : tensor<8x16xi1> loc(#loc312) + %cond_34 = tt.splat %flip : i1 -> tensor<8x16xi1> loc(#loc313) + %cond_35 = arith.xori %cond_33, %cond_34 : tensor<8x16xi1> loc(#loc313) + %ret = arith.xori %ileft_13, %iright_14 : tensor<8x16xi32> loc(#loc315) + %ret_36 = tt.call @triton.language.standard.zeros_like__i32S8_16S__(%x) : (tensor<8x16xi32>) -> tensor<8x16xi32> loc(#loc316) + %ret_37 = arith.select %cond_35, %ret, %ret_36 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc317) + %ret_38 = arith.xori %x, %ret_37 : tensor<8x16xi32> loc(#loc318) + %new_idxs = arith.xori %left_idx_23, %right_idx_24 : tensor<8x16xi32> loc(#loc319) + %new_idxs_39 = tt.call @triton.language.standard.zeros_like__i32S8_16S__(%idxs) : (tensor<8x16xi32>) -> tensor<8x16xi32> loc(#loc320) + %new_idxs_40 = arith.select %cond_35, %new_idxs, %new_idxs_39 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc321) + %new_idxs_41 = arith.xori %idxs, %new_idxs_40 : tensor<8x16xi32> loc(#loc322) + tt.return %ret_38, %new_idxs_41 : tensor<8x16xi32>, tensor<8x16xi32> loc(#loc164) + ^bb1: // no predecessors + %4 = ub.poison : tensor<8x16xi32> loc(#loc165) + %5 = ub.poison : tensor<8x16xi32> loc(#loc165) + tt.return %4, %5 : tensor<8x16xi32>, tensor<8x16xi32> loc(#loc165) + } loc(#loc103) + tt.func private @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S8_16S_i32S8_16S_u1__(2,)cconstexpr_None__(4,)cconstexpr_2__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x: tensor<8x16xi32> loc("x"(#loc103)), %idxs: tensor<8x16xi32> loc("idxs"(#loc103)), %flip: i1 loc("flip"(#loc103))) -> (tensor<8x16xi32>, tensor<8x16xi32>) attributes {noinline = false} { + %y = tt.reshape %x : tensor<8x16xi32> -> tensor<32x2x2xi32> loc(#loc268) + %right_mask = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc269) + %right_mask_0 = tt.expand_dims %right_mask {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc270) + %right_mask_1 = tt.expand_dims %right_mask_0 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc270) + %left_mask = arith.constant 1 : i32 loc(#loc271) + %left_mask_2 = arith.constant 1 : i32 loc(#loc271) + %left_mask_3 = arith.constant dense<1> : tensor<1x2x1xi32> loc(#loc271) + %left_mask_4 = arith.subi %left_mask_3, %right_mask_1 : tensor<1x2x1xi32> loc(#loc271) + %ileft = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<32x2x2xi32> loc(#loc272) + %ileft_5 = arith.muli %y, %ileft : tensor<32x2x2xi32> loc(#loc272) + %ileft_6 = tt.call @"triton.language.standard.sum__i32S32_2_2S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%ileft_5) : (tensor<32x2x2xi32>) -> tensor<32x2xi32> loc(#loc273) + %ileft_7 = tt.expand_dims %ileft_6 {axis = 1 : i32} : tensor<32x2xi32> -> tensor<32x1x2xi32> loc(#loc274) + %ileft_8 = tt.broadcast %ileft_7 : tensor<32x1x2xi32> -> tensor<32x2x2xi32> loc(#loc275) + %iright = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<32x2x2xi32> loc(#loc276) + %iright_9 = arith.muli %y, %iright : tensor<32x2x2xi32> loc(#loc276) + %iright_10 = tt.call @"triton.language.standard.sum__i32S32_2_2S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%iright_9) : (tensor<32x2x2xi32>) -> tensor<32x2xi32> loc(#loc277) + %iright_11 = tt.expand_dims %iright_10 {axis = 1 : i32} : tensor<32x2xi32> -> tensor<32x1x2xi32> loc(#loc278) + %iright_12 = tt.broadcast %iright_11 : tensor<32x1x2xi32> -> tensor<32x2x2xi32> loc(#loc279) + %ileft_13 = tt.reshape %ileft_8 : tensor<32x2x2xi32> -> tensor<8x16xi32> loc(#loc280) + %iright_14 = tt.reshape %iright_12 : tensor<32x2x2xi32> -> tensor<8x16xi32> loc(#loc281) + %y_idx = tt.reshape %idxs : tensor<8x16xi32> -> tensor<32x2x2xi32> loc(#loc282) + %left_idx = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<32x2x2xi32> loc(#loc284) + %left_idx_15 = arith.muli %y_idx, %left_idx : tensor<32x2x2xi32> loc(#loc284) + %left_idx_16 = tt.call @"triton.language.standard.sum__i32S32_2_2S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%left_idx_15) : (tensor<32x2x2xi32>) -> tensor<32x2xi32> loc(#loc285) + %left_idx_17 = tt.expand_dims %left_idx_16 {axis = 1 : i32} : tensor<32x2xi32> -> tensor<32x1x2xi32> loc(#loc286) + %left_idx_18 = tt.broadcast %left_idx_17 : tensor<32x1x2xi32> -> tensor<32x2x2xi32> loc(#loc287) + %right_idx = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<32x2x2xi32> loc(#loc289) + %right_idx_19 = arith.muli %y_idx, %right_idx : tensor<32x2x2xi32> loc(#loc289) + %right_idx_20 = tt.call @"triton.language.standard.sum__i32S32_2_2S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%right_idx_19) : (tensor<32x2x2xi32>) -> tensor<32x2xi32> loc(#loc290) + %right_idx_21 = tt.expand_dims %right_idx_20 {axis = 1 : i32} : tensor<32x2xi32> -> tensor<32x1x2xi32> loc(#loc291) + %right_idx_22 = tt.broadcast %right_idx_21 : tensor<32x1x2xi32> -> tensor<32x2x2xi32> loc(#loc292) + %left_idx_23 = tt.reshape %left_idx_18 : tensor<32x2x2xi32> -> tensor<8x16xi32> loc(#loc293) + %right_idx_24 = tt.reshape %right_idx_22 : tensor<32x2x2xi32> -> tensor<8x16xi32> loc(#loc294) + %left_valid_mask = arith.constant true loc(#loc295) + %left_valid_mask_25 = arith.constant dense : tensor<8x16xi1> loc(#loc295) + %right_valid_mask = arith.constant true loc(#loc296) + %right_valid_mask_26 = arith.constant dense : tensor<8x16xi1> loc(#loc296) + %left_isnan = arith.cmpi ne, %ileft_13, %ileft_13 : tensor<8x16xi32> loc(#loc297) + %right_isnan = arith.cmpi ne, %iright_14, %iright_14 : tensor<8x16xi32> loc(#loc298) + %cond = arith.cmpi slt, %ileft_13, %iright_14 : tensor<8x16xi32> loc(#loc331) + %0 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S8_16S__(%ileft_13) : (tensor<8x16xi32>) -> i1 loc(#loc136) + %1 = scf.if %0 -> (tensor<8x16xi1>) { + %cond_42 = arith.constant true loc(#loc300) + %cond_43 = arith.constant dense : tensor<8x16xi1> loc(#loc300) + %cond_44 = arith.xori %left_isnan, %cond_43 : tensor<8x16xi1> loc(#loc300) + %cond_45 = arith.andi %right_isnan, %cond_44 : tensor<8x16xi1> loc(#loc301) + %cond_46 = arith.ori %cond, %cond_45 : tensor<8x16xi1> loc(#loc332) + scf.yield %cond_46 : tensor<8x16xi1> loc(#loc332) + } else { + scf.yield %cond : tensor<8x16xi1> loc(#loc141) + } loc(#loc137) + %eq = arith.cmpi eq, %ileft_13, %iright_14 : tensor<8x16xi32> loc(#loc333) + %2 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S8_16S__(%ileft_13) : (tensor<8x16xi32>) -> i1 loc(#loc143) + %3 = scf.if %2 -> (tensor<8x16xi1>) { + %eq_42 = arith.andi %left_isnan, %right_isnan : tensor<8x16xi1> loc(#loc304) + %eq_43 = arith.ori %eq, %eq_42 : tensor<8x16xi1> loc(#loc334) + scf.yield %eq_43 : tensor<8x16xi1> loc(#loc334) + } else { + scf.yield %eq : tensor<8x16xi1> loc(#loc141) + } loc(#loc144) + %cond_27 = arith.cmpi sgt, %left_idx_23, %right_idx_24 : tensor<8x16xi32> loc(#loc306) + %cond_28 = arith.andi %3, %cond_27 : tensor<8x16xi1> loc(#loc307) + %cond_29 = arith.ori %1, %cond_28 : tensor<8x16xi1> loc(#loc308) + %cond_30 = arith.cmpi ugt, %right_valid_mask_26, %left_valid_mask_25 : tensor<8x16xi1> loc(#loc309) + %cond_31 = arith.cmpi eq, %right_valid_mask_26, %left_valid_mask_25 : tensor<8x16xi1> loc(#loc310) + %cond_32 = arith.andi %cond_31, %cond_29 : tensor<8x16xi1> loc(#loc311) + %cond_33 = arith.ori %cond_30, %cond_32 : tensor<8x16xi1> loc(#loc312) + %cond_34 = tt.splat %flip : i1 -> tensor<8x16xi1> loc(#loc313) + %cond_35 = arith.xori %cond_33, %cond_34 : tensor<8x16xi1> loc(#loc313) + %ret = arith.xori %ileft_13, %iright_14 : tensor<8x16xi32> loc(#loc315) + %ret_36 = tt.call @triton.language.standard.zeros_like__i32S8_16S__(%x) : (tensor<8x16xi32>) -> tensor<8x16xi32> loc(#loc316) + %ret_37 = arith.select %cond_35, %ret, %ret_36 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc317) + %ret_38 = arith.xori %x, %ret_37 : tensor<8x16xi32> loc(#loc318) + %new_idxs = arith.xori %left_idx_23, %right_idx_24 : tensor<8x16xi32> loc(#loc319) + %new_idxs_39 = tt.call @triton.language.standard.zeros_like__i32S8_16S__(%idxs) : (tensor<8x16xi32>) -> tensor<8x16xi32> loc(#loc320) + %new_idxs_40 = arith.select %cond_35, %new_idxs, %new_idxs_39 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc321) + %new_idxs_41 = arith.xori %idxs, %new_idxs_40 : tensor<8x16xi32> loc(#loc322) + tt.return %ret_38, %new_idxs_41 : tensor<8x16xi32>, tensor<8x16xi32> loc(#loc164) + ^bb1: // no predecessors + %4 = ub.poison : tensor<8x16xi32> loc(#loc165) + %5 = ub.poison : tensor<8x16xi32> loc(#loc165) + tt.return %4, %5 : tensor<8x16xi32>, tensor<8x16xi32> loc(#loc165) + } loc(#loc103) + tt.func private @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S8_16S_i32S8_16S_u1__(2,)cconstexpr_None__(4,)cconstexpr_3__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x: tensor<8x16xi32> loc("x"(#loc103)), %idxs: tensor<8x16xi32> loc("idxs"(#loc103)), %flip: i1 loc("flip"(#loc103))) -> (tensor<8x16xi32>, tensor<8x16xi32>) attributes {noinline = false} { + %y = tt.reshape %x : tensor<8x16xi32> -> tensor<64x2x1xi32> loc(#loc268) + %right_mask = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc269) + %right_mask_0 = tt.expand_dims %right_mask {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc270) + %right_mask_1 = tt.expand_dims %right_mask_0 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc270) + %left_mask = arith.constant 1 : i32 loc(#loc271) + %left_mask_2 = arith.constant 1 : i32 loc(#loc271) + %left_mask_3 = arith.constant dense<1> : tensor<1x2x1xi32> loc(#loc271) + %left_mask_4 = arith.subi %left_mask_3, %right_mask_1 : tensor<1x2x1xi32> loc(#loc271) + %ileft = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<64x2x1xi32> loc(#loc272) + %ileft_5 = arith.muli %y, %ileft : tensor<64x2x1xi32> loc(#loc272) + %ileft_6 = tt.call @"triton.language.standard.sum__i32S64_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%ileft_5) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc273) + %ileft_7 = tt.expand_dims %ileft_6 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc274) + %ileft_8 = tt.broadcast %ileft_7 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc275) + %iright = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<64x2x1xi32> loc(#loc276) + %iright_9 = arith.muli %y, %iright : tensor<64x2x1xi32> loc(#loc276) + %iright_10 = tt.call @"triton.language.standard.sum__i32S64_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%iright_9) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc277) + %iright_11 = tt.expand_dims %iright_10 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc278) + %iright_12 = tt.broadcast %iright_11 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc279) + %ileft_13 = tt.reshape %ileft_8 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc280) + %iright_14 = tt.reshape %iright_12 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc281) + %y_idx = tt.reshape %idxs : tensor<8x16xi32> -> tensor<64x2x1xi32> loc(#loc282) + %left_idx = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<64x2x1xi32> loc(#loc284) + %left_idx_15 = arith.muli %y_idx, %left_idx : tensor<64x2x1xi32> loc(#loc284) + %left_idx_16 = tt.call @"triton.language.standard.sum__i32S64_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%left_idx_15) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc285) + %left_idx_17 = tt.expand_dims %left_idx_16 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc286) + %left_idx_18 = tt.broadcast %left_idx_17 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc287) + %right_idx = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<64x2x1xi32> loc(#loc289) + %right_idx_19 = arith.muli %y_idx, %right_idx : tensor<64x2x1xi32> loc(#loc289) + %right_idx_20 = tt.call @"triton.language.standard.sum__i32S64_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%right_idx_19) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc290) + %right_idx_21 = tt.expand_dims %right_idx_20 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc291) + %right_idx_22 = tt.broadcast %right_idx_21 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc292) + %left_idx_23 = tt.reshape %left_idx_18 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc293) + %right_idx_24 = tt.reshape %right_idx_22 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc294) + %left_valid_mask = arith.constant true loc(#loc295) + %left_valid_mask_25 = arith.constant dense : tensor<8x16xi1> loc(#loc295) + %right_valid_mask = arith.constant true loc(#loc296) + %right_valid_mask_26 = arith.constant dense : tensor<8x16xi1> loc(#loc296) + %left_isnan = arith.cmpi ne, %ileft_13, %ileft_13 : tensor<8x16xi32> loc(#loc297) + %right_isnan = arith.cmpi ne, %iright_14, %iright_14 : tensor<8x16xi32> loc(#loc298) + %cond = arith.cmpi slt, %ileft_13, %iright_14 : tensor<8x16xi32> loc(#loc331) + %0 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S8_16S__(%ileft_13) : (tensor<8x16xi32>) -> i1 loc(#loc136) + %1 = scf.if %0 -> (tensor<8x16xi1>) { + %cond_42 = arith.constant true loc(#loc300) + %cond_43 = arith.constant dense : tensor<8x16xi1> loc(#loc300) + %cond_44 = arith.xori %left_isnan, %cond_43 : tensor<8x16xi1> loc(#loc300) + %cond_45 = arith.andi %right_isnan, %cond_44 : tensor<8x16xi1> loc(#loc301) + %cond_46 = arith.ori %cond, %cond_45 : tensor<8x16xi1> loc(#loc332) + scf.yield %cond_46 : tensor<8x16xi1> loc(#loc332) + } else { + scf.yield %cond : tensor<8x16xi1> loc(#loc141) + } loc(#loc137) + %eq = arith.cmpi eq, %ileft_13, %iright_14 : tensor<8x16xi32> loc(#loc333) + %2 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S8_16S__(%ileft_13) : (tensor<8x16xi32>) -> i1 loc(#loc143) + %3 = scf.if %2 -> (tensor<8x16xi1>) { + %eq_42 = arith.andi %left_isnan, %right_isnan : tensor<8x16xi1> loc(#loc304) + %eq_43 = arith.ori %eq, %eq_42 : tensor<8x16xi1> loc(#loc334) + scf.yield %eq_43 : tensor<8x16xi1> loc(#loc334) + } else { + scf.yield %eq : tensor<8x16xi1> loc(#loc141) + } loc(#loc144) + %cond_27 = arith.cmpi sgt, %left_idx_23, %right_idx_24 : tensor<8x16xi32> loc(#loc306) + %cond_28 = arith.andi %3, %cond_27 : tensor<8x16xi1> loc(#loc307) + %cond_29 = arith.ori %1, %cond_28 : tensor<8x16xi1> loc(#loc308) + %cond_30 = arith.cmpi ugt, %right_valid_mask_26, %left_valid_mask_25 : tensor<8x16xi1> loc(#loc309) + %cond_31 = arith.cmpi eq, %right_valid_mask_26, %left_valid_mask_25 : tensor<8x16xi1> loc(#loc310) + %cond_32 = arith.andi %cond_31, %cond_29 : tensor<8x16xi1> loc(#loc311) + %cond_33 = arith.ori %cond_30, %cond_32 : tensor<8x16xi1> loc(#loc312) + %cond_34 = tt.splat %flip : i1 -> tensor<8x16xi1> loc(#loc313) + %cond_35 = arith.xori %cond_33, %cond_34 : tensor<8x16xi1> loc(#loc313) + %ret = arith.xori %ileft_13, %iright_14 : tensor<8x16xi32> loc(#loc315) + %ret_36 = tt.call @triton.language.standard.zeros_like__i32S8_16S__(%x) : (tensor<8x16xi32>) -> tensor<8x16xi32> loc(#loc316) + %ret_37 = arith.select %cond_35, %ret, %ret_36 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc317) + %ret_38 = arith.xori %x, %ret_37 : tensor<8x16xi32> loc(#loc318) + %new_idxs = arith.xori %left_idx_23, %right_idx_24 : tensor<8x16xi32> loc(#loc319) + %new_idxs_39 = tt.call @triton.language.standard.zeros_like__i32S8_16S__(%idxs) : (tensor<8x16xi32>) -> tensor<8x16xi32> loc(#loc320) + %new_idxs_40 = arith.select %cond_35, %new_idxs, %new_idxs_39 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc321) + %new_idxs_41 = arith.xori %idxs, %new_idxs_40 : tensor<8x16xi32> loc(#loc322) + tt.return %ret_38, %new_idxs_41 : tensor<8x16xi32>, tensor<8x16xi32> loc(#loc164) + ^bb1: // no predecessors + %4 = ub.poison : tensor<8x16xi32> loc(#loc165) + %5 = ub.poison : tensor<8x16xi32> loc(#loc165) + tt.return %4, %5 : tensor<8x16xi32>, tensor<8x16xi32> loc(#loc165) + } loc(#loc103) + tt.func private @"triton.language.standard.sum__i64S8_16S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%input: tensor<8x16xi64> loc("input"(#loc166))) -> tensor<8xi64> attributes {noinline = false} { + %0 = "tt.reduce"(%input) <{axis = 1 : i32}> ({ + ^bb0(%arg1: i64 loc(unknown), %arg2: i64 loc(unknown)): + %2 = tt.call @triton.language.standard._sum_combine__i64_i64__(%arg1, %arg2) : (i64, i64) -> i64 loc(#loc167) + tt.reduce.return %2 : i64 loc(#loc167) + }) : (tensor<8x16xi64>) -> tensor<8xi64> loc(#loc167) + tt.return %0 : tensor<8xi64> loc(#loc168) + ^bb1: // no predecessors + %1 = ub.poison : tensor<8xi64> loc(#loc169) + tt.return %1 : tensor<8xi64> loc(#loc169) + } loc(#loc166) + tt.func private @triton.language.standard._sum_combine__i64_i64__(%a: i64 loc("a"(#loc170)), %b: i64 loc("b"(#loc170))) -> i64 attributes {noinline = false} { + %0 = arith.addi %a, %b : i64 loc(#loc171) + tt.return %0 : i64 loc(#loc172) + ^bb1: // no predecessors + %1 = ub.poison : i64 loc(#loc173) + tt.return %1 : i64 loc(#loc173) + } loc(#loc170) +} loc(#loc) +#loc1 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":19:13) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":20:15) +#loc3 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":24:28) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":24:33) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":25:36) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":25:44) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":25:23) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":26:21) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":27:28) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":27:38) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":28:16) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":29:48) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":34:40) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":34:37) +#loc15 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":34:30) +#loc16 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":34:45) +#loc17 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":35:30) +#loc18 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":36:18) +#loc19 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":37:34) +#loc20 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":38:18) +#loc21 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":39:18) +#loc22 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":40:19) +#loc23 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":41:19) +#loc24 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":43:19) +#loc25 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":45:34) +#loc26 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":46:71) +#loc27 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":47:20) +#loc28 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":48:21) +#loc29 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":49:21) +#loc30 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":51:71) +#loc31 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":52:20) +#loc32 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":54:35) +#loc33 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":55:26) +#loc34 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":55:29) +#loc35 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":56:21) +#loc36 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":58:35) +#loc37 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":59:26) +#loc38 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":59:29) +#loc39 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":60:21) +#loc40 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":61:21) +#loc41 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":62:21) +#loc42 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":63:21) +#loc43 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":64:19) +#loc44 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":65:32) +#loc45 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":66:35) +#loc46 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":67:44) +#loc47 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":68:20) +#loc48 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":69:20) +#loc49 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":70:35) +#loc50 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":71:28) +#loc51 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":71:46) +#loc52 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":71:38) +#loc53 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":71:55) +#loc54 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":71:53) +#loc55 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":71:63) +#loc56 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":72:31) +#loc57 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":73:21) +#loc58 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":74:21) +#loc59 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":75:19) +#loc60 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":76:35) +#loc61 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":77:20) +#loc62 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":78:20) +#loc63 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":79:35) +#loc64 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":80:28) +#loc65 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":80:46) +#loc66 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":80:38) +#loc67 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":80:55) +#loc68 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":80:53) +#loc69 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":80:63) +#loc70 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":81:25) +#loc71 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":81:37) +#loc72 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":82:25) +#loc73 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":82:37) +#loc74 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":83:35) +#loc75 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":83:32) +#loc76 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":83:25) +#loc77 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":83:47) +#loc78 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":84:52) +#loc79 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":84:49) +#loc80 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":84:25) +#loc81 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":84:85) +#loc82 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":85:35) +#loc83 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":85:32) +#loc84 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":85:25) +#loc85 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":85:47) +#loc86 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":86:52) +#loc87 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":86:49) +#loc88 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":86:25) +#loc89 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":86:85) +#loc90 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":86:4) +#loc92 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":662:12) +#loc93 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":668:11) +#loc94 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":668:4) +#loc96 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":627:41) +#loc97 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":627:44) +#loc98 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":627:60) +#loc99 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":627:68) +#loc100 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":634:73) +#loc101 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":636:11) +#loc102 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":636:4) +#loc104 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":533:22) +#loc105 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":536:30) +#loc106 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":536:33) +#loc107 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":537:21) +#loc108 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":538:40) +#loc109 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":538:51) +#loc110 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":538:65) +#loc111 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":538:78) +#loc112 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":539:41) +#loc113 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":539:53) +#loc114 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":539:67) +#loc115 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":539:80) +#loc116 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":540:30) +#loc117 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":541:32) +#loc118 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":546:29) +#loc119 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:36) +#loc120 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:23) +#loc121 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:50) +#loc122 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:53) +#loc123 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:66) +#loc124 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:37) +#loc125 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:23) +#loc126 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:51) +#loc127 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:54) +#loc128 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:67) +#loc129 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":553:36) +#loc130 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":554:38) +#loc131 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":558:49) +#loc132 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":559:50) +#loc133 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":570:25) +#loc134 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":571:27) +#loc135 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":574:22) +#loc136 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":575:23) +#loc137 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":575:11) +#loc138 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":579:47) +#loc139 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":579:46) +#loc140 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":579:31) +#loc142 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":591:21) +#loc143 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":592:23) +#loc144 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":592:11) +#loc145 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":593:36) +#loc146 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":593:23) +#loc147 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":594:40) +#loc148 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":594:29) +#loc149 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":594:23) +#loc150 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":596:31) +#loc151 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":597:29) +#loc152 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":597:48) +#loc153 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":597:8) +#loc154 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":599:19) +#loc155 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":599:28) +#loc156 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":600:38) +#loc157 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":600:60) +#loc158 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":600:46) +#loc159 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":600:15) +#loc160 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":601:48) +#loc161 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":601:73) +#loc162 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":601:59) +#loc163 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":601:22) +#loc164 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":603:11) +#loc165 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":603:4) +#loc167 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:36) +#loc168 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:11) +#loc169 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:4) +#loc171 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:15) +#loc172 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:11) +#loc173 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:4) +#loc174 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":290:25) +#loc176 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":87:29) +#loc177 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":87:11) +#loc178 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":87:4) +#loc180 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":65:30) +#loc181 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":65:15) +#loc182 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":65:11) +#loc183 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":65:4) +#loc184 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":118:0) +#loc185 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":127:31) +#loc186 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":127:11) +#loc187 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":127:4) +#loc189 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":138:30) +#loc190 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":138:11) +#loc191 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":138:4) +#loc192 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":630:15) +#loc202 = loc("xnumel"(#loc1)) +#loc203 = loc("r0_numel"(#loc2)) +#loc204 = loc("xoffset"(#loc3)) +#loc205 = loc("xoffset"(#loc4)) +#loc206 = loc("xindex"(#loc5)) +#loc207 = loc("xindex"(#loc6)) +#loc208 = loc("xindex"(#loc7)) +#loc209 = loc("xmask"(#loc8)) +#loc210 = loc("r0_index"(#loc9)) +#loc211 = loc("r0_index"(#loc10)) +#loc212 = loc("r0_offset"(#loc11)) +#loc213 = loc("r0_mask"(#loc12)) +#loc214 = loc("tmp0"(#loc13)) +#loc215 = loc("tmp0"(#loc14)) +#loc216 = loc("tmp0"(#loc15)) +#loc217 = loc("tmp0"(#loc16)) +#loc218 = loc("tmp1"(#loc17)) +#loc219 = loc("tmp2"(#loc18)) +#loc220 = loc("tmp3"(#loc19)) +#loc221 = loc("tmp4"(#loc20)) +#loc222 = loc("tmp5"(#loc21)) +#loc223 = loc("tmp6"(#loc22)) +#loc224 = loc("tmp7"(#loc23)) +#loc225 = loc("tmp9"(#loc24)) +#loc226 = loc("tmp11"(#loc25)) +#loc227 = loc("tmp14"(#loc27)) +#loc228 = loc("tmp15"(#loc28)) +#loc229 = loc("tmp16"(#loc29)) +#loc230 = loc("tmp20"(#loc31)) +#loc231 = loc("tmp23"(#loc32)) +#loc232 = loc("tmp24"(#loc33)) +#loc233 = loc("tmp24"(#loc34)) +#loc234 = loc("tmp25"(#loc35)) +#loc235 = loc("tmp28"(#loc36)) +#loc236 = loc("tmp29"(#loc37)) +#loc237 = loc("tmp29"(#loc38)) +#loc238 = loc("tmp30"(#loc39)) +#loc239 = loc("tmp31"(#loc40)) +#loc240 = loc("tmp32"(#loc41)) +#loc241 = loc("tmp33"(#loc42)) +#loc242 = loc("tmp34"(#loc43)) +#loc243 = loc("tmp35"(#loc44)) +#loc244 = loc("tmp36"(#loc45)) +#loc245 = loc("tmp37"(#loc46)) +#loc246 = loc("tmp38"(#loc47)) +#loc247 = loc("tmp39"(#loc48)) +#loc248 = loc("tmp40"(#loc49)) +#loc249 = loc("tmp42"(#loc56)) +#loc250 = loc("tmp43"(#loc57)) +#loc251 = loc("tmp44"(#loc58)) +#loc252 = loc("tmp45"(#loc59)) +#loc253 = loc("tmp46"(#loc60)) +#loc254 = loc("tmp47"(#loc61)) +#loc255 = loc("tmp48"(#loc62)) +#loc256 = loc("tmp49"(#loc63)) +#loc261 = loc("flip"(#loc96)) +#loc262 = loc("flip"(#loc97)) +#loc263 = loc("flip"(#loc98)) +#loc264 = loc("flip"(#loc99)) +#loc268 = loc("y"(#loc104)) +#loc269 = loc("right_mask"(#loc105)) +#loc270 = loc("right_mask"(#loc106)) +#loc271 = loc("left_mask"(#loc107)) +#loc272 = loc("ileft"(#loc108)) +#loc273 = loc("ileft"(#loc109)) +#loc274 = loc("ileft"(#loc110)) +#loc275 = loc("ileft"(#loc111)) +#loc276 = loc("iright"(#loc112)) +#loc277 = loc("iright"(#loc113)) +#loc278 = loc("iright"(#loc114)) +#loc279 = loc("iright"(#loc115)) +#loc280 = loc("ileft"(#loc116)) +#loc281 = loc("iright"(#loc117)) +#loc282 = loc("y_idx"(#loc118)) +#loc283 = loc("left_idx"(#loc119)) +#loc284 = loc("left_idx"(#loc120)) +#loc285 = loc("left_idx"(#loc121)) +#loc286 = loc("left_idx"(#loc122)) +#loc287 = loc("left_idx"(#loc123)) +#loc288 = loc("right_idx"(#loc124)) +#loc289 = loc("right_idx"(#loc125)) +#loc290 = loc("right_idx"(#loc126)) +#loc291 = loc("right_idx"(#loc127)) +#loc292 = loc("right_idx"(#loc128)) +#loc293 = loc("left_idx"(#loc129)) +#loc294 = loc("right_idx"(#loc130)) +#loc295 = loc("left_valid_mask"(#loc131)) +#loc296 = loc("right_valid_mask"(#loc132)) +#loc297 = loc("left_isnan"(#loc133)) +#loc298 = loc("right_isnan"(#loc134)) +#loc299 = loc("cond"(#loc135)) +#loc300 = loc("cond"(#loc138)) +#loc301 = loc("cond"(#loc139)) +#loc302 = loc("cond"(#loc140)) +#loc303 = loc("eq"(#loc142)) +#loc304 = loc("eq"(#loc145)) +#loc305 = loc("eq"(#loc146)) +#loc306 = loc("cond"(#loc147)) +#loc307 = loc("cond"(#loc148)) +#loc308 = loc("cond"(#loc149)) +#loc309 = loc("cond"(#loc150)) +#loc310 = loc("cond"(#loc151)) +#loc311 = loc("cond"(#loc152)) +#loc312 = loc("cond"(#loc153)) +#loc313 = loc("cond"(#loc154)) +#loc314 = loc("cond"(#loc155)) +#loc315 = loc("ret"(#loc156)) +#loc316 = loc("ret"(#loc157)) +#loc317 = loc("ret"(#loc158)) +#loc318 = loc("ret"(#loc159)) +#loc319 = loc("new_idxs"(#loc160)) +#loc320 = loc("new_idxs"(#loc161)) +#loc321 = loc("new_idxs"(#loc162)) +#loc322 = loc("new_idxs"(#loc163)) +#loc326 = loc("input"(#loc174)) +#loc330 = loc("flip"(#loc192)) +#loc331 = loc("cond"(#loc299)) +#loc332 = loc("cond"(#loc302)) +#loc333 = loc("eq"(#loc303)) +#loc334 = loc("eq"(#loc305)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/5P66Y2R4BAVAKI2AZ4OOKOSCRGKH7SKT5SYLTP4RXGBUCBAJIDZQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ttgir b/SpecForge-ext/cache/compiled_kernels/triton/0/5P66Y2R4BAVAKI2AZ4OOKOSCRGKH7SKT5SYLTP4RXGBUCBAJIDZQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ttgir new file mode 100644 index 0000000000000000000000000000000000000000..71376c5f269bd92bf754ce05da87da679d0b07a6 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/5P66Y2R4BAVAKI2AZ4OOKOSCRGKH7SKT5SYLTP4RXGBUCBAJIDZQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ttgir @@ -0,0 +1,1480 @@ +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [2, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [4, 2, 4], warpsPerCTA = [2, 1, 1], order = [2, 1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [8, 2, 2], warpsPerCTA = [2, 1, 1], order = [2, 1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [16, 2, 1], warpsPerCTA = [2, 1, 1], order = [2, 1, 0]}> +#blocked4 = #ttg.blocked<{sizePerThread = [1, 2, 1], threadsPerWarp = [32, 1, 1], warpsPerCTA = [2, 1, 1], order = [2, 1, 0]}> +#blocked5 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 2], order = [0, 1]}> +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":18:0) +#loc1 = loc(unknown) +#loc20 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":662:12) +#loc21 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":46:71) +#loc25 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":634:73) +#loc29 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":538:51) +#loc34 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":539:53) +#loc43 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:50) +#loc48 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:51) +#loc69 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":51:71) +#loc72 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":55:26) +#loc77 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":59:26) +#loc117 = loc("in_ptr0"(#loc)) +#loc118 = loc("out_ptr4"(#loc)) +#loc119 = loc("out_ptr5"(#loc)) +#loc120 = loc("out_ptr6"(#loc)) +#loc121 = loc("out_ptr7"(#loc)) +#loc122 = loc("out_ptr8"(#loc)) +#loc123 = loc("out_ptr9"(#loc)) +#loc124 = loc("xnumel"(#loc)) +#loc125 = loc("r0_numel"(#loc)) +#loc144 = loc(callsite(#loc20 at #loc21)) +#loc150 = loc("ileft"(#loc29)) +#loc154 = loc("iright"(#loc34)) +#loc163 = loc("left_idx"(#loc43)) +#loc168 = loc("right_idx"(#loc48)) +#loc189 = loc(callsite(#loc20 at #loc69)) +#loc192 = loc("tmp24"(#loc72)) +#loc197 = loc("tmp29"(#loc77)) +#loc214 = loc(callsite(#loc25 at #loc144)) +#loc218 = loc(callsite(#loc25 at #loc189)) +#loc221 = loc(callsite(#loc1 at #loc192)) +#loc224 = loc(callsite(#loc1 at #loc197)) +#loc228 = loc(callsite(#loc150 at #loc214)) +#loc232 = loc(callsite(#loc154 at #loc214)) +#loc240 = loc(callsite(#loc163 at #loc214)) +#loc245 = loc(callsite(#loc168 at #loc214)) +#loc265 = loc(callsite(#loc150 at #loc218)) +#loc269 = loc(callsite(#loc154 at #loc218)) +#loc287 = loc(callsite(#loc163 at #loc218)) +#loc291 = loc(callsite(#loc168 at #loc218)) +#loc301 = loc(callsite(#loc1 at #loc228)) +#loc303 = loc(callsite(#loc1 at #loc232)) +#loc306 = loc(callsite(#loc1 at #loc240)) +#loc309 = loc(callsite(#loc1 at #loc245)) +#loc311 = loc(callsite(#loc1 at #loc265)) +#loc313 = loc(callsite(#loc1 at #loc269)) +#loc315 = loc(callsite(#loc1 at #loc287)) +#loc317 = loc(callsite(#loc1 at #loc291)) +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %out_ptr4: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr4"(#loc)), %out_ptr5: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr5"(#loc)), %out_ptr6: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr6"(#loc)), %out_ptr7: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr7"(#loc)), %out_ptr8: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr8"(#loc)), %out_ptr9: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr9"(#loc)), %xnumel: i32 {tt.divisibility = 16 : i32} loc("xnumel"(#loc)), %r0_numel: i32 {tt.divisibility = 16 : i32} loc("r0_numel"(#loc))) attributes {noinline = false} { + %cst = arith.constant dense : tensor<8x1xi1, #blocked> loc(#loc1) + %cst_0 = arith.constant dense<17> : tensor<8x1xi32, #blocked> loc(#loc1) + %cst_1 = arith.constant dense<1> : tensor<1x2x1xi32, #blocked1> loc(#loc1) + %cst_2 = arith.constant dense<1> : tensor<1x2x1xi32, #blocked2> loc(#loc1) + %cst_3 = arith.constant dense<1> : tensor<1x2x1xi32, #blocked3> loc(#loc1) + %cst_4 = arith.constant dense<1> : tensor<1x2x1xi32, #blocked4> loc(#loc1) + %cst_5 = arith.constant dense<16> : tensor<8x1xi32, #blocked> loc(#loc1) + %cst_6 = arith.constant dense<32> : tensor<8x1xi32, #blocked5> loc(#loc1) + %cst_7 = arith.constant dense<32> : tensor<8x1xi32, #blocked> loc(#loc1) + %c8_i32 = arith.constant 8 : i32 loc(#loc1) + %cst_8 = arith.constant dense<1> : tensor<8x16xi32, #blocked5> loc(#loc1) + %cst_9 = arith.constant dense<0> : tensor<8x16xi32, #blocked> loc(#loc1) + %cst_10 = arith.constant dense<17> : tensor<8x16xi32, #blocked> loc(#loc1) + %cst_11 = arith.constant dense<16> : tensor<8x16xi32, #blocked> loc(#loc1) + %cst_12 = arith.constant dense<16384> : tensor<8x16xi64, #blocked> loc(#loc1) + %cst_13 = arith.constant dense<0> : tensor<8x16xi64, #blocked> loc(#loc1) + %xoffset = tt.get_program_id x : i32 loc(#loc126) + %xoffset_14 = arith.muli %xoffset, %c8_i32 : i32 loc(#loc127) + %xindex = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc128) + %xindex_15 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #ttg.slice<{dim = 1, parent = #blocked5}>> loc(#loc128) + %xindex_16 = tt.expand_dims %xindex {axis = 1 : i32} : tensor<8xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<8x1xi32, #blocked> loc(#loc128) + %xindex_17 = tt.expand_dims %xindex_15 {axis = 1 : i32} : tensor<8xi32, #ttg.slice<{dim = 1, parent = #blocked5}>> -> tensor<8x1xi32, #blocked5> loc(#loc128) + %xindex_18 = tt.splat %xoffset_14 : i32 -> tensor<8x1xi32, #blocked> loc(#loc129) + %xindex_19 = tt.splat %xoffset_14 : i32 -> tensor<8x1xi32, #blocked5> loc(#loc129) + %xindex_20 = arith.addi %xindex_18, %xindex_16 : tensor<8x1xi32, #blocked> loc(#loc129) + %xindex_21 = arith.addi %xindex_19, %xindex_17 : tensor<8x1xi32, #blocked5> loc(#loc129) + %xmask = arith.cmpi slt, %xindex_20, %cst_7 : tensor<8x1xi32, #blocked> loc(#loc130) + %xmask_22 = arith.cmpi slt, %xindex_21, %cst_6 : tensor<8x1xi32, #blocked5> loc(#loc130) + %r0_index = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc131) + %r0_index_23 = tt.expand_dims %r0_index {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> loc(#loc131) + %tmp0 = arith.muli %xindex_20, %cst_5 : tensor<8x1xi32, #blocked> loc(#loc132) + %tmp0_24 = tt.broadcast %r0_index_23 : tensor<1x16xi32, #blocked> -> tensor<8x16xi32, #blocked> loc(#loc133) + %tmp0_25 = tt.broadcast %tmp0 : tensor<8x1xi32, #blocked> -> tensor<8x16xi32, #blocked> loc(#loc133) + %tmp0_26 = arith.addi %tmp0_24, %tmp0_25 : tensor<8x16xi32, #blocked> loc(#loc133) + %tmp0_27 = tt.splat %in_ptr0 : !tt.ptr -> tensor<8x16x!tt.ptr, #blocked> loc(#loc134) + %tmp0_28 = tt.addptr %tmp0_27, %tmp0_26 : tensor<8x16x!tt.ptr, #blocked>, tensor<8x16xi32, #blocked> loc(#loc134) + %tmp0_29 = tt.broadcast %xmask : tensor<8x1xi1, #blocked> -> tensor<8x16xi1, #blocked> loc(#loc135) + %tmp0_30 = tt.broadcast %xmask_22 : tensor<8x1xi1, #blocked5> -> tensor<8x16xi1, #blocked5> loc(#loc135) + %tmp0_31 = tt.load %tmp0_28, %tmp0_29, %cst_13 : tensor<8x16x!tt.ptr, #blocked> loc(#loc135) + %tmp2 = arith.cmpi sgt, %tmp0_31, %cst_13 : tensor<8x16xi64, #blocked> loc(#loc136) + %tmp4 = arith.cmpi slt, %tmp0_31, %cst_12 : tensor<8x16xi64, #blocked> loc(#loc137) + %tmp5 = arith.andi %tmp2, %tmp4 : tensor<8x16xi1, #blocked> loc(#loc138) + %tmp7 = arith.extui %tmp5 : tensor<8x16xi1, #blocked> to tensor<8x16xi32, #blocked> loc(#loc210) + %tmp9 = arith.trunci %r0_index_23 : tensor<1x16xi32, #blocked> to tensor<1x16xi16, #blocked> loc(#loc141) + %tmp11 = tt.broadcast %tmp9 : tensor<1x16xi16, #blocked> -> tensor<8x16xi16, #blocked> loc(#loc142) + %flip = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #blocked3}>}>> loc(#loc211) + %flip_32 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #blocked4}>}>> loc(#loc211) + %flip_33 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #blocked2}>}>> loc(#loc211) + %flip_34 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #blocked1}>}>> loc(#loc211) + %flip_35 = tt.expand_dims %flip {axis = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #blocked3}>}>> -> tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #blocked3}>> loc(#loc211) + %flip_36 = tt.expand_dims %flip_32 {axis = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #blocked4}>}>> -> tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #blocked4}>> loc(#loc211) + %flip_37 = tt.expand_dims %flip_33 {axis = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #blocked2}>}>> -> tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #blocked2}>> loc(#loc211) + %flip_38 = tt.expand_dims %flip_34 {axis = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #blocked1}>}>> -> tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #blocked1}>> loc(#loc211) + %flip_39 = tt.expand_dims %flip_35 {axis = 2 : i32} : tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #blocked3}>> -> tensor<1x2x1xi32, #blocked3> loc(#loc211) + %flip_40 = tt.expand_dims %flip_36 {axis = 2 : i32} : tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #blocked4}>> -> tensor<1x2x1xi32, #blocked4> loc(#loc211) + %flip_41 = tt.expand_dims %flip_37 {axis = 2 : i32} : tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #blocked2}>> -> tensor<1x2x1xi32, #blocked2> loc(#loc211) + %flip_42 = tt.expand_dims %flip_38 {axis = 2 : i32} : tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #blocked1}>> -> tensor<1x2x1xi32, #blocked1> loc(#loc211) + %flip_43 = tt.broadcast %flip_39 : tensor<1x2x1xi32, #blocked3> -> tensor<32x2x2xi32, #blocked3> loc(#loc212) + %flip_44 = tt.reshape %flip_43 : tensor<32x2x2xi32, #blocked3> -> tensor<8x16xi32, #blocked> loc(#loc213) + %y = tt.reshape %tmp7 : tensor<8x16xi32, #blocked> -> tensor<64x2x1xi32, #blocked4> loc(#loc225) + %left_mask = arith.subi %cst_4, %flip_40 : tensor<1x2x1xi32, #blocked4> loc(#loc226) + %left_mask_45 = arith.subi %cst_3, %flip_39 : tensor<1x2x1xi32, #blocked3> loc(#loc226) + %left_mask_46 = arith.subi %cst_2, %flip_41 : tensor<1x2x1xi32, #blocked2> loc(#loc226) + %left_mask_47 = arith.subi %cst_1, %flip_42 : tensor<1x2x1xi32, #blocked1> loc(#loc226) + %ileft = tt.broadcast %left_mask : tensor<1x2x1xi32, #blocked4> -> tensor<64x2x1xi32, #blocked4> loc(#loc227) + %ileft_48 = arith.muli %y, %ileft : tensor<64x2x1xi32, #blocked4> loc(#loc227) + %ileft_49 = "tt.reduce"(%ileft_48) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc228)), %ileft_742: i32 loc(callsite(#loc1 at #loc228))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc318) + tt.reduce.return %ileft_743 : i32 loc(#loc300) + }) : (tensor<64x2x1xi32, #blocked4>) -> tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc300) + %ileft_50 = tt.expand_dims %ileft_49 {axis = 1 : i32} : tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<64x1x1xi32, #blocked4> loc(#loc229) + %ileft_51 = tt.broadcast %ileft_50 : tensor<64x1x1xi32, #blocked4> -> tensor<64x2x1xi32, #blocked4> loc(#loc230) + %iright = tt.broadcast %flip_40 : tensor<1x2x1xi32, #blocked4> -> tensor<64x2x1xi32, #blocked4> loc(#loc231) + %iright_52 = arith.muli %y, %iright : tensor<64x2x1xi32, #blocked4> loc(#loc231) + %iright_53 = "tt.reduce"(%iright_52) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc232)), %iright_742: i32 loc(callsite(#loc1 at #loc232))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc319) + tt.reduce.return %iright_743 : i32 loc(#loc302) + }) : (tensor<64x2x1xi32, #blocked4>) -> tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc302) + %iright_54 = tt.expand_dims %iright_53 {axis = 1 : i32} : tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<64x1x1xi32, #blocked4> loc(#loc233) + %iright_55 = tt.broadcast %iright_54 : tensor<64x1x1xi32, #blocked4> -> tensor<64x2x1xi32, #blocked4> loc(#loc234) + %ileft_56 = tt.reshape %ileft_51 : tensor<64x2x1xi32, #blocked4> -> tensor<8x16xi32, #blocked> loc(#loc235) + %iright_57 = tt.reshape %iright_55 : tensor<64x2x1xi32, #blocked4> -> tensor<8x16xi32, #blocked> loc(#loc236) + %y_idx = tt.reshape %tmp11 : tensor<8x16xi16, #blocked> -> tensor<64x2x1xi16, #blocked4> loc(#loc237) + %left_idx = arith.trunci %left_mask : tensor<1x2x1xi32, #blocked4> to tensor<1x2x1xi16, #blocked4> loc(#loc238) + %left_idx_58 = tt.broadcast %left_idx : tensor<1x2x1xi16, #blocked4> -> tensor<64x2x1xi16, #blocked4> loc(#loc239) + %left_idx_59 = arith.muli %y_idx, %left_idx_58 : tensor<64x2x1xi16, #blocked4> loc(#loc239) + %input = arith.extsi %left_idx_59 : tensor<64x2x1xi16, #blocked4> to tensor<64x2x1xi32, #blocked4> loc(#loc304) + %left_idx_60 = "tt.reduce"(%input) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc240)), %left_idx_742: i32 loc(callsite(#loc1 at #loc240))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc320) + tt.reduce.return %left_idx_743 : i32 loc(#loc305) + }) : (tensor<64x2x1xi32, #blocked4>) -> tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc305) + %left_idx_61 = tt.expand_dims %left_idx_60 {axis = 1 : i32} : tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<64x1x1xi32, #blocked4> loc(#loc241) + %left_idx_62 = tt.broadcast %left_idx_61 : tensor<64x1x1xi32, #blocked4> -> tensor<64x2x1xi32, #blocked4> loc(#loc242) + %right_idx = arith.trunci %flip_40 : tensor<1x2x1xi32, #blocked4> to tensor<1x2x1xi16, #blocked4> loc(#loc243) + %right_idx_63 = tt.broadcast %right_idx : tensor<1x2x1xi16, #blocked4> -> tensor<64x2x1xi16, #blocked4> loc(#loc244) + %right_idx_64 = arith.muli %y_idx, %right_idx_63 : tensor<64x2x1xi16, #blocked4> loc(#loc244) + %input_65 = arith.extsi %right_idx_64 : tensor<64x2x1xi16, #blocked4> to tensor<64x2x1xi32, #blocked4> loc(#loc307) + %right_idx_66 = "tt.reduce"(%input_65) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc245)), %right_idx_742: i32 loc(callsite(#loc1 at #loc245))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc321) + tt.reduce.return %right_idx_743 : i32 loc(#loc308) + }) : (tensor<64x2x1xi32, #blocked4>) -> tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc308) + %right_idx_67 = tt.expand_dims %right_idx_66 {axis = 1 : i32} : tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<64x1x1xi32, #blocked4> loc(#loc246) + %right_idx_68 = tt.broadcast %right_idx_67 : tensor<64x1x1xi32, #blocked4> -> tensor<64x2x1xi32, #blocked4> loc(#loc247) + %left_idx_69 = tt.reshape %left_idx_62 : tensor<64x2x1xi32, #blocked4> -> tensor<8x16xi32, #blocked> loc(#loc248) + %right_idx_70 = tt.reshape %right_idx_68 : tensor<64x2x1xi32, #blocked4> -> tensor<8x16xi32, #blocked> loc(#loc249) + %cond = arith.cmpi slt, %ileft_56, %iright_57 : tensor<8x16xi32, #blocked> loc(#loc250) + %eq = arith.cmpi eq, %ileft_56, %iright_57 : tensor<8x16xi32, #blocked> loc(#loc251) + %cond_71 = arith.cmpi sgt, %left_idx_69, %right_idx_70 : tensor<8x16xi32, #blocked> loc(#loc252) + %cond_72 = arith.andi %eq, %cond_71 : tensor<8x16xi1, #blocked> loc(#loc253) + %cond_73 = arith.ori %cond, %cond_72 : tensor<8x16xi1, #blocked> loc(#loc254) + %cond_74 = arith.extui %cond_73 : tensor<8x16xi1, #blocked> to tensor<8x16xi32, #blocked> loc(#loc255) + %cond_75 = arith.xori %cond_74, %flip_44 : tensor<8x16xi32, #blocked> loc(#loc255) + %cond_76 = arith.cmpi ne, %cond_75, %cst_9 : tensor<8x16xi32, #blocked> loc(#loc256) + %ret = arith.xori %ileft_56, %iright_57 : tensor<8x16xi32, #blocked> loc(#loc257) + %ret_77 = arith.select %cond_76, %ret, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc258) + %ret_78 = arith.xori %tmp7, %ret_77 : tensor<8x16xi32, #blocked> loc(#loc259) + %new_idxs = arith.xori %left_idx_69, %right_idx_70 : tensor<8x16xi32, #blocked> loc(#loc260) + %new_idxs_79 = arith.select %cond_76, %new_idxs, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc261) + %new_idxs_80 = arith.extsi %tmp9 : tensor<1x16xi16, #blocked> to tensor<1x16xi32, #blocked> loc(#loc262) + %new_idxs_81 = tt.broadcast %new_idxs_80 : tensor<1x16xi32, #blocked> -> tensor<8x16xi32, #blocked> loc(#loc262) + %new_idxs_82 = arith.xori %new_idxs_81, %new_idxs_79 : tensor<8x16xi32, #blocked> loc(#loc262) + %flip_83 = tt.broadcast %flip_41 : tensor<1x2x1xi32, #blocked2> -> tensor<16x2x4xi32, #blocked2> loc(#loc212) + %flip_84 = tt.reshape %flip_83 : tensor<16x2x4xi32, #blocked2> -> tensor<8x16xi32, #blocked> loc(#loc213) + %y_85 = tt.reshape %ret_78 : tensor<8x16xi32, #blocked> -> tensor<32x2x2xi32, #blocked3> loc(#loc225) + %ileft_86 = tt.broadcast %left_mask_45 : tensor<1x2x1xi32, #blocked3> -> tensor<32x2x2xi32, #blocked3> loc(#loc227) + %ileft_87 = arith.muli %y_85, %ileft_86 : tensor<32x2x2xi32, #blocked3> loc(#loc227) + %ileft_88 = "tt.reduce"(%ileft_87) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc228)), %ileft_742: i32 loc(callsite(#loc1 at #loc228))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc318) + tt.reduce.return %ileft_743 : i32 loc(#loc300) + }) : (tensor<32x2x2xi32, #blocked3>) -> tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc300) + %ileft_89 = tt.expand_dims %ileft_88 {axis = 1 : i32} : tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x2xi32, #blocked3> loc(#loc229) + %ileft_90 = tt.broadcast %ileft_89 : tensor<32x1x2xi32, #blocked3> -> tensor<32x2x2xi32, #blocked3> loc(#loc230) + %iright_91 = arith.muli %y_85, %flip_43 : tensor<32x2x2xi32, #blocked3> loc(#loc231) + %iright_92 = "tt.reduce"(%iright_91) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc232)), %iright_742: i32 loc(callsite(#loc1 at #loc232))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc319) + tt.reduce.return %iright_743 : i32 loc(#loc302) + }) : (tensor<32x2x2xi32, #blocked3>) -> tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc302) + %iright_93 = tt.expand_dims %iright_92 {axis = 1 : i32} : tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x2xi32, #blocked3> loc(#loc233) + %iright_94 = tt.broadcast %iright_93 : tensor<32x1x2xi32, #blocked3> -> tensor<32x2x2xi32, #blocked3> loc(#loc234) + %ileft_95 = tt.reshape %ileft_90 : tensor<32x2x2xi32, #blocked3> -> tensor<8x16xi32, #blocked> loc(#loc235) + %iright_96 = tt.reshape %iright_94 : tensor<32x2x2xi32, #blocked3> -> tensor<8x16xi32, #blocked> loc(#loc236) + %y_idx_97 = tt.reshape %new_idxs_82 : tensor<8x16xi32, #blocked> -> tensor<32x2x2xi32, #blocked3> loc(#loc237) + %left_idx_98 = arith.muli %y_idx_97, %ileft_86 : tensor<32x2x2xi32, #blocked3> loc(#loc239) + %left_idx_99 = "tt.reduce"(%left_idx_98) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc240)), %left_idx_742: i32 loc(callsite(#loc1 at #loc240))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc320) + tt.reduce.return %left_idx_743 : i32 loc(#loc305) + }) : (tensor<32x2x2xi32, #blocked3>) -> tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc305) + %left_idx_100 = tt.expand_dims %left_idx_99 {axis = 1 : i32} : tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x2xi32, #blocked3> loc(#loc241) + %left_idx_101 = tt.broadcast %left_idx_100 : tensor<32x1x2xi32, #blocked3> -> tensor<32x2x2xi32, #blocked3> loc(#loc242) + %right_idx_102 = arith.muli %y_idx_97, %flip_43 : tensor<32x2x2xi32, #blocked3> loc(#loc244) + %right_idx_103 = "tt.reduce"(%right_idx_102) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc245)), %right_idx_742: i32 loc(callsite(#loc1 at #loc245))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc321) + tt.reduce.return %right_idx_743 : i32 loc(#loc308) + }) : (tensor<32x2x2xi32, #blocked3>) -> tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc308) + %right_idx_104 = tt.expand_dims %right_idx_103 {axis = 1 : i32} : tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x2xi32, #blocked3> loc(#loc246) + %right_idx_105 = tt.broadcast %right_idx_104 : tensor<32x1x2xi32, #blocked3> -> tensor<32x2x2xi32, #blocked3> loc(#loc247) + %left_idx_106 = tt.reshape %left_idx_101 : tensor<32x2x2xi32, #blocked3> -> tensor<8x16xi32, #blocked> loc(#loc248) + %right_idx_107 = tt.reshape %right_idx_105 : tensor<32x2x2xi32, #blocked3> -> tensor<8x16xi32, #blocked> loc(#loc249) + %cond_108 = arith.cmpi slt, %ileft_95, %iright_96 : tensor<8x16xi32, #blocked> loc(#loc250) + %eq_109 = arith.cmpi eq, %ileft_95, %iright_96 : tensor<8x16xi32, #blocked> loc(#loc251) + %cond_110 = arith.cmpi sgt, %left_idx_106, %right_idx_107 : tensor<8x16xi32, #blocked> loc(#loc252) + %cond_111 = arith.andi %eq_109, %cond_110 : tensor<8x16xi1, #blocked> loc(#loc253) + %cond_112 = arith.ori %cond_108, %cond_111 : tensor<8x16xi1, #blocked> loc(#loc254) + %cond_113 = arith.extui %cond_112 : tensor<8x16xi1, #blocked> to tensor<8x16xi32, #blocked> loc(#loc255) + %cond_114 = arith.xori %cond_113, %flip_84 : tensor<8x16xi32, #blocked> loc(#loc255) + %cond_115 = arith.cmpi ne, %cond_114, %cst_9 : tensor<8x16xi32, #blocked> loc(#loc256) + %ret_116 = arith.xori %ileft_95, %iright_96 : tensor<8x16xi32, #blocked> loc(#loc257) + %ret_117 = arith.select %cond_115, %ret_116, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc258) + %ret_118 = arith.xori %ret_78, %ret_117 : tensor<8x16xi32, #blocked> loc(#loc259) + %new_idxs_119 = arith.xori %left_idx_106, %right_idx_107 : tensor<8x16xi32, #blocked> loc(#loc260) + %new_idxs_120 = arith.select %cond_115, %new_idxs_119, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc261) + %new_idxs_121 = arith.xori %new_idxs_82, %new_idxs_120 : tensor<8x16xi32, #blocked> loc(#loc262) + %y_122 = tt.reshape %ret_118 : tensor<8x16xi32, #blocked> -> tensor<64x2x1xi32, #blocked4> loc(#loc225) + %ileft_123 = arith.muli %y_122, %ileft : tensor<64x2x1xi32, #blocked4> loc(#loc227) + %ileft_124 = "tt.reduce"(%ileft_123) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc228)), %ileft_742: i32 loc(callsite(#loc1 at #loc228))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc318) + tt.reduce.return %ileft_743 : i32 loc(#loc300) + }) : (tensor<64x2x1xi32, #blocked4>) -> tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc300) + %ileft_125 = tt.expand_dims %ileft_124 {axis = 1 : i32} : tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<64x1x1xi32, #blocked4> loc(#loc229) + %ileft_126 = tt.broadcast %ileft_125 : tensor<64x1x1xi32, #blocked4> -> tensor<64x2x1xi32, #blocked4> loc(#loc230) + %iright_127 = arith.muli %y_122, %iright : tensor<64x2x1xi32, #blocked4> loc(#loc231) + %iright_128 = "tt.reduce"(%iright_127) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc232)), %iright_742: i32 loc(callsite(#loc1 at #loc232))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc319) + tt.reduce.return %iright_743 : i32 loc(#loc302) + }) : (tensor<64x2x1xi32, #blocked4>) -> tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc302) + %iright_129 = tt.expand_dims %iright_128 {axis = 1 : i32} : tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<64x1x1xi32, #blocked4> loc(#loc233) + %iright_130 = tt.broadcast %iright_129 : tensor<64x1x1xi32, #blocked4> -> tensor<64x2x1xi32, #blocked4> loc(#loc234) + %ileft_131 = tt.reshape %ileft_126 : tensor<64x2x1xi32, #blocked4> -> tensor<8x16xi32, #blocked> loc(#loc235) + %iright_132 = tt.reshape %iright_130 : tensor<64x2x1xi32, #blocked4> -> tensor<8x16xi32, #blocked> loc(#loc236) + %y_idx_133 = tt.reshape %new_idxs_121 : tensor<8x16xi32, #blocked> -> tensor<64x2x1xi32, #blocked4> loc(#loc237) + %left_idx_134 = arith.muli %y_idx_133, %ileft : tensor<64x2x1xi32, #blocked4> loc(#loc239) + %left_idx_135 = "tt.reduce"(%left_idx_134) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc240)), %left_idx_742: i32 loc(callsite(#loc1 at #loc240))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc320) + tt.reduce.return %left_idx_743 : i32 loc(#loc305) + }) : (tensor<64x2x1xi32, #blocked4>) -> tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc305) + %left_idx_136 = tt.expand_dims %left_idx_135 {axis = 1 : i32} : tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<64x1x1xi32, #blocked4> loc(#loc241) + %left_idx_137 = tt.broadcast %left_idx_136 : tensor<64x1x1xi32, #blocked4> -> tensor<64x2x1xi32, #blocked4> loc(#loc242) + %right_idx_138 = arith.muli %y_idx_133, %iright : tensor<64x2x1xi32, #blocked4> loc(#loc244) + %right_idx_139 = "tt.reduce"(%right_idx_138) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc245)), %right_idx_742: i32 loc(callsite(#loc1 at #loc245))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc321) + tt.reduce.return %right_idx_743 : i32 loc(#loc308) + }) : (tensor<64x2x1xi32, #blocked4>) -> tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc308) + %right_idx_140 = tt.expand_dims %right_idx_139 {axis = 1 : i32} : tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<64x1x1xi32, #blocked4> loc(#loc246) + %right_idx_141 = tt.broadcast %right_idx_140 : tensor<64x1x1xi32, #blocked4> -> tensor<64x2x1xi32, #blocked4> loc(#loc247) + %left_idx_142 = tt.reshape %left_idx_137 : tensor<64x2x1xi32, #blocked4> -> tensor<8x16xi32, #blocked> loc(#loc248) + %right_idx_143 = tt.reshape %right_idx_141 : tensor<64x2x1xi32, #blocked4> -> tensor<8x16xi32, #blocked> loc(#loc249) + %cond_144 = arith.cmpi slt, %ileft_131, %iright_132 : tensor<8x16xi32, #blocked> loc(#loc250) + %eq_145 = arith.cmpi eq, %ileft_131, %iright_132 : tensor<8x16xi32, #blocked> loc(#loc251) + %cond_146 = arith.cmpi sgt, %left_idx_142, %right_idx_143 : tensor<8x16xi32, #blocked> loc(#loc252) + %cond_147 = arith.andi %eq_145, %cond_146 : tensor<8x16xi1, #blocked> loc(#loc253) + %cond_148 = arith.ori %cond_144, %cond_147 : tensor<8x16xi1, #blocked> loc(#loc254) + %cond_149 = arith.extui %cond_148 : tensor<8x16xi1, #blocked> to tensor<8x16xi32, #blocked> loc(#loc255) + %cond_150 = arith.xori %cond_149, %flip_84 : tensor<8x16xi32, #blocked> loc(#loc255) + %cond_151 = arith.cmpi ne, %cond_150, %cst_9 : tensor<8x16xi32, #blocked> loc(#loc256) + %ret_152 = arith.xori %ileft_131, %iright_132 : tensor<8x16xi32, #blocked> loc(#loc257) + %ret_153 = arith.select %cond_151, %ret_152, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc258) + %ret_154 = arith.xori %ret_118, %ret_153 : tensor<8x16xi32, #blocked> loc(#loc259) + %new_idxs_155 = arith.xori %left_idx_142, %right_idx_143 : tensor<8x16xi32, #blocked> loc(#loc260) + %new_idxs_156 = arith.select %cond_151, %new_idxs_155, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc261) + %new_idxs_157 = arith.xori %new_idxs_121, %new_idxs_156 : tensor<8x16xi32, #blocked> loc(#loc262) + %flip_158 = tt.broadcast %flip_42 : tensor<1x2x1xi32, #blocked1> -> tensor<8x2x8xi32, #blocked1> loc(#loc212) + %flip_159 = tt.reshape %flip_158 : tensor<8x2x8xi32, #blocked1> -> tensor<8x16xi32, #blocked> loc(#loc213) + %y_160 = tt.reshape %ret_154 : tensor<8x16xi32, #blocked> -> tensor<16x2x4xi32, #blocked2> loc(#loc225) + %ileft_161 = tt.broadcast %left_mask_46 : tensor<1x2x1xi32, #blocked2> -> tensor<16x2x4xi32, #blocked2> loc(#loc227) + %ileft_162 = arith.muli %y_160, %ileft_161 : tensor<16x2x4xi32, #blocked2> loc(#loc227) + %ileft_163 = "tt.reduce"(%ileft_162) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc228)), %ileft_742: i32 loc(callsite(#loc1 at #loc228))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc318) + tt.reduce.return %ileft_743 : i32 loc(#loc300) + }) : (tensor<16x2x4xi32, #blocked2>) -> tensor<16x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc300) + %ileft_164 = tt.expand_dims %ileft_163 {axis = 1 : i32} : tensor<16x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1x4xi32, #blocked2> loc(#loc229) + %ileft_165 = tt.broadcast %ileft_164 : tensor<16x1x4xi32, #blocked2> -> tensor<16x2x4xi32, #blocked2> loc(#loc230) + %iright_166 = arith.muli %y_160, %flip_83 : tensor<16x2x4xi32, #blocked2> loc(#loc231) + %iright_167 = "tt.reduce"(%iright_166) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc232)), %iright_742: i32 loc(callsite(#loc1 at #loc232))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc319) + tt.reduce.return %iright_743 : i32 loc(#loc302) + }) : (tensor<16x2x4xi32, #blocked2>) -> tensor<16x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc302) + %iright_168 = tt.expand_dims %iright_167 {axis = 1 : i32} : tensor<16x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1x4xi32, #blocked2> loc(#loc233) + %iright_169 = tt.broadcast %iright_168 : tensor<16x1x4xi32, #blocked2> -> tensor<16x2x4xi32, #blocked2> loc(#loc234) + %ileft_170 = tt.reshape %ileft_165 : tensor<16x2x4xi32, #blocked2> -> tensor<8x16xi32, #blocked> loc(#loc235) + %iright_171 = tt.reshape %iright_169 : tensor<16x2x4xi32, #blocked2> -> tensor<8x16xi32, #blocked> loc(#loc236) + %y_idx_172 = tt.reshape %new_idxs_157 : tensor<8x16xi32, #blocked> -> tensor<16x2x4xi32, #blocked2> loc(#loc237) + %left_idx_173 = arith.muli %y_idx_172, %ileft_161 : tensor<16x2x4xi32, #blocked2> loc(#loc239) + %left_idx_174 = "tt.reduce"(%left_idx_173) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc240)), %left_idx_742: i32 loc(callsite(#loc1 at #loc240))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc320) + tt.reduce.return %left_idx_743 : i32 loc(#loc305) + }) : (tensor<16x2x4xi32, #blocked2>) -> tensor<16x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc305) + %left_idx_175 = tt.expand_dims %left_idx_174 {axis = 1 : i32} : tensor<16x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1x4xi32, #blocked2> loc(#loc241) + %left_idx_176 = tt.broadcast %left_idx_175 : tensor<16x1x4xi32, #blocked2> -> tensor<16x2x4xi32, #blocked2> loc(#loc242) + %right_idx_177 = arith.muli %y_idx_172, %flip_83 : tensor<16x2x4xi32, #blocked2> loc(#loc244) + %right_idx_178 = "tt.reduce"(%right_idx_177) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc245)), %right_idx_742: i32 loc(callsite(#loc1 at #loc245))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc321) + tt.reduce.return %right_idx_743 : i32 loc(#loc308) + }) : (tensor<16x2x4xi32, #blocked2>) -> tensor<16x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc308) + %right_idx_179 = tt.expand_dims %right_idx_178 {axis = 1 : i32} : tensor<16x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1x4xi32, #blocked2> loc(#loc246) + %right_idx_180 = tt.broadcast %right_idx_179 : tensor<16x1x4xi32, #blocked2> -> tensor<16x2x4xi32, #blocked2> loc(#loc247) + %left_idx_181 = tt.reshape %left_idx_176 : tensor<16x2x4xi32, #blocked2> -> tensor<8x16xi32, #blocked> loc(#loc248) + %right_idx_182 = tt.reshape %right_idx_180 : tensor<16x2x4xi32, #blocked2> -> tensor<8x16xi32, #blocked> loc(#loc249) + %cond_183 = arith.cmpi slt, %ileft_170, %iright_171 : tensor<8x16xi32, #blocked> loc(#loc250) + %eq_184 = arith.cmpi eq, %ileft_170, %iright_171 : tensor<8x16xi32, #blocked> loc(#loc251) + %cond_185 = arith.cmpi sgt, %left_idx_181, %right_idx_182 : tensor<8x16xi32, #blocked> loc(#loc252) + %cond_186 = arith.andi %eq_184, %cond_185 : tensor<8x16xi1, #blocked> loc(#loc253) + %cond_187 = arith.ori %cond_183, %cond_186 : tensor<8x16xi1, #blocked> loc(#loc254) + %cond_188 = arith.extui %cond_187 : tensor<8x16xi1, #blocked> to tensor<8x16xi32, #blocked> loc(#loc255) + %cond_189 = arith.xori %cond_188, %flip_159 : tensor<8x16xi32, #blocked> loc(#loc255) + %cond_190 = arith.cmpi ne, %cond_189, %cst_9 : tensor<8x16xi32, #blocked> loc(#loc256) + %ret_191 = arith.xori %ileft_170, %iright_171 : tensor<8x16xi32, #blocked> loc(#loc257) + %ret_192 = arith.select %cond_190, %ret_191, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc258) + %ret_193 = arith.xori %ret_154, %ret_192 : tensor<8x16xi32, #blocked> loc(#loc259) + %new_idxs_194 = arith.xori %left_idx_181, %right_idx_182 : tensor<8x16xi32, #blocked> loc(#loc260) + %new_idxs_195 = arith.select %cond_190, %new_idxs_194, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc261) + %new_idxs_196 = arith.xori %new_idxs_157, %new_idxs_195 : tensor<8x16xi32, #blocked> loc(#loc262) + %y_197 = tt.reshape %ret_193 : tensor<8x16xi32, #blocked> -> tensor<32x2x2xi32, #blocked3> loc(#loc225) + %ileft_198 = arith.muli %y_197, %ileft_86 : tensor<32x2x2xi32, #blocked3> loc(#loc227) + %ileft_199 = "tt.reduce"(%ileft_198) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc228)), %ileft_742: i32 loc(callsite(#loc1 at #loc228))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc318) + tt.reduce.return %ileft_743 : i32 loc(#loc300) + }) : (tensor<32x2x2xi32, #blocked3>) -> tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc300) + %ileft_200 = tt.expand_dims %ileft_199 {axis = 1 : i32} : tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x2xi32, #blocked3> loc(#loc229) + %ileft_201 = tt.broadcast %ileft_200 : tensor<32x1x2xi32, #blocked3> -> tensor<32x2x2xi32, #blocked3> loc(#loc230) + %iright_202 = arith.muli %y_197, %flip_43 : tensor<32x2x2xi32, #blocked3> loc(#loc231) + %iright_203 = "tt.reduce"(%iright_202) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc232)), %iright_742: i32 loc(callsite(#loc1 at #loc232))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc319) + tt.reduce.return %iright_743 : i32 loc(#loc302) + }) : (tensor<32x2x2xi32, #blocked3>) -> tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc302) + %iright_204 = tt.expand_dims %iright_203 {axis = 1 : i32} : tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x2xi32, #blocked3> loc(#loc233) + %iright_205 = tt.broadcast %iright_204 : tensor<32x1x2xi32, #blocked3> -> tensor<32x2x2xi32, #blocked3> loc(#loc234) + %ileft_206 = tt.reshape %ileft_201 : tensor<32x2x2xi32, #blocked3> -> tensor<8x16xi32, #blocked> loc(#loc235) + %iright_207 = tt.reshape %iright_205 : tensor<32x2x2xi32, #blocked3> -> tensor<8x16xi32, #blocked> loc(#loc236) + %y_idx_208 = tt.reshape %new_idxs_196 : tensor<8x16xi32, #blocked> -> tensor<32x2x2xi32, #blocked3> loc(#loc237) + %left_idx_209 = arith.muli %y_idx_208, %ileft_86 : tensor<32x2x2xi32, #blocked3> loc(#loc239) + %left_idx_210 = "tt.reduce"(%left_idx_209) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc240)), %left_idx_742: i32 loc(callsite(#loc1 at #loc240))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc320) + tt.reduce.return %left_idx_743 : i32 loc(#loc305) + }) : (tensor<32x2x2xi32, #blocked3>) -> tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc305) + %left_idx_211 = tt.expand_dims %left_idx_210 {axis = 1 : i32} : tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x2xi32, #blocked3> loc(#loc241) + %left_idx_212 = tt.broadcast %left_idx_211 : tensor<32x1x2xi32, #blocked3> -> tensor<32x2x2xi32, #blocked3> loc(#loc242) + %right_idx_213 = arith.muli %y_idx_208, %flip_43 : tensor<32x2x2xi32, #blocked3> loc(#loc244) + %right_idx_214 = "tt.reduce"(%right_idx_213) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc245)), %right_idx_742: i32 loc(callsite(#loc1 at #loc245))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc321) + tt.reduce.return %right_idx_743 : i32 loc(#loc308) + }) : (tensor<32x2x2xi32, #blocked3>) -> tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc308) + %right_idx_215 = tt.expand_dims %right_idx_214 {axis = 1 : i32} : tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x2xi32, #blocked3> loc(#loc246) + %right_idx_216 = tt.broadcast %right_idx_215 : tensor<32x1x2xi32, #blocked3> -> tensor<32x2x2xi32, #blocked3> loc(#loc247) + %left_idx_217 = tt.reshape %left_idx_212 : tensor<32x2x2xi32, #blocked3> -> tensor<8x16xi32, #blocked> loc(#loc248) + %right_idx_218 = tt.reshape %right_idx_216 : tensor<32x2x2xi32, #blocked3> -> tensor<8x16xi32, #blocked> loc(#loc249) + %cond_219 = arith.cmpi slt, %ileft_206, %iright_207 : tensor<8x16xi32, #blocked> loc(#loc250) + %eq_220 = arith.cmpi eq, %ileft_206, %iright_207 : tensor<8x16xi32, #blocked> loc(#loc251) + %cond_221 = arith.cmpi sgt, %left_idx_217, %right_idx_218 : tensor<8x16xi32, #blocked> loc(#loc252) + %cond_222 = arith.andi %eq_220, %cond_221 : tensor<8x16xi1, #blocked> loc(#loc253) + %cond_223 = arith.ori %cond_219, %cond_222 : tensor<8x16xi1, #blocked> loc(#loc254) + %cond_224 = arith.extui %cond_223 : tensor<8x16xi1, #blocked> to tensor<8x16xi32, #blocked> loc(#loc255) + %cond_225 = arith.xori %cond_224, %flip_159 : tensor<8x16xi32, #blocked> loc(#loc255) + %cond_226 = arith.cmpi ne, %cond_225, %cst_9 : tensor<8x16xi32, #blocked> loc(#loc256) + %ret_227 = arith.xori %ileft_206, %iright_207 : tensor<8x16xi32, #blocked> loc(#loc257) + %ret_228 = arith.select %cond_226, %ret_227, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc258) + %ret_229 = arith.xori %ret_193, %ret_228 : tensor<8x16xi32, #blocked> loc(#loc259) + %new_idxs_230 = arith.xori %left_idx_217, %right_idx_218 : tensor<8x16xi32, #blocked> loc(#loc260) + %new_idxs_231 = arith.select %cond_226, %new_idxs_230, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc261) + %new_idxs_232 = arith.xori %new_idxs_196, %new_idxs_231 : tensor<8x16xi32, #blocked> loc(#loc262) + %y_233 = tt.reshape %ret_229 : tensor<8x16xi32, #blocked> -> tensor<64x2x1xi32, #blocked4> loc(#loc225) + %ileft_234 = arith.muli %y_233, %ileft : tensor<64x2x1xi32, #blocked4> loc(#loc227) + %ileft_235 = "tt.reduce"(%ileft_234) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc228)), %ileft_742: i32 loc(callsite(#loc1 at #loc228))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc318) + tt.reduce.return %ileft_743 : i32 loc(#loc300) + }) : (tensor<64x2x1xi32, #blocked4>) -> tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc300) + %ileft_236 = tt.expand_dims %ileft_235 {axis = 1 : i32} : tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<64x1x1xi32, #blocked4> loc(#loc229) + %ileft_237 = tt.broadcast %ileft_236 : tensor<64x1x1xi32, #blocked4> -> tensor<64x2x1xi32, #blocked4> loc(#loc230) + %iright_238 = arith.muli %y_233, %iright : tensor<64x2x1xi32, #blocked4> loc(#loc231) + %iright_239 = "tt.reduce"(%iright_238) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc232)), %iright_742: i32 loc(callsite(#loc1 at #loc232))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc319) + tt.reduce.return %iright_743 : i32 loc(#loc302) + }) : (tensor<64x2x1xi32, #blocked4>) -> tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc302) + %iright_240 = tt.expand_dims %iright_239 {axis = 1 : i32} : tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<64x1x1xi32, #blocked4> loc(#loc233) + %iright_241 = tt.broadcast %iright_240 : tensor<64x1x1xi32, #blocked4> -> tensor<64x2x1xi32, #blocked4> loc(#loc234) + %ileft_242 = tt.reshape %ileft_237 : tensor<64x2x1xi32, #blocked4> -> tensor<8x16xi32, #blocked> loc(#loc235) + %iright_243 = tt.reshape %iright_241 : tensor<64x2x1xi32, #blocked4> -> tensor<8x16xi32, #blocked> loc(#loc236) + %y_idx_244 = tt.reshape %new_idxs_232 : tensor<8x16xi32, #blocked> -> tensor<64x2x1xi32, #blocked4> loc(#loc237) + %left_idx_245 = arith.muli %y_idx_244, %ileft : tensor<64x2x1xi32, #blocked4> loc(#loc239) + %left_idx_246 = "tt.reduce"(%left_idx_245) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc240)), %left_idx_742: i32 loc(callsite(#loc1 at #loc240))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc320) + tt.reduce.return %left_idx_743 : i32 loc(#loc305) + }) : (tensor<64x2x1xi32, #blocked4>) -> tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc305) + %left_idx_247 = tt.expand_dims %left_idx_246 {axis = 1 : i32} : tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<64x1x1xi32, #blocked4> loc(#loc241) + %left_idx_248 = tt.broadcast %left_idx_247 : tensor<64x1x1xi32, #blocked4> -> tensor<64x2x1xi32, #blocked4> loc(#loc242) + %right_idx_249 = arith.muli %y_idx_244, %iright : tensor<64x2x1xi32, #blocked4> loc(#loc244) + %right_idx_250 = "tt.reduce"(%right_idx_249) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc245)), %right_idx_742: i32 loc(callsite(#loc1 at #loc245))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc321) + tt.reduce.return %right_idx_743 : i32 loc(#loc308) + }) : (tensor<64x2x1xi32, #blocked4>) -> tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc308) + %right_idx_251 = tt.expand_dims %right_idx_250 {axis = 1 : i32} : tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<64x1x1xi32, #blocked4> loc(#loc246) + %right_idx_252 = tt.broadcast %right_idx_251 : tensor<64x1x1xi32, #blocked4> -> tensor<64x2x1xi32, #blocked4> loc(#loc247) + %left_idx_253 = tt.reshape %left_idx_248 : tensor<64x2x1xi32, #blocked4> -> tensor<8x16xi32, #blocked> loc(#loc248) + %right_idx_254 = tt.reshape %right_idx_252 : tensor<64x2x1xi32, #blocked4> -> tensor<8x16xi32, #blocked> loc(#loc249) + %cond_255 = arith.cmpi slt, %ileft_242, %iright_243 : tensor<8x16xi32, #blocked> loc(#loc250) + %eq_256 = arith.cmpi eq, %ileft_242, %iright_243 : tensor<8x16xi32, #blocked> loc(#loc251) + %cond_257 = arith.cmpi sgt, %left_idx_253, %right_idx_254 : tensor<8x16xi32, #blocked> loc(#loc252) + %cond_258 = arith.andi %eq_256, %cond_257 : tensor<8x16xi1, #blocked> loc(#loc253) + %cond_259 = arith.ori %cond_255, %cond_258 : tensor<8x16xi1, #blocked> loc(#loc254) + %cond_260 = arith.extui %cond_259 : tensor<8x16xi1, #blocked> to tensor<8x16xi32, #blocked> loc(#loc255) + %cond_261 = arith.xori %cond_260, %flip_159 : tensor<8x16xi32, #blocked> loc(#loc255) + %cond_262 = arith.cmpi ne, %cond_261, %cst_9 : tensor<8x16xi32, #blocked> loc(#loc256) + %ret_263 = arith.xori %ileft_242, %iright_243 : tensor<8x16xi32, #blocked> loc(#loc257) + %ret_264 = arith.select %cond_262, %ret_263, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc258) + %ret_265 = arith.xori %ret_229, %ret_264 : tensor<8x16xi32, #blocked> loc(#loc259) + %new_idxs_266 = arith.xori %left_idx_253, %right_idx_254 : tensor<8x16xi32, #blocked> loc(#loc260) + %new_idxs_267 = arith.select %cond_262, %new_idxs_266, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc261) + %new_idxs_268 = arith.xori %new_idxs_232, %new_idxs_267 : tensor<8x16xi32, #blocked> loc(#loc262) + %y_269 = tt.reshape %ret_265 : tensor<8x16xi32, #blocked> -> tensor<8x2x8xi32, #blocked1> loc(#loc225) + %ileft_270 = tt.broadcast %left_mask_47 : tensor<1x2x1xi32, #blocked1> -> tensor<8x2x8xi32, #blocked1> loc(#loc227) + %ileft_271 = arith.muli %y_269, %ileft_270 : tensor<8x2x8xi32, #blocked1> loc(#loc227) + %ileft_272 = "tt.reduce"(%ileft_271) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc228)), %ileft_742: i32 loc(callsite(#loc1 at #loc228))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc318) + tt.reduce.return %ileft_743 : i32 loc(#loc300) + }) : (tensor<8x2x8xi32, #blocked1>) -> tensor<8x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc300) + %ileft_273 = tt.expand_dims %ileft_272 {axis = 1 : i32} : tensor<8x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<8x1x8xi32, #blocked1> loc(#loc229) + %ileft_274 = tt.broadcast %ileft_273 : tensor<8x1x8xi32, #blocked1> -> tensor<8x2x8xi32, #blocked1> loc(#loc230) + %iright_275 = arith.muli %y_269, %flip_158 : tensor<8x2x8xi32, #blocked1> loc(#loc231) + %iright_276 = "tt.reduce"(%iright_275) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc232)), %iright_742: i32 loc(callsite(#loc1 at #loc232))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc319) + tt.reduce.return %iright_743 : i32 loc(#loc302) + }) : (tensor<8x2x8xi32, #blocked1>) -> tensor<8x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc302) + %iright_277 = tt.expand_dims %iright_276 {axis = 1 : i32} : tensor<8x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<8x1x8xi32, #blocked1> loc(#loc233) + %iright_278 = tt.broadcast %iright_277 : tensor<8x1x8xi32, #blocked1> -> tensor<8x2x8xi32, #blocked1> loc(#loc234) + %ileft_279 = tt.reshape %ileft_274 : tensor<8x2x8xi32, #blocked1> -> tensor<8x16xi32, #blocked> loc(#loc235) + %iright_280 = tt.reshape %iright_278 : tensor<8x2x8xi32, #blocked1> -> tensor<8x16xi32, #blocked> loc(#loc236) + %y_idx_281 = tt.reshape %new_idxs_268 : tensor<8x16xi32, #blocked> -> tensor<8x2x8xi32, #blocked1> loc(#loc237) + %left_idx_282 = arith.muli %y_idx_281, %ileft_270 : tensor<8x2x8xi32, #blocked1> loc(#loc239) + %left_idx_283 = "tt.reduce"(%left_idx_282) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc240)), %left_idx_742: i32 loc(callsite(#loc1 at #loc240))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc320) + tt.reduce.return %left_idx_743 : i32 loc(#loc305) + }) : (tensor<8x2x8xi32, #blocked1>) -> tensor<8x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc305) + %left_idx_284 = tt.expand_dims %left_idx_283 {axis = 1 : i32} : tensor<8x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<8x1x8xi32, #blocked1> loc(#loc241) + %left_idx_285 = tt.broadcast %left_idx_284 : tensor<8x1x8xi32, #blocked1> -> tensor<8x2x8xi32, #blocked1> loc(#loc242) + %right_idx_286 = arith.muli %y_idx_281, %flip_158 : tensor<8x2x8xi32, #blocked1> loc(#loc244) + %right_idx_287 = "tt.reduce"(%right_idx_286) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc245)), %right_idx_742: i32 loc(callsite(#loc1 at #loc245))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc321) + tt.reduce.return %right_idx_743 : i32 loc(#loc308) + }) : (tensor<8x2x8xi32, #blocked1>) -> tensor<8x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc308) + %right_idx_288 = tt.expand_dims %right_idx_287 {axis = 1 : i32} : tensor<8x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<8x1x8xi32, #blocked1> loc(#loc246) + %right_idx_289 = tt.broadcast %right_idx_288 : tensor<8x1x8xi32, #blocked1> -> tensor<8x2x8xi32, #blocked1> loc(#loc247) + %left_idx_290 = tt.reshape %left_idx_285 : tensor<8x2x8xi32, #blocked1> -> tensor<8x16xi32, #blocked> loc(#loc248) + %right_idx_291 = tt.reshape %right_idx_289 : tensor<8x2x8xi32, #blocked1> -> tensor<8x16xi32, #blocked> loc(#loc249) + %cond_292 = arith.cmpi slt, %ileft_279, %iright_280 : tensor<8x16xi32, #blocked> loc(#loc250) + %eq_293 = arith.cmpi eq, %ileft_279, %iright_280 : tensor<8x16xi32, #blocked> loc(#loc251) + %cond_294 = arith.cmpi sgt, %left_idx_290, %right_idx_291 : tensor<8x16xi32, #blocked> loc(#loc252) + %cond_295 = arith.andi %eq_293, %cond_294 : tensor<8x16xi1, #blocked> loc(#loc253) + %cond_296 = arith.ori %cond_292, %cond_295 : tensor<8x16xi1, #blocked> loc(#loc254) + %ret_297 = arith.xori %ileft_279, %iright_280 : tensor<8x16xi32, #blocked> loc(#loc257) + %ret_298 = arith.select %cond_296, %ret_297, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc258) + %ret_299 = arith.xori %ret_265, %ret_298 : tensor<8x16xi32, #blocked> loc(#loc259) + %new_idxs_300 = arith.xori %left_idx_290, %right_idx_291 : tensor<8x16xi32, #blocked> loc(#loc260) + %new_idxs_301 = arith.select %cond_296, %new_idxs_300, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc261) + %new_idxs_302 = arith.xori %new_idxs_268, %new_idxs_301 : tensor<8x16xi32, #blocked> loc(#loc262) + %y_303 = tt.reshape %ret_299 : tensor<8x16xi32, #blocked> -> tensor<16x2x4xi32, #blocked2> loc(#loc225) + %ileft_304 = arith.muli %y_303, %ileft_161 : tensor<16x2x4xi32, #blocked2> loc(#loc227) + %ileft_305 = "tt.reduce"(%ileft_304) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc228)), %ileft_742: i32 loc(callsite(#loc1 at #loc228))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc318) + tt.reduce.return %ileft_743 : i32 loc(#loc300) + }) : (tensor<16x2x4xi32, #blocked2>) -> tensor<16x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc300) + %ileft_306 = tt.expand_dims %ileft_305 {axis = 1 : i32} : tensor<16x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1x4xi32, #blocked2> loc(#loc229) + %ileft_307 = tt.broadcast %ileft_306 : tensor<16x1x4xi32, #blocked2> -> tensor<16x2x4xi32, #blocked2> loc(#loc230) + %iright_308 = arith.muli %y_303, %flip_83 : tensor<16x2x4xi32, #blocked2> loc(#loc231) + %iright_309 = "tt.reduce"(%iright_308) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc232)), %iright_742: i32 loc(callsite(#loc1 at #loc232))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc319) + tt.reduce.return %iright_743 : i32 loc(#loc302) + }) : (tensor<16x2x4xi32, #blocked2>) -> tensor<16x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc302) + %iright_310 = tt.expand_dims %iright_309 {axis = 1 : i32} : tensor<16x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1x4xi32, #blocked2> loc(#loc233) + %iright_311 = tt.broadcast %iright_310 : tensor<16x1x4xi32, #blocked2> -> tensor<16x2x4xi32, #blocked2> loc(#loc234) + %ileft_312 = tt.reshape %ileft_307 : tensor<16x2x4xi32, #blocked2> -> tensor<8x16xi32, #blocked> loc(#loc235) + %iright_313 = tt.reshape %iright_311 : tensor<16x2x4xi32, #blocked2> -> tensor<8x16xi32, #blocked> loc(#loc236) + %y_idx_314 = tt.reshape %new_idxs_302 : tensor<8x16xi32, #blocked> -> tensor<16x2x4xi32, #blocked2> loc(#loc237) + %left_idx_315 = arith.muli %y_idx_314, %ileft_161 : tensor<16x2x4xi32, #blocked2> loc(#loc239) + %left_idx_316 = "tt.reduce"(%left_idx_315) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc240)), %left_idx_742: i32 loc(callsite(#loc1 at #loc240))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc320) + tt.reduce.return %left_idx_743 : i32 loc(#loc305) + }) : (tensor<16x2x4xi32, #blocked2>) -> tensor<16x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc305) + %left_idx_317 = tt.expand_dims %left_idx_316 {axis = 1 : i32} : tensor<16x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1x4xi32, #blocked2> loc(#loc241) + %left_idx_318 = tt.broadcast %left_idx_317 : tensor<16x1x4xi32, #blocked2> -> tensor<16x2x4xi32, #blocked2> loc(#loc242) + %right_idx_319 = arith.muli %y_idx_314, %flip_83 : tensor<16x2x4xi32, #blocked2> loc(#loc244) + %right_idx_320 = "tt.reduce"(%right_idx_319) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc245)), %right_idx_742: i32 loc(callsite(#loc1 at #loc245))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc321) + tt.reduce.return %right_idx_743 : i32 loc(#loc308) + }) : (tensor<16x2x4xi32, #blocked2>) -> tensor<16x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc308) + %right_idx_321 = tt.expand_dims %right_idx_320 {axis = 1 : i32} : tensor<16x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1x4xi32, #blocked2> loc(#loc246) + %right_idx_322 = tt.broadcast %right_idx_321 : tensor<16x1x4xi32, #blocked2> -> tensor<16x2x4xi32, #blocked2> loc(#loc247) + %left_idx_323 = tt.reshape %left_idx_318 : tensor<16x2x4xi32, #blocked2> -> tensor<8x16xi32, #blocked> loc(#loc248) + %right_idx_324 = tt.reshape %right_idx_322 : tensor<16x2x4xi32, #blocked2> -> tensor<8x16xi32, #blocked> loc(#loc249) + %cond_325 = arith.cmpi slt, %ileft_312, %iright_313 : tensor<8x16xi32, #blocked> loc(#loc250) + %eq_326 = arith.cmpi eq, %ileft_312, %iright_313 : tensor<8x16xi32, #blocked> loc(#loc251) + %cond_327 = arith.cmpi sgt, %left_idx_323, %right_idx_324 : tensor<8x16xi32, #blocked> loc(#loc252) + %cond_328 = arith.andi %eq_326, %cond_327 : tensor<8x16xi1, #blocked> loc(#loc253) + %cond_329 = arith.ori %cond_325, %cond_328 : tensor<8x16xi1, #blocked> loc(#loc254) + %ret_330 = arith.xori %ileft_312, %iright_313 : tensor<8x16xi32, #blocked> loc(#loc257) + %ret_331 = arith.select %cond_329, %ret_330, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc258) + %ret_332 = arith.xori %ret_299, %ret_331 : tensor<8x16xi32, #blocked> loc(#loc259) + %new_idxs_333 = arith.xori %left_idx_323, %right_idx_324 : tensor<8x16xi32, #blocked> loc(#loc260) + %new_idxs_334 = arith.select %cond_329, %new_idxs_333, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc261) + %new_idxs_335 = arith.xori %new_idxs_302, %new_idxs_334 : tensor<8x16xi32, #blocked> loc(#loc262) + %y_336 = tt.reshape %ret_332 : tensor<8x16xi32, #blocked> -> tensor<32x2x2xi32, #blocked3> loc(#loc225) + %ileft_337 = arith.muli %y_336, %ileft_86 : tensor<32x2x2xi32, #blocked3> loc(#loc227) + %ileft_338 = "tt.reduce"(%ileft_337) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc228)), %ileft_742: i32 loc(callsite(#loc1 at #loc228))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc318) + tt.reduce.return %ileft_743 : i32 loc(#loc300) + }) : (tensor<32x2x2xi32, #blocked3>) -> tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc300) + %ileft_339 = tt.expand_dims %ileft_338 {axis = 1 : i32} : tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x2xi32, #blocked3> loc(#loc229) + %ileft_340 = tt.broadcast %ileft_339 : tensor<32x1x2xi32, #blocked3> -> tensor<32x2x2xi32, #blocked3> loc(#loc230) + %iright_341 = arith.muli %y_336, %flip_43 : tensor<32x2x2xi32, #blocked3> loc(#loc231) + %iright_342 = "tt.reduce"(%iright_341) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc232)), %iright_742: i32 loc(callsite(#loc1 at #loc232))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc319) + tt.reduce.return %iright_743 : i32 loc(#loc302) + }) : (tensor<32x2x2xi32, #blocked3>) -> tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc302) + %iright_343 = tt.expand_dims %iright_342 {axis = 1 : i32} : tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x2xi32, #blocked3> loc(#loc233) + %iright_344 = tt.broadcast %iright_343 : tensor<32x1x2xi32, #blocked3> -> tensor<32x2x2xi32, #blocked3> loc(#loc234) + %ileft_345 = tt.reshape %ileft_340 : tensor<32x2x2xi32, #blocked3> -> tensor<8x16xi32, #blocked> loc(#loc235) + %iright_346 = tt.reshape %iright_344 : tensor<32x2x2xi32, #blocked3> -> tensor<8x16xi32, #blocked> loc(#loc236) + %y_idx_347 = tt.reshape %new_idxs_335 : tensor<8x16xi32, #blocked> -> tensor<32x2x2xi32, #blocked3> loc(#loc237) + %left_idx_348 = arith.muli %y_idx_347, %ileft_86 : tensor<32x2x2xi32, #blocked3> loc(#loc239) + %left_idx_349 = "tt.reduce"(%left_idx_348) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc240)), %left_idx_742: i32 loc(callsite(#loc1 at #loc240))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc320) + tt.reduce.return %left_idx_743 : i32 loc(#loc305) + }) : (tensor<32x2x2xi32, #blocked3>) -> tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc305) + %left_idx_350 = tt.expand_dims %left_idx_349 {axis = 1 : i32} : tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x2xi32, #blocked3> loc(#loc241) + %left_idx_351 = tt.broadcast %left_idx_350 : tensor<32x1x2xi32, #blocked3> -> tensor<32x2x2xi32, #blocked3> loc(#loc242) + %right_idx_352 = arith.muli %y_idx_347, %flip_43 : tensor<32x2x2xi32, #blocked3> loc(#loc244) + %right_idx_353 = "tt.reduce"(%right_idx_352) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc245)), %right_idx_742: i32 loc(callsite(#loc1 at #loc245))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc321) + tt.reduce.return %right_idx_743 : i32 loc(#loc308) + }) : (tensor<32x2x2xi32, #blocked3>) -> tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc308) + %right_idx_354 = tt.expand_dims %right_idx_353 {axis = 1 : i32} : tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x2xi32, #blocked3> loc(#loc246) + %right_idx_355 = tt.broadcast %right_idx_354 : tensor<32x1x2xi32, #blocked3> -> tensor<32x2x2xi32, #blocked3> loc(#loc247) + %left_idx_356 = tt.reshape %left_idx_351 : tensor<32x2x2xi32, #blocked3> -> tensor<8x16xi32, #blocked> loc(#loc248) + %right_idx_357 = tt.reshape %right_idx_355 : tensor<32x2x2xi32, #blocked3> -> tensor<8x16xi32, #blocked> loc(#loc249) + %cond_358 = arith.cmpi slt, %ileft_345, %iright_346 : tensor<8x16xi32, #blocked> loc(#loc250) + %eq_359 = arith.cmpi eq, %ileft_345, %iright_346 : tensor<8x16xi32, #blocked> loc(#loc251) + %cond_360 = arith.cmpi sgt, %left_idx_356, %right_idx_357 : tensor<8x16xi32, #blocked> loc(#loc252) + %cond_361 = arith.andi %eq_359, %cond_360 : tensor<8x16xi1, #blocked> loc(#loc253) + %cond_362 = arith.ori %cond_358, %cond_361 : tensor<8x16xi1, #blocked> loc(#loc254) + %ret_363 = arith.xori %ileft_345, %iright_346 : tensor<8x16xi32, #blocked> loc(#loc257) + %ret_364 = arith.select %cond_362, %ret_363, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc258) + %ret_365 = arith.xori %ret_332, %ret_364 : tensor<8x16xi32, #blocked> loc(#loc259) + %new_idxs_366 = arith.xori %left_idx_356, %right_idx_357 : tensor<8x16xi32, #blocked> loc(#loc260) + %new_idxs_367 = arith.select %cond_362, %new_idxs_366, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc261) + %new_idxs_368 = arith.xori %new_idxs_335, %new_idxs_367 : tensor<8x16xi32, #blocked> loc(#loc262) + %y_369 = tt.reshape %ret_365 : tensor<8x16xi32, #blocked> -> tensor<64x2x1xi32, #blocked4> loc(#loc225) + %ileft_370 = arith.muli %y_369, %ileft : tensor<64x2x1xi32, #blocked4> loc(#loc227) + %ileft_371 = "tt.reduce"(%ileft_370) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc228)), %ileft_742: i32 loc(callsite(#loc1 at #loc228))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc318) + tt.reduce.return %ileft_743 : i32 loc(#loc300) + }) : (tensor<64x2x1xi32, #blocked4>) -> tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc300) + %ileft_372 = tt.expand_dims %ileft_371 {axis = 1 : i32} : tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<64x1x1xi32, #blocked4> loc(#loc229) + %ileft_373 = tt.broadcast %ileft_372 : tensor<64x1x1xi32, #blocked4> -> tensor<64x2x1xi32, #blocked4> loc(#loc230) + %iright_374 = arith.muli %y_369, %iright : tensor<64x2x1xi32, #blocked4> loc(#loc231) + %iright_375 = "tt.reduce"(%iright_374) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc232)), %iright_742: i32 loc(callsite(#loc1 at #loc232))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc319) + tt.reduce.return %iright_743 : i32 loc(#loc302) + }) : (tensor<64x2x1xi32, #blocked4>) -> tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc302) + %iright_376 = tt.expand_dims %iright_375 {axis = 1 : i32} : tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<64x1x1xi32, #blocked4> loc(#loc233) + %iright_377 = tt.broadcast %iright_376 : tensor<64x1x1xi32, #blocked4> -> tensor<64x2x1xi32, #blocked4> loc(#loc234) + %ileft_378 = tt.reshape %ileft_373 : tensor<64x2x1xi32, #blocked4> -> tensor<8x16xi32, #blocked> loc(#loc235) + %iright_379 = tt.reshape %iright_377 : tensor<64x2x1xi32, #blocked4> -> tensor<8x16xi32, #blocked> loc(#loc236) + %y_idx_380 = tt.reshape %new_idxs_368 : tensor<8x16xi32, #blocked> -> tensor<64x2x1xi32, #blocked4> loc(#loc237) + %left_idx_381 = arith.muli %y_idx_380, %ileft : tensor<64x2x1xi32, #blocked4> loc(#loc239) + %left_idx_382 = "tt.reduce"(%left_idx_381) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc240)), %left_idx_742: i32 loc(callsite(#loc1 at #loc240))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc320) + tt.reduce.return %left_idx_743 : i32 loc(#loc305) + }) : (tensor<64x2x1xi32, #blocked4>) -> tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc305) + %left_idx_383 = tt.expand_dims %left_idx_382 {axis = 1 : i32} : tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<64x1x1xi32, #blocked4> loc(#loc241) + %left_idx_384 = tt.broadcast %left_idx_383 : tensor<64x1x1xi32, #blocked4> -> tensor<64x2x1xi32, #blocked4> loc(#loc242) + %right_idx_385 = arith.muli %y_idx_380, %iright : tensor<64x2x1xi32, #blocked4> loc(#loc244) + %right_idx_386 = "tt.reduce"(%right_idx_385) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc245)), %right_idx_742: i32 loc(callsite(#loc1 at #loc245))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc321) + tt.reduce.return %right_idx_743 : i32 loc(#loc308) + }) : (tensor<64x2x1xi32, #blocked4>) -> tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc308) + %right_idx_387 = tt.expand_dims %right_idx_386 {axis = 1 : i32} : tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<64x1x1xi32, #blocked4> loc(#loc246) + %right_idx_388 = tt.broadcast %right_idx_387 : tensor<64x1x1xi32, #blocked4> -> tensor<64x2x1xi32, #blocked4> loc(#loc247) + %left_idx_389 = tt.reshape %left_idx_384 : tensor<64x2x1xi32, #blocked4> -> tensor<8x16xi32, #blocked> loc(#loc248) + %right_idx_390 = tt.reshape %right_idx_388 : tensor<64x2x1xi32, #blocked4> -> tensor<8x16xi32, #blocked> loc(#loc249) + %cond_391 = arith.cmpi slt, %ileft_378, %iright_379 : tensor<8x16xi32, #blocked> loc(#loc250) + %eq_392 = arith.cmpi eq, %ileft_378, %iright_379 : tensor<8x16xi32, #blocked> loc(#loc251) + %cond_393 = arith.cmpi sgt, %left_idx_389, %right_idx_390 : tensor<8x16xi32, #blocked> loc(#loc252) + %cond_394 = arith.andi %eq_392, %cond_393 : tensor<8x16xi1, #blocked> loc(#loc253) + %cond_395 = arith.ori %cond_391, %cond_394 : tensor<8x16xi1, #blocked> loc(#loc254) + %new_idxs_396 = arith.xori %left_idx_389, %right_idx_390 : tensor<8x16xi32, #blocked> loc(#loc260) + %new_idxs_397 = arith.select %cond_395, %new_idxs_396, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc261) + %new_idxs_398 = arith.xori %new_idxs_368, %new_idxs_397 : tensor<8x16xi32, #blocked> loc(#loc262) + %tmp14 = arith.cmpi eq, %tmp0_31, %cst_12 : tensor<8x16xi64, #blocked> loc(#loc186) + %tmp16 = arith.extui %tmp14 : tensor<8x16xi1, #blocked> to tensor<8x16xi32, #blocked> loc(#loc217) + %y_399 = tt.reshape %tmp16 : tensor<8x16xi32, #blocked> -> tensor<64x2x1xi32, #blocked4> loc(#loc263) + %ileft_400 = arith.muli %y_399, %ileft : tensor<64x2x1xi32, #blocked4> loc(#loc264) + %ileft_401 = "tt.reduce"(%ileft_400) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc265)), %ileft_742: i32 loc(callsite(#loc1 at #loc265))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc322) + tt.reduce.return %ileft_743 : i32 loc(#loc310) + }) : (tensor<64x2x1xi32, #blocked4>) -> tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc310) + %ileft_402 = tt.expand_dims %ileft_401 {axis = 1 : i32} : tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<64x1x1xi32, #blocked4> loc(#loc266) + %ileft_403 = tt.broadcast %ileft_402 : tensor<64x1x1xi32, #blocked4> -> tensor<64x2x1xi32, #blocked4> loc(#loc267) + %iright_404 = arith.muli %y_399, %iright : tensor<64x2x1xi32, #blocked4> loc(#loc268) + %iright_405 = "tt.reduce"(%iright_404) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc269)), %iright_742: i32 loc(callsite(#loc1 at #loc269))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc323) + tt.reduce.return %iright_743 : i32 loc(#loc312) + }) : (tensor<64x2x1xi32, #blocked4>) -> tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc312) + %iright_406 = tt.expand_dims %iright_405 {axis = 1 : i32} : tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<64x1x1xi32, #blocked4> loc(#loc270) + %iright_407 = tt.broadcast %iright_406 : tensor<64x1x1xi32, #blocked4> -> tensor<64x2x1xi32, #blocked4> loc(#loc271) + %ileft_408 = tt.reshape %ileft_403 : tensor<64x2x1xi32, #blocked4> -> tensor<8x16xi32, #blocked> loc(#loc272) + %iright_409 = tt.reshape %iright_407 : tensor<64x2x1xi32, #blocked4> -> tensor<8x16xi32, #blocked> loc(#loc273) + %cond_410 = arith.cmpi slt, %ileft_408, %iright_409 : tensor<8x16xi32, #blocked> loc(#loc274) + %eq_411 = arith.cmpi eq, %ileft_408, %iright_409 : tensor<8x16xi32, #blocked> loc(#loc275) + %cond_412 = arith.andi %eq_411, %cond_71 : tensor<8x16xi1, #blocked> loc(#loc276) + %cond_413 = arith.ori %cond_410, %cond_412 : tensor<8x16xi1, #blocked> loc(#loc277) + %cond_414 = arith.extui %cond_413 : tensor<8x16xi1, #blocked> to tensor<8x16xi32, #blocked> loc(#loc278) + %cond_415 = arith.xori %cond_414, %flip_44 : tensor<8x16xi32, #blocked> loc(#loc278) + %cond_416 = arith.cmpi ne, %cond_415, %cst_9 : tensor<8x16xi32, #blocked> loc(#loc279) + %ret_417 = arith.xori %ileft_408, %iright_409 : tensor<8x16xi32, #blocked> loc(#loc280) + %ret_418 = arith.select %cond_416, %ret_417, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc281) + %ret_419 = arith.xori %tmp16, %ret_418 : tensor<8x16xi32, #blocked> loc(#loc282) + %new_idxs_420 = arith.select %cond_416, %new_idxs, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc283) + %new_idxs_421 = arith.xori %new_idxs_81, %new_idxs_420 : tensor<8x16xi32, #blocked> loc(#loc284) + %y_422 = tt.reshape %ret_419 : tensor<8x16xi32, #blocked> -> tensor<32x2x2xi32, #blocked3> loc(#loc263) + %ileft_423 = arith.muli %y_422, %ileft_86 : tensor<32x2x2xi32, #blocked3> loc(#loc264) + %ileft_424 = "tt.reduce"(%ileft_423) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc265)), %ileft_742: i32 loc(callsite(#loc1 at #loc265))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc322) + tt.reduce.return %ileft_743 : i32 loc(#loc310) + }) : (tensor<32x2x2xi32, #blocked3>) -> tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc310) + %ileft_425 = tt.expand_dims %ileft_424 {axis = 1 : i32} : tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x2xi32, #blocked3> loc(#loc266) + %ileft_426 = tt.broadcast %ileft_425 : tensor<32x1x2xi32, #blocked3> -> tensor<32x2x2xi32, #blocked3> loc(#loc267) + %iright_427 = arith.muli %y_422, %flip_43 : tensor<32x2x2xi32, #blocked3> loc(#loc268) + %iright_428 = "tt.reduce"(%iright_427) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc269)), %iright_742: i32 loc(callsite(#loc1 at #loc269))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc323) + tt.reduce.return %iright_743 : i32 loc(#loc312) + }) : (tensor<32x2x2xi32, #blocked3>) -> tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc312) + %iright_429 = tt.expand_dims %iright_428 {axis = 1 : i32} : tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x2xi32, #blocked3> loc(#loc270) + %iright_430 = tt.broadcast %iright_429 : tensor<32x1x2xi32, #blocked3> -> tensor<32x2x2xi32, #blocked3> loc(#loc271) + %ileft_431 = tt.reshape %ileft_426 : tensor<32x2x2xi32, #blocked3> -> tensor<8x16xi32, #blocked> loc(#loc272) + %iright_432 = tt.reshape %iright_430 : tensor<32x2x2xi32, #blocked3> -> tensor<8x16xi32, #blocked> loc(#loc273) + %y_idx_433 = tt.reshape %new_idxs_421 : tensor<8x16xi32, #blocked> -> tensor<32x2x2xi32, #blocked3> loc(#loc285) + %left_idx_434 = arith.muli %y_idx_433, %ileft_86 : tensor<32x2x2xi32, #blocked3> loc(#loc286) + %left_idx_435 = "tt.reduce"(%left_idx_434) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc287)), %left_idx_742: i32 loc(callsite(#loc1 at #loc287))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc324) + tt.reduce.return %left_idx_743 : i32 loc(#loc314) + }) : (tensor<32x2x2xi32, #blocked3>) -> tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc314) + %left_idx_436 = tt.expand_dims %left_idx_435 {axis = 1 : i32} : tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x2xi32, #blocked3> loc(#loc288) + %left_idx_437 = tt.broadcast %left_idx_436 : tensor<32x1x2xi32, #blocked3> -> tensor<32x2x2xi32, #blocked3> loc(#loc289) + %right_idx_438 = arith.muli %y_idx_433, %flip_43 : tensor<32x2x2xi32, #blocked3> loc(#loc290) + %right_idx_439 = "tt.reduce"(%right_idx_438) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc291)), %right_idx_742: i32 loc(callsite(#loc1 at #loc291))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc325) + tt.reduce.return %right_idx_743 : i32 loc(#loc316) + }) : (tensor<32x2x2xi32, #blocked3>) -> tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc316) + %right_idx_440 = tt.expand_dims %right_idx_439 {axis = 1 : i32} : tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x2xi32, #blocked3> loc(#loc292) + %right_idx_441 = tt.broadcast %right_idx_440 : tensor<32x1x2xi32, #blocked3> -> tensor<32x2x2xi32, #blocked3> loc(#loc293) + %left_idx_442 = tt.reshape %left_idx_437 : tensor<32x2x2xi32, #blocked3> -> tensor<8x16xi32, #blocked> loc(#loc294) + %right_idx_443 = tt.reshape %right_idx_441 : tensor<32x2x2xi32, #blocked3> -> tensor<8x16xi32, #blocked> loc(#loc295) + %cond_444 = arith.cmpi slt, %ileft_431, %iright_432 : tensor<8x16xi32, #blocked> loc(#loc274) + %eq_445 = arith.cmpi eq, %ileft_431, %iright_432 : tensor<8x16xi32, #blocked> loc(#loc275) + %cond_446 = arith.cmpi sgt, %left_idx_442, %right_idx_443 : tensor<8x16xi32, #blocked> loc(#loc296) + %cond_447 = arith.andi %eq_445, %cond_446 : tensor<8x16xi1, #blocked> loc(#loc276) + %cond_448 = arith.ori %cond_444, %cond_447 : tensor<8x16xi1, #blocked> loc(#loc277) + %cond_449 = arith.extui %cond_448 : tensor<8x16xi1, #blocked> to tensor<8x16xi32, #blocked> loc(#loc278) + %cond_450 = arith.xori %cond_449, %flip_84 : tensor<8x16xi32, #blocked> loc(#loc278) + %cond_451 = arith.cmpi ne, %cond_450, %cst_9 : tensor<8x16xi32, #blocked> loc(#loc279) + %ret_452 = arith.xori %ileft_431, %iright_432 : tensor<8x16xi32, #blocked> loc(#loc280) + %ret_453 = arith.select %cond_451, %ret_452, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc281) + %ret_454 = arith.xori %ret_419, %ret_453 : tensor<8x16xi32, #blocked> loc(#loc282) + %new_idxs_455 = arith.xori %left_idx_442, %right_idx_443 : tensor<8x16xi32, #blocked> loc(#loc297) + %new_idxs_456 = arith.select %cond_451, %new_idxs_455, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc283) + %new_idxs_457 = arith.xori %new_idxs_421, %new_idxs_456 : tensor<8x16xi32, #blocked> loc(#loc284) + %y_458 = tt.reshape %ret_454 : tensor<8x16xi32, #blocked> -> tensor<64x2x1xi32, #blocked4> loc(#loc263) + %ileft_459 = arith.muli %y_458, %ileft : tensor<64x2x1xi32, #blocked4> loc(#loc264) + %ileft_460 = "tt.reduce"(%ileft_459) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc265)), %ileft_742: i32 loc(callsite(#loc1 at #loc265))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc322) + tt.reduce.return %ileft_743 : i32 loc(#loc310) + }) : (tensor<64x2x1xi32, #blocked4>) -> tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc310) + %ileft_461 = tt.expand_dims %ileft_460 {axis = 1 : i32} : tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<64x1x1xi32, #blocked4> loc(#loc266) + %ileft_462 = tt.broadcast %ileft_461 : tensor<64x1x1xi32, #blocked4> -> tensor<64x2x1xi32, #blocked4> loc(#loc267) + %iright_463 = arith.muli %y_458, %iright : tensor<64x2x1xi32, #blocked4> loc(#loc268) + %iright_464 = "tt.reduce"(%iright_463) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc269)), %iright_742: i32 loc(callsite(#loc1 at #loc269))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc323) + tt.reduce.return %iright_743 : i32 loc(#loc312) + }) : (tensor<64x2x1xi32, #blocked4>) -> tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc312) + %iright_465 = tt.expand_dims %iright_464 {axis = 1 : i32} : tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<64x1x1xi32, #blocked4> loc(#loc270) + %iright_466 = tt.broadcast %iright_465 : tensor<64x1x1xi32, #blocked4> -> tensor<64x2x1xi32, #blocked4> loc(#loc271) + %ileft_467 = tt.reshape %ileft_462 : tensor<64x2x1xi32, #blocked4> -> tensor<8x16xi32, #blocked> loc(#loc272) + %iright_468 = tt.reshape %iright_466 : tensor<64x2x1xi32, #blocked4> -> tensor<8x16xi32, #blocked> loc(#loc273) + %y_idx_469 = tt.reshape %new_idxs_457 : tensor<8x16xi32, #blocked> -> tensor<64x2x1xi32, #blocked4> loc(#loc285) + %left_idx_470 = arith.muli %y_idx_469, %ileft : tensor<64x2x1xi32, #blocked4> loc(#loc286) + %left_idx_471 = "tt.reduce"(%left_idx_470) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc287)), %left_idx_742: i32 loc(callsite(#loc1 at #loc287))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc324) + tt.reduce.return %left_idx_743 : i32 loc(#loc314) + }) : (tensor<64x2x1xi32, #blocked4>) -> tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc314) + %left_idx_472 = tt.expand_dims %left_idx_471 {axis = 1 : i32} : tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<64x1x1xi32, #blocked4> loc(#loc288) + %left_idx_473 = tt.broadcast %left_idx_472 : tensor<64x1x1xi32, #blocked4> -> tensor<64x2x1xi32, #blocked4> loc(#loc289) + %right_idx_474 = arith.muli %y_idx_469, %iright : tensor<64x2x1xi32, #blocked4> loc(#loc290) + %right_idx_475 = "tt.reduce"(%right_idx_474) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc291)), %right_idx_742: i32 loc(callsite(#loc1 at #loc291))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc325) + tt.reduce.return %right_idx_743 : i32 loc(#loc316) + }) : (tensor<64x2x1xi32, #blocked4>) -> tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc316) + %right_idx_476 = tt.expand_dims %right_idx_475 {axis = 1 : i32} : tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<64x1x1xi32, #blocked4> loc(#loc292) + %right_idx_477 = tt.broadcast %right_idx_476 : tensor<64x1x1xi32, #blocked4> -> tensor<64x2x1xi32, #blocked4> loc(#loc293) + %left_idx_478 = tt.reshape %left_idx_473 : tensor<64x2x1xi32, #blocked4> -> tensor<8x16xi32, #blocked> loc(#loc294) + %right_idx_479 = tt.reshape %right_idx_477 : tensor<64x2x1xi32, #blocked4> -> tensor<8x16xi32, #blocked> loc(#loc295) + %cond_480 = arith.cmpi slt, %ileft_467, %iright_468 : tensor<8x16xi32, #blocked> loc(#loc274) + %eq_481 = arith.cmpi eq, %ileft_467, %iright_468 : tensor<8x16xi32, #blocked> loc(#loc275) + %cond_482 = arith.cmpi sgt, %left_idx_478, %right_idx_479 : tensor<8x16xi32, #blocked> loc(#loc296) + %cond_483 = arith.andi %eq_481, %cond_482 : tensor<8x16xi1, #blocked> loc(#loc276) + %cond_484 = arith.ori %cond_480, %cond_483 : tensor<8x16xi1, #blocked> loc(#loc277) + %cond_485 = arith.extui %cond_484 : tensor<8x16xi1, #blocked> to tensor<8x16xi32, #blocked> loc(#loc278) + %cond_486 = arith.xori %cond_485, %flip_84 : tensor<8x16xi32, #blocked> loc(#loc278) + %cond_487 = arith.cmpi ne, %cond_486, %cst_9 : tensor<8x16xi32, #blocked> loc(#loc279) + %ret_488 = arith.xori %ileft_467, %iright_468 : tensor<8x16xi32, #blocked> loc(#loc280) + %ret_489 = arith.select %cond_487, %ret_488, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc281) + %ret_490 = arith.xori %ret_454, %ret_489 : tensor<8x16xi32, #blocked> loc(#loc282) + %new_idxs_491 = arith.xori %left_idx_478, %right_idx_479 : tensor<8x16xi32, #blocked> loc(#loc297) + %new_idxs_492 = arith.select %cond_487, %new_idxs_491, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc283) + %new_idxs_493 = arith.xori %new_idxs_457, %new_idxs_492 : tensor<8x16xi32, #blocked> loc(#loc284) + %y_494 = tt.reshape %ret_490 : tensor<8x16xi32, #blocked> -> tensor<16x2x4xi32, #blocked2> loc(#loc263) + %ileft_495 = arith.muli %y_494, %ileft_161 : tensor<16x2x4xi32, #blocked2> loc(#loc264) + %ileft_496 = "tt.reduce"(%ileft_495) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc265)), %ileft_742: i32 loc(callsite(#loc1 at #loc265))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc322) + tt.reduce.return %ileft_743 : i32 loc(#loc310) + }) : (tensor<16x2x4xi32, #blocked2>) -> tensor<16x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc310) + %ileft_497 = tt.expand_dims %ileft_496 {axis = 1 : i32} : tensor<16x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1x4xi32, #blocked2> loc(#loc266) + %ileft_498 = tt.broadcast %ileft_497 : tensor<16x1x4xi32, #blocked2> -> tensor<16x2x4xi32, #blocked2> loc(#loc267) + %iright_499 = arith.muli %y_494, %flip_83 : tensor<16x2x4xi32, #blocked2> loc(#loc268) + %iright_500 = "tt.reduce"(%iright_499) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc269)), %iright_742: i32 loc(callsite(#loc1 at #loc269))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc323) + tt.reduce.return %iright_743 : i32 loc(#loc312) + }) : (tensor<16x2x4xi32, #blocked2>) -> tensor<16x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc312) + %iright_501 = tt.expand_dims %iright_500 {axis = 1 : i32} : tensor<16x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1x4xi32, #blocked2> loc(#loc270) + %iright_502 = tt.broadcast %iright_501 : tensor<16x1x4xi32, #blocked2> -> tensor<16x2x4xi32, #blocked2> loc(#loc271) + %ileft_503 = tt.reshape %ileft_498 : tensor<16x2x4xi32, #blocked2> -> tensor<8x16xi32, #blocked> loc(#loc272) + %iright_504 = tt.reshape %iright_502 : tensor<16x2x4xi32, #blocked2> -> tensor<8x16xi32, #blocked> loc(#loc273) + %y_idx_505 = tt.reshape %new_idxs_493 : tensor<8x16xi32, #blocked> -> tensor<16x2x4xi32, #blocked2> loc(#loc285) + %left_idx_506 = arith.muli %y_idx_505, %ileft_161 : tensor<16x2x4xi32, #blocked2> loc(#loc286) + %left_idx_507 = "tt.reduce"(%left_idx_506) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc287)), %left_idx_742: i32 loc(callsite(#loc1 at #loc287))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc324) + tt.reduce.return %left_idx_743 : i32 loc(#loc314) + }) : (tensor<16x2x4xi32, #blocked2>) -> tensor<16x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc314) + %left_idx_508 = tt.expand_dims %left_idx_507 {axis = 1 : i32} : tensor<16x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1x4xi32, #blocked2> loc(#loc288) + %left_idx_509 = tt.broadcast %left_idx_508 : tensor<16x1x4xi32, #blocked2> -> tensor<16x2x4xi32, #blocked2> loc(#loc289) + %right_idx_510 = arith.muli %y_idx_505, %flip_83 : tensor<16x2x4xi32, #blocked2> loc(#loc290) + %right_idx_511 = "tt.reduce"(%right_idx_510) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc291)), %right_idx_742: i32 loc(callsite(#loc1 at #loc291))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc325) + tt.reduce.return %right_idx_743 : i32 loc(#loc316) + }) : (tensor<16x2x4xi32, #blocked2>) -> tensor<16x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc316) + %right_idx_512 = tt.expand_dims %right_idx_511 {axis = 1 : i32} : tensor<16x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1x4xi32, #blocked2> loc(#loc292) + %right_idx_513 = tt.broadcast %right_idx_512 : tensor<16x1x4xi32, #blocked2> -> tensor<16x2x4xi32, #blocked2> loc(#loc293) + %left_idx_514 = tt.reshape %left_idx_509 : tensor<16x2x4xi32, #blocked2> -> tensor<8x16xi32, #blocked> loc(#loc294) + %right_idx_515 = tt.reshape %right_idx_513 : tensor<16x2x4xi32, #blocked2> -> tensor<8x16xi32, #blocked> loc(#loc295) + %cond_516 = arith.cmpi slt, %ileft_503, %iright_504 : tensor<8x16xi32, #blocked> loc(#loc274) + %eq_517 = arith.cmpi eq, %ileft_503, %iright_504 : tensor<8x16xi32, #blocked> loc(#loc275) + %cond_518 = arith.cmpi sgt, %left_idx_514, %right_idx_515 : tensor<8x16xi32, #blocked> loc(#loc296) + %cond_519 = arith.andi %eq_517, %cond_518 : tensor<8x16xi1, #blocked> loc(#loc276) + %cond_520 = arith.ori %cond_516, %cond_519 : tensor<8x16xi1, #blocked> loc(#loc277) + %cond_521 = arith.extui %cond_520 : tensor<8x16xi1, #blocked> to tensor<8x16xi32, #blocked> loc(#loc278) + %cond_522 = arith.xori %cond_521, %flip_159 : tensor<8x16xi32, #blocked> loc(#loc278) + %cond_523 = arith.cmpi ne, %cond_522, %cst_9 : tensor<8x16xi32, #blocked> loc(#loc279) + %ret_524 = arith.xori %ileft_503, %iright_504 : tensor<8x16xi32, #blocked> loc(#loc280) + %ret_525 = arith.select %cond_523, %ret_524, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc281) + %ret_526 = arith.xori %ret_490, %ret_525 : tensor<8x16xi32, #blocked> loc(#loc282) + %new_idxs_527 = arith.xori %left_idx_514, %right_idx_515 : tensor<8x16xi32, #blocked> loc(#loc297) + %new_idxs_528 = arith.select %cond_523, %new_idxs_527, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc283) + %new_idxs_529 = arith.xori %new_idxs_493, %new_idxs_528 : tensor<8x16xi32, #blocked> loc(#loc284) + %y_530 = tt.reshape %ret_526 : tensor<8x16xi32, #blocked> -> tensor<32x2x2xi32, #blocked3> loc(#loc263) + %ileft_531 = arith.muli %y_530, %ileft_86 : tensor<32x2x2xi32, #blocked3> loc(#loc264) + %ileft_532 = "tt.reduce"(%ileft_531) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc265)), %ileft_742: i32 loc(callsite(#loc1 at #loc265))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc322) + tt.reduce.return %ileft_743 : i32 loc(#loc310) + }) : (tensor<32x2x2xi32, #blocked3>) -> tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc310) + %ileft_533 = tt.expand_dims %ileft_532 {axis = 1 : i32} : tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x2xi32, #blocked3> loc(#loc266) + %ileft_534 = tt.broadcast %ileft_533 : tensor<32x1x2xi32, #blocked3> -> tensor<32x2x2xi32, #blocked3> loc(#loc267) + %iright_535 = arith.muli %y_530, %flip_43 : tensor<32x2x2xi32, #blocked3> loc(#loc268) + %iright_536 = "tt.reduce"(%iright_535) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc269)), %iright_742: i32 loc(callsite(#loc1 at #loc269))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc323) + tt.reduce.return %iright_743 : i32 loc(#loc312) + }) : (tensor<32x2x2xi32, #blocked3>) -> tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc312) + %iright_537 = tt.expand_dims %iright_536 {axis = 1 : i32} : tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x2xi32, #blocked3> loc(#loc270) + %iright_538 = tt.broadcast %iright_537 : tensor<32x1x2xi32, #blocked3> -> tensor<32x2x2xi32, #blocked3> loc(#loc271) + %ileft_539 = tt.reshape %ileft_534 : tensor<32x2x2xi32, #blocked3> -> tensor<8x16xi32, #blocked> loc(#loc272) + %iright_540 = tt.reshape %iright_538 : tensor<32x2x2xi32, #blocked3> -> tensor<8x16xi32, #blocked> loc(#loc273) + %y_idx_541 = tt.reshape %new_idxs_529 : tensor<8x16xi32, #blocked> -> tensor<32x2x2xi32, #blocked3> loc(#loc285) + %left_idx_542 = arith.muli %y_idx_541, %ileft_86 : tensor<32x2x2xi32, #blocked3> loc(#loc286) + %left_idx_543 = "tt.reduce"(%left_idx_542) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc287)), %left_idx_742: i32 loc(callsite(#loc1 at #loc287))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc324) + tt.reduce.return %left_idx_743 : i32 loc(#loc314) + }) : (tensor<32x2x2xi32, #blocked3>) -> tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc314) + %left_idx_544 = tt.expand_dims %left_idx_543 {axis = 1 : i32} : tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x2xi32, #blocked3> loc(#loc288) + %left_idx_545 = tt.broadcast %left_idx_544 : tensor<32x1x2xi32, #blocked3> -> tensor<32x2x2xi32, #blocked3> loc(#loc289) + %right_idx_546 = arith.muli %y_idx_541, %flip_43 : tensor<32x2x2xi32, #blocked3> loc(#loc290) + %right_idx_547 = "tt.reduce"(%right_idx_546) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc291)), %right_idx_742: i32 loc(callsite(#loc1 at #loc291))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc325) + tt.reduce.return %right_idx_743 : i32 loc(#loc316) + }) : (tensor<32x2x2xi32, #blocked3>) -> tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc316) + %right_idx_548 = tt.expand_dims %right_idx_547 {axis = 1 : i32} : tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x2xi32, #blocked3> loc(#loc292) + %right_idx_549 = tt.broadcast %right_idx_548 : tensor<32x1x2xi32, #blocked3> -> tensor<32x2x2xi32, #blocked3> loc(#loc293) + %left_idx_550 = tt.reshape %left_idx_545 : tensor<32x2x2xi32, #blocked3> -> tensor<8x16xi32, #blocked> loc(#loc294) + %right_idx_551 = tt.reshape %right_idx_549 : tensor<32x2x2xi32, #blocked3> -> tensor<8x16xi32, #blocked> loc(#loc295) + %cond_552 = arith.cmpi slt, %ileft_539, %iright_540 : tensor<8x16xi32, #blocked> loc(#loc274) + %eq_553 = arith.cmpi eq, %ileft_539, %iright_540 : tensor<8x16xi32, #blocked> loc(#loc275) + %cond_554 = arith.cmpi sgt, %left_idx_550, %right_idx_551 : tensor<8x16xi32, #blocked> loc(#loc296) + %cond_555 = arith.andi %eq_553, %cond_554 : tensor<8x16xi1, #blocked> loc(#loc276) + %cond_556 = arith.ori %cond_552, %cond_555 : tensor<8x16xi1, #blocked> loc(#loc277) + %cond_557 = arith.extui %cond_556 : tensor<8x16xi1, #blocked> to tensor<8x16xi32, #blocked> loc(#loc278) + %cond_558 = arith.xori %cond_557, %flip_159 : tensor<8x16xi32, #blocked> loc(#loc278) + %cond_559 = arith.cmpi ne, %cond_558, %cst_9 : tensor<8x16xi32, #blocked> loc(#loc279) + %ret_560 = arith.xori %ileft_539, %iright_540 : tensor<8x16xi32, #blocked> loc(#loc280) + %ret_561 = arith.select %cond_559, %ret_560, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc281) + %ret_562 = arith.xori %ret_526, %ret_561 : tensor<8x16xi32, #blocked> loc(#loc282) + %new_idxs_563 = arith.xori %left_idx_550, %right_idx_551 : tensor<8x16xi32, #blocked> loc(#loc297) + %new_idxs_564 = arith.select %cond_559, %new_idxs_563, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc283) + %new_idxs_565 = arith.xori %new_idxs_529, %new_idxs_564 : tensor<8x16xi32, #blocked> loc(#loc284) + %y_566 = tt.reshape %ret_562 : tensor<8x16xi32, #blocked> -> tensor<64x2x1xi32, #blocked4> loc(#loc263) + %ileft_567 = arith.muli %y_566, %ileft : tensor<64x2x1xi32, #blocked4> loc(#loc264) + %ileft_568 = "tt.reduce"(%ileft_567) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc265)), %ileft_742: i32 loc(callsite(#loc1 at #loc265))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc322) + tt.reduce.return %ileft_743 : i32 loc(#loc310) + }) : (tensor<64x2x1xi32, #blocked4>) -> tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc310) + %ileft_569 = tt.expand_dims %ileft_568 {axis = 1 : i32} : tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<64x1x1xi32, #blocked4> loc(#loc266) + %ileft_570 = tt.broadcast %ileft_569 : tensor<64x1x1xi32, #blocked4> -> tensor<64x2x1xi32, #blocked4> loc(#loc267) + %iright_571 = arith.muli %y_566, %iright : tensor<64x2x1xi32, #blocked4> loc(#loc268) + %iright_572 = "tt.reduce"(%iright_571) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc269)), %iright_742: i32 loc(callsite(#loc1 at #loc269))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc323) + tt.reduce.return %iright_743 : i32 loc(#loc312) + }) : (tensor<64x2x1xi32, #blocked4>) -> tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc312) + %iright_573 = tt.expand_dims %iright_572 {axis = 1 : i32} : tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<64x1x1xi32, #blocked4> loc(#loc270) + %iright_574 = tt.broadcast %iright_573 : tensor<64x1x1xi32, #blocked4> -> tensor<64x2x1xi32, #blocked4> loc(#loc271) + %ileft_575 = tt.reshape %ileft_570 : tensor<64x2x1xi32, #blocked4> -> tensor<8x16xi32, #blocked> loc(#loc272) + %iright_576 = tt.reshape %iright_574 : tensor<64x2x1xi32, #blocked4> -> tensor<8x16xi32, #blocked> loc(#loc273) + %y_idx_577 = tt.reshape %new_idxs_565 : tensor<8x16xi32, #blocked> -> tensor<64x2x1xi32, #blocked4> loc(#loc285) + %left_idx_578 = arith.muli %y_idx_577, %ileft : tensor<64x2x1xi32, #blocked4> loc(#loc286) + %left_idx_579 = "tt.reduce"(%left_idx_578) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc287)), %left_idx_742: i32 loc(callsite(#loc1 at #loc287))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc324) + tt.reduce.return %left_idx_743 : i32 loc(#loc314) + }) : (tensor<64x2x1xi32, #blocked4>) -> tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc314) + %left_idx_580 = tt.expand_dims %left_idx_579 {axis = 1 : i32} : tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<64x1x1xi32, #blocked4> loc(#loc288) + %left_idx_581 = tt.broadcast %left_idx_580 : tensor<64x1x1xi32, #blocked4> -> tensor<64x2x1xi32, #blocked4> loc(#loc289) + %right_idx_582 = arith.muli %y_idx_577, %iright : tensor<64x2x1xi32, #blocked4> loc(#loc290) + %right_idx_583 = "tt.reduce"(%right_idx_582) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc291)), %right_idx_742: i32 loc(callsite(#loc1 at #loc291))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc325) + tt.reduce.return %right_idx_743 : i32 loc(#loc316) + }) : (tensor<64x2x1xi32, #blocked4>) -> tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc316) + %right_idx_584 = tt.expand_dims %right_idx_583 {axis = 1 : i32} : tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<64x1x1xi32, #blocked4> loc(#loc292) + %right_idx_585 = tt.broadcast %right_idx_584 : tensor<64x1x1xi32, #blocked4> -> tensor<64x2x1xi32, #blocked4> loc(#loc293) + %left_idx_586 = tt.reshape %left_idx_581 : tensor<64x2x1xi32, #blocked4> -> tensor<8x16xi32, #blocked> loc(#loc294) + %right_idx_587 = tt.reshape %right_idx_585 : tensor<64x2x1xi32, #blocked4> -> tensor<8x16xi32, #blocked> loc(#loc295) + %cond_588 = arith.cmpi slt, %ileft_575, %iright_576 : tensor<8x16xi32, #blocked> loc(#loc274) + %eq_589 = arith.cmpi eq, %ileft_575, %iright_576 : tensor<8x16xi32, #blocked> loc(#loc275) + %cond_590 = arith.cmpi sgt, %left_idx_586, %right_idx_587 : tensor<8x16xi32, #blocked> loc(#loc296) + %cond_591 = arith.andi %eq_589, %cond_590 : tensor<8x16xi1, #blocked> loc(#loc276) + %cond_592 = arith.ori %cond_588, %cond_591 : tensor<8x16xi1, #blocked> loc(#loc277) + %cond_593 = arith.extui %cond_592 : tensor<8x16xi1, #blocked> to tensor<8x16xi32, #blocked> loc(#loc278) + %cond_594 = arith.xori %cond_593, %flip_159 : tensor<8x16xi32, #blocked> loc(#loc278) + %cond_595 = arith.cmpi ne, %cond_594, %cst_9 : tensor<8x16xi32, #blocked> loc(#loc279) + %ret_596 = arith.xori %ileft_575, %iright_576 : tensor<8x16xi32, #blocked> loc(#loc280) + %ret_597 = arith.select %cond_595, %ret_596, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc281) + %ret_598 = arith.xori %ret_562, %ret_597 : tensor<8x16xi32, #blocked> loc(#loc282) + %new_idxs_599 = arith.xori %left_idx_586, %right_idx_587 : tensor<8x16xi32, #blocked> loc(#loc297) + %new_idxs_600 = arith.select %cond_595, %new_idxs_599, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc283) + %new_idxs_601 = arith.xori %new_idxs_565, %new_idxs_600 : tensor<8x16xi32, #blocked> loc(#loc284) + %y_602 = tt.reshape %ret_598 : tensor<8x16xi32, #blocked> -> tensor<8x2x8xi32, #blocked1> loc(#loc263) + %ileft_603 = arith.muli %y_602, %ileft_270 : tensor<8x2x8xi32, #blocked1> loc(#loc264) + %ileft_604 = "tt.reduce"(%ileft_603) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc265)), %ileft_742: i32 loc(callsite(#loc1 at #loc265))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc322) + tt.reduce.return %ileft_743 : i32 loc(#loc310) + }) : (tensor<8x2x8xi32, #blocked1>) -> tensor<8x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc310) + %ileft_605 = tt.expand_dims %ileft_604 {axis = 1 : i32} : tensor<8x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<8x1x8xi32, #blocked1> loc(#loc266) + %ileft_606 = tt.broadcast %ileft_605 : tensor<8x1x8xi32, #blocked1> -> tensor<8x2x8xi32, #blocked1> loc(#loc267) + %iright_607 = arith.muli %y_602, %flip_158 : tensor<8x2x8xi32, #blocked1> loc(#loc268) + %iright_608 = "tt.reduce"(%iright_607) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc269)), %iright_742: i32 loc(callsite(#loc1 at #loc269))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc323) + tt.reduce.return %iright_743 : i32 loc(#loc312) + }) : (tensor<8x2x8xi32, #blocked1>) -> tensor<8x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc312) + %iright_609 = tt.expand_dims %iright_608 {axis = 1 : i32} : tensor<8x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<8x1x8xi32, #blocked1> loc(#loc270) + %iright_610 = tt.broadcast %iright_609 : tensor<8x1x8xi32, #blocked1> -> tensor<8x2x8xi32, #blocked1> loc(#loc271) + %ileft_611 = tt.reshape %ileft_606 : tensor<8x2x8xi32, #blocked1> -> tensor<8x16xi32, #blocked> loc(#loc272) + %iright_612 = tt.reshape %iright_610 : tensor<8x2x8xi32, #blocked1> -> tensor<8x16xi32, #blocked> loc(#loc273) + %y_idx_613 = tt.reshape %new_idxs_601 : tensor<8x16xi32, #blocked> -> tensor<8x2x8xi32, #blocked1> loc(#loc285) + %left_idx_614 = arith.muli %y_idx_613, %ileft_270 : tensor<8x2x8xi32, #blocked1> loc(#loc286) + %left_idx_615 = "tt.reduce"(%left_idx_614) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc287)), %left_idx_742: i32 loc(callsite(#loc1 at #loc287))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc324) + tt.reduce.return %left_idx_743 : i32 loc(#loc314) + }) : (tensor<8x2x8xi32, #blocked1>) -> tensor<8x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc314) + %left_idx_616 = tt.expand_dims %left_idx_615 {axis = 1 : i32} : tensor<8x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<8x1x8xi32, #blocked1> loc(#loc288) + %left_idx_617 = tt.broadcast %left_idx_616 : tensor<8x1x8xi32, #blocked1> -> tensor<8x2x8xi32, #blocked1> loc(#loc289) + %right_idx_618 = arith.muli %y_idx_613, %flip_158 : tensor<8x2x8xi32, #blocked1> loc(#loc290) + %right_idx_619 = "tt.reduce"(%right_idx_618) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc291)), %right_idx_742: i32 loc(callsite(#loc1 at #loc291))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc325) + tt.reduce.return %right_idx_743 : i32 loc(#loc316) + }) : (tensor<8x2x8xi32, #blocked1>) -> tensor<8x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc316) + %right_idx_620 = tt.expand_dims %right_idx_619 {axis = 1 : i32} : tensor<8x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<8x1x8xi32, #blocked1> loc(#loc292) + %right_idx_621 = tt.broadcast %right_idx_620 : tensor<8x1x8xi32, #blocked1> -> tensor<8x2x8xi32, #blocked1> loc(#loc293) + %left_idx_622 = tt.reshape %left_idx_617 : tensor<8x2x8xi32, #blocked1> -> tensor<8x16xi32, #blocked> loc(#loc294) + %right_idx_623 = tt.reshape %right_idx_621 : tensor<8x2x8xi32, #blocked1> -> tensor<8x16xi32, #blocked> loc(#loc295) + %cond_624 = arith.cmpi slt, %ileft_611, %iright_612 : tensor<8x16xi32, #blocked> loc(#loc274) + %eq_625 = arith.cmpi eq, %ileft_611, %iright_612 : tensor<8x16xi32, #blocked> loc(#loc275) + %cond_626 = arith.cmpi sgt, %left_idx_622, %right_idx_623 : tensor<8x16xi32, #blocked> loc(#loc296) + %cond_627 = arith.andi %eq_625, %cond_626 : tensor<8x16xi1, #blocked> loc(#loc276) + %cond_628 = arith.ori %cond_624, %cond_627 : tensor<8x16xi1, #blocked> loc(#loc277) + %ret_629 = arith.xori %ileft_611, %iright_612 : tensor<8x16xi32, #blocked> loc(#loc280) + %ret_630 = arith.select %cond_628, %ret_629, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc281) + %ret_631 = arith.xori %ret_598, %ret_630 : tensor<8x16xi32, #blocked> loc(#loc282) + %new_idxs_632 = arith.xori %left_idx_622, %right_idx_623 : tensor<8x16xi32, #blocked> loc(#loc297) + %new_idxs_633 = arith.select %cond_628, %new_idxs_632, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc283) + %new_idxs_634 = arith.xori %new_idxs_601, %new_idxs_633 : tensor<8x16xi32, #blocked> loc(#loc284) + %y_635 = tt.reshape %ret_631 : tensor<8x16xi32, #blocked> -> tensor<16x2x4xi32, #blocked2> loc(#loc263) + %ileft_636 = arith.muli %y_635, %ileft_161 : tensor<16x2x4xi32, #blocked2> loc(#loc264) + %ileft_637 = "tt.reduce"(%ileft_636) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc265)), %ileft_742: i32 loc(callsite(#loc1 at #loc265))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc322) + tt.reduce.return %ileft_743 : i32 loc(#loc310) + }) : (tensor<16x2x4xi32, #blocked2>) -> tensor<16x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc310) + %ileft_638 = tt.expand_dims %ileft_637 {axis = 1 : i32} : tensor<16x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1x4xi32, #blocked2> loc(#loc266) + %ileft_639 = tt.broadcast %ileft_638 : tensor<16x1x4xi32, #blocked2> -> tensor<16x2x4xi32, #blocked2> loc(#loc267) + %iright_640 = arith.muli %y_635, %flip_83 : tensor<16x2x4xi32, #blocked2> loc(#loc268) + %iright_641 = "tt.reduce"(%iright_640) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc269)), %iright_742: i32 loc(callsite(#loc1 at #loc269))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc323) + tt.reduce.return %iright_743 : i32 loc(#loc312) + }) : (tensor<16x2x4xi32, #blocked2>) -> tensor<16x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc312) + %iright_642 = tt.expand_dims %iright_641 {axis = 1 : i32} : tensor<16x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1x4xi32, #blocked2> loc(#loc270) + %iright_643 = tt.broadcast %iright_642 : tensor<16x1x4xi32, #blocked2> -> tensor<16x2x4xi32, #blocked2> loc(#loc271) + %ileft_644 = tt.reshape %ileft_639 : tensor<16x2x4xi32, #blocked2> -> tensor<8x16xi32, #blocked> loc(#loc272) + %iright_645 = tt.reshape %iright_643 : tensor<16x2x4xi32, #blocked2> -> tensor<8x16xi32, #blocked> loc(#loc273) + %y_idx_646 = tt.reshape %new_idxs_634 : tensor<8x16xi32, #blocked> -> tensor<16x2x4xi32, #blocked2> loc(#loc285) + %left_idx_647 = arith.muli %y_idx_646, %ileft_161 : tensor<16x2x4xi32, #blocked2> loc(#loc286) + %left_idx_648 = "tt.reduce"(%left_idx_647) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc287)), %left_idx_742: i32 loc(callsite(#loc1 at #loc287))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc324) + tt.reduce.return %left_idx_743 : i32 loc(#loc314) + }) : (tensor<16x2x4xi32, #blocked2>) -> tensor<16x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc314) + %left_idx_649 = tt.expand_dims %left_idx_648 {axis = 1 : i32} : tensor<16x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1x4xi32, #blocked2> loc(#loc288) + %left_idx_650 = tt.broadcast %left_idx_649 : tensor<16x1x4xi32, #blocked2> -> tensor<16x2x4xi32, #blocked2> loc(#loc289) + %right_idx_651 = arith.muli %y_idx_646, %flip_83 : tensor<16x2x4xi32, #blocked2> loc(#loc290) + %right_idx_652 = "tt.reduce"(%right_idx_651) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc291)), %right_idx_742: i32 loc(callsite(#loc1 at #loc291))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc325) + tt.reduce.return %right_idx_743 : i32 loc(#loc316) + }) : (tensor<16x2x4xi32, #blocked2>) -> tensor<16x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc316) + %right_idx_653 = tt.expand_dims %right_idx_652 {axis = 1 : i32} : tensor<16x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1x4xi32, #blocked2> loc(#loc292) + %right_idx_654 = tt.broadcast %right_idx_653 : tensor<16x1x4xi32, #blocked2> -> tensor<16x2x4xi32, #blocked2> loc(#loc293) + %left_idx_655 = tt.reshape %left_idx_650 : tensor<16x2x4xi32, #blocked2> -> tensor<8x16xi32, #blocked> loc(#loc294) + %right_idx_656 = tt.reshape %right_idx_654 : tensor<16x2x4xi32, #blocked2> -> tensor<8x16xi32, #blocked> loc(#loc295) + %cond_657 = arith.cmpi slt, %ileft_644, %iright_645 : tensor<8x16xi32, #blocked> loc(#loc274) + %eq_658 = arith.cmpi eq, %ileft_644, %iright_645 : tensor<8x16xi32, #blocked> loc(#loc275) + %cond_659 = arith.cmpi sgt, %left_idx_655, %right_idx_656 : tensor<8x16xi32, #blocked> loc(#loc296) + %cond_660 = arith.andi %eq_658, %cond_659 : tensor<8x16xi1, #blocked> loc(#loc276) + %cond_661 = arith.ori %cond_657, %cond_660 : tensor<8x16xi1, #blocked> loc(#loc277) + %ret_662 = arith.xori %ileft_644, %iright_645 : tensor<8x16xi32, #blocked> loc(#loc280) + %ret_663 = arith.select %cond_661, %ret_662, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc281) + %ret_664 = arith.xori %ret_631, %ret_663 : tensor<8x16xi32, #blocked> loc(#loc282) + %new_idxs_665 = arith.xori %left_idx_655, %right_idx_656 : tensor<8x16xi32, #blocked> loc(#loc297) + %new_idxs_666 = arith.select %cond_661, %new_idxs_665, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc283) + %new_idxs_667 = arith.xori %new_idxs_634, %new_idxs_666 : tensor<8x16xi32, #blocked> loc(#loc284) + %y_668 = tt.reshape %ret_664 : tensor<8x16xi32, #blocked> -> tensor<32x2x2xi32, #blocked3> loc(#loc263) + %ileft_669 = arith.muli %y_668, %ileft_86 : tensor<32x2x2xi32, #blocked3> loc(#loc264) + %ileft_670 = "tt.reduce"(%ileft_669) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc265)), %ileft_742: i32 loc(callsite(#loc1 at #loc265))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc322) + tt.reduce.return %ileft_743 : i32 loc(#loc310) + }) : (tensor<32x2x2xi32, #blocked3>) -> tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc310) + %ileft_671 = tt.expand_dims %ileft_670 {axis = 1 : i32} : tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x2xi32, #blocked3> loc(#loc266) + %ileft_672 = tt.broadcast %ileft_671 : tensor<32x1x2xi32, #blocked3> -> tensor<32x2x2xi32, #blocked3> loc(#loc267) + %iright_673 = arith.muli %y_668, %flip_43 : tensor<32x2x2xi32, #blocked3> loc(#loc268) + %iright_674 = "tt.reduce"(%iright_673) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc269)), %iright_742: i32 loc(callsite(#loc1 at #loc269))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc323) + tt.reduce.return %iright_743 : i32 loc(#loc312) + }) : (tensor<32x2x2xi32, #blocked3>) -> tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc312) + %iright_675 = tt.expand_dims %iright_674 {axis = 1 : i32} : tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x2xi32, #blocked3> loc(#loc270) + %iright_676 = tt.broadcast %iright_675 : tensor<32x1x2xi32, #blocked3> -> tensor<32x2x2xi32, #blocked3> loc(#loc271) + %ileft_677 = tt.reshape %ileft_672 : tensor<32x2x2xi32, #blocked3> -> tensor<8x16xi32, #blocked> loc(#loc272) + %iright_678 = tt.reshape %iright_676 : tensor<32x2x2xi32, #blocked3> -> tensor<8x16xi32, #blocked> loc(#loc273) + %y_idx_679 = tt.reshape %new_idxs_667 : tensor<8x16xi32, #blocked> -> tensor<32x2x2xi32, #blocked3> loc(#loc285) + %left_idx_680 = arith.muli %y_idx_679, %ileft_86 : tensor<32x2x2xi32, #blocked3> loc(#loc286) + %left_idx_681 = "tt.reduce"(%left_idx_680) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc287)), %left_idx_742: i32 loc(callsite(#loc1 at #loc287))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc324) + tt.reduce.return %left_idx_743 : i32 loc(#loc314) + }) : (tensor<32x2x2xi32, #blocked3>) -> tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc314) + %left_idx_682 = tt.expand_dims %left_idx_681 {axis = 1 : i32} : tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x2xi32, #blocked3> loc(#loc288) + %left_idx_683 = tt.broadcast %left_idx_682 : tensor<32x1x2xi32, #blocked3> -> tensor<32x2x2xi32, #blocked3> loc(#loc289) + %right_idx_684 = arith.muli %y_idx_679, %flip_43 : tensor<32x2x2xi32, #blocked3> loc(#loc290) + %right_idx_685 = "tt.reduce"(%right_idx_684) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc291)), %right_idx_742: i32 loc(callsite(#loc1 at #loc291))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc325) + tt.reduce.return %right_idx_743 : i32 loc(#loc316) + }) : (tensor<32x2x2xi32, #blocked3>) -> tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc316) + %right_idx_686 = tt.expand_dims %right_idx_685 {axis = 1 : i32} : tensor<32x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x2xi32, #blocked3> loc(#loc292) + %right_idx_687 = tt.broadcast %right_idx_686 : tensor<32x1x2xi32, #blocked3> -> tensor<32x2x2xi32, #blocked3> loc(#loc293) + %left_idx_688 = tt.reshape %left_idx_683 : tensor<32x2x2xi32, #blocked3> -> tensor<8x16xi32, #blocked> loc(#loc294) + %right_idx_689 = tt.reshape %right_idx_687 : tensor<32x2x2xi32, #blocked3> -> tensor<8x16xi32, #blocked> loc(#loc295) + %cond_690 = arith.cmpi slt, %ileft_677, %iright_678 : tensor<8x16xi32, #blocked> loc(#loc274) + %eq_691 = arith.cmpi eq, %ileft_677, %iright_678 : tensor<8x16xi32, #blocked> loc(#loc275) + %cond_692 = arith.cmpi sgt, %left_idx_688, %right_idx_689 : tensor<8x16xi32, #blocked> loc(#loc296) + %cond_693 = arith.andi %eq_691, %cond_692 : tensor<8x16xi1, #blocked> loc(#loc276) + %cond_694 = arith.ori %cond_690, %cond_693 : tensor<8x16xi1, #blocked> loc(#loc277) + %ret_695 = arith.xori %ileft_677, %iright_678 : tensor<8x16xi32, #blocked> loc(#loc280) + %ret_696 = arith.select %cond_694, %ret_695, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc281) + %ret_697 = arith.xori %ret_664, %ret_696 : tensor<8x16xi32, #blocked> loc(#loc282) + %new_idxs_698 = arith.xori %left_idx_688, %right_idx_689 : tensor<8x16xi32, #blocked> loc(#loc297) + %new_idxs_699 = arith.select %cond_694, %new_idxs_698, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc283) + %new_idxs_700 = arith.xori %new_idxs_667, %new_idxs_699 : tensor<8x16xi32, #blocked> loc(#loc284) + %y_701 = tt.reshape %ret_697 : tensor<8x16xi32, #blocked> -> tensor<64x2x1xi32, #blocked4> loc(#loc263) + %ileft_702 = arith.muli %y_701, %ileft : tensor<64x2x1xi32, #blocked4> loc(#loc264) + %ileft_703 = "tt.reduce"(%ileft_702) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc265)), %ileft_742: i32 loc(callsite(#loc1 at #loc265))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc322) + tt.reduce.return %ileft_743 : i32 loc(#loc310) + }) : (tensor<64x2x1xi32, #blocked4>) -> tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc310) + %ileft_704 = tt.expand_dims %ileft_703 {axis = 1 : i32} : tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<64x1x1xi32, #blocked4> loc(#loc266) + %ileft_705 = tt.broadcast %ileft_704 : tensor<64x1x1xi32, #blocked4> -> tensor<64x2x1xi32, #blocked4> loc(#loc267) + %iright_706 = arith.muli %y_701, %iright : tensor<64x2x1xi32, #blocked4> loc(#loc268) + %iright_707 = "tt.reduce"(%iright_706) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc269)), %iright_742: i32 loc(callsite(#loc1 at #loc269))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc323) + tt.reduce.return %iright_743 : i32 loc(#loc312) + }) : (tensor<64x2x1xi32, #blocked4>) -> tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc312) + %iright_708 = tt.expand_dims %iright_707 {axis = 1 : i32} : tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<64x1x1xi32, #blocked4> loc(#loc270) + %iright_709 = tt.broadcast %iright_708 : tensor<64x1x1xi32, #blocked4> -> tensor<64x2x1xi32, #blocked4> loc(#loc271) + %ileft_710 = tt.reshape %ileft_705 : tensor<64x2x1xi32, #blocked4> -> tensor<8x16xi32, #blocked> loc(#loc272) + %iright_711 = tt.reshape %iright_709 : tensor<64x2x1xi32, #blocked4> -> tensor<8x16xi32, #blocked> loc(#loc273) + %y_idx_712 = tt.reshape %new_idxs_700 : tensor<8x16xi32, #blocked> -> tensor<64x2x1xi32, #blocked4> loc(#loc285) + %left_idx_713 = arith.muli %y_idx_712, %ileft : tensor<64x2x1xi32, #blocked4> loc(#loc286) + %left_idx_714 = "tt.reduce"(%left_idx_713) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc287)), %left_idx_742: i32 loc(callsite(#loc1 at #loc287))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc324) + tt.reduce.return %left_idx_743 : i32 loc(#loc314) + }) : (tensor<64x2x1xi32, #blocked4>) -> tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc314) + %left_idx_715 = tt.expand_dims %left_idx_714 {axis = 1 : i32} : tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<64x1x1xi32, #blocked4> loc(#loc288) + %left_idx_716 = tt.broadcast %left_idx_715 : tensor<64x1x1xi32, #blocked4> -> tensor<64x2x1xi32, #blocked4> loc(#loc289) + %right_idx_717 = arith.muli %y_idx_712, %iright : tensor<64x2x1xi32, #blocked4> loc(#loc290) + %right_idx_718 = "tt.reduce"(%right_idx_717) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc291)), %right_idx_742: i32 loc(callsite(#loc1 at #loc291))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc325) + tt.reduce.return %right_idx_743 : i32 loc(#loc316) + }) : (tensor<64x2x1xi32, #blocked4>) -> tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc316) + %right_idx_719 = tt.expand_dims %right_idx_718 {axis = 1 : i32} : tensor<64x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<64x1x1xi32, #blocked4> loc(#loc292) + %right_idx_720 = tt.broadcast %right_idx_719 : tensor<64x1x1xi32, #blocked4> -> tensor<64x2x1xi32, #blocked4> loc(#loc293) + %left_idx_721 = tt.reshape %left_idx_716 : tensor<64x2x1xi32, #blocked4> -> tensor<8x16xi32, #blocked> loc(#loc294) + %right_idx_722 = tt.reshape %right_idx_720 : tensor<64x2x1xi32, #blocked4> -> tensor<8x16xi32, #blocked> loc(#loc295) + %cond_723 = arith.cmpi slt, %ileft_710, %iright_711 : tensor<8x16xi32, #blocked> loc(#loc274) + %eq_724 = arith.cmpi eq, %ileft_710, %iright_711 : tensor<8x16xi32, #blocked> loc(#loc275) + %cond_725 = arith.cmpi sgt, %left_idx_721, %right_idx_722 : tensor<8x16xi32, #blocked> loc(#loc296) + %cond_726 = arith.andi %eq_724, %cond_725 : tensor<8x16xi1, #blocked> loc(#loc276) + %cond_727 = arith.ori %cond_723, %cond_726 : tensor<8x16xi1, #blocked> loc(#loc277) + %new_idxs_728 = arith.xori %left_idx_721, %right_idx_722 : tensor<8x16xi32, #blocked> loc(#loc297) + %new_idxs_729 = arith.select %cond_727, %new_idxs_728, %cst_9 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc283) + %new_idxs_730 = arith.xori %new_idxs_700, %new_idxs_729 : tensor<8x16xi32, #blocked> loc(#loc284) + %tmp20 = arith.extui %tmp5 : tensor<8x16xi1, #blocked> to tensor<8x16xi64, #blocked> loc(#loc219) + %tmp23 = arith.select %tmp0_29, %tmp20, %cst_13 : tensor<8x16xi1, #blocked>, tensor<8x16xi64, #blocked> loc(#loc191) + %tmp24 = "tt.reduce"(%tmp23) <{axis = 1 : i32}> ({ + ^bb0(%tmp24_741: i64 loc(callsite(#loc1 at #loc192)), %tmp24_742: i64 loc(callsite(#loc1 at #loc192))): + %tmp24_743 = arith.addi %tmp24_741, %tmp24_742 : i64 loc(#loc298) + tt.reduce.return %tmp24_743 : i64 loc(#loc220) + }) : (tensor<8x16xi64, #blocked>) -> tensor<8xi64, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc220) + %tmp30 = ttg.convert_layout %tmp24 : tensor<8xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<8xi64, #ttg.slice<{dim = 1, parent = #blocked5}>> loc(#loc193) + %tmp24_731 = tt.expand_dims %tmp30 {axis = 1 : i32} : tensor<8xi64, #ttg.slice<{dim = 1, parent = #blocked5}>> -> tensor<8x1xi64, #blocked5> loc(#loc194) + %tmp24_732 = tt.expand_dims %tmp24 {axis = 1 : i32} : tensor<8xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<8x1xi64, #blocked> loc(#loc194) + %tmp25 = arith.extui %tmp14 : tensor<8x16xi1, #blocked> to tensor<8x16xi64, #blocked> loc(#loc222) + %tmp28 = arith.select %tmp0_29, %tmp25, %cst_13 : tensor<8x16xi1, #blocked>, tensor<8x16xi64, #blocked> loc(#loc196) + %tmp29 = "tt.reduce"(%tmp28) <{axis = 1 : i32}> ({ + ^bb0(%tmp29_741: i64 loc(callsite(#loc1 at #loc197)), %tmp29_742: i64 loc(callsite(#loc1 at #loc197))): + %tmp29_743 = arith.addi %tmp29_741, %tmp29_742 : i64 loc(#loc299) + tt.reduce.return %tmp29_743 : i64 loc(#loc223) + }) : (tensor<8x16xi64, #blocked>) -> tensor<8xi64, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc223) + %tmp31 = ttg.convert_layout %tmp29 : tensor<8xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<8xi64, #ttg.slice<{dim = 1, parent = #blocked5}>> loc(#loc198) + %tmp29_733 = tt.expand_dims %tmp31 {axis = 1 : i32} : tensor<8xi64, #ttg.slice<{dim = 1, parent = #blocked5}>> -> tensor<8x1xi64, #blocked5> loc(#loc199) + %tmp29_734 = tt.expand_dims %tmp29 {axis = 1 : i32} : tensor<8xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<8x1xi64, #blocked> loc(#loc199) + %tmp30_735 = arith.trunci %tmp24_731 : tensor<8x1xi64, #blocked5> to tensor<8x1xi32, #blocked5> loc(#loc193) + %tmp30_736 = arith.trunci %tmp24_732 : tensor<8x1xi64, #blocked> to tensor<8x1xi32, #blocked> loc(#loc193) + %tmp31_737 = arith.trunci %tmp29_733 : tensor<8x1xi64, #blocked5> to tensor<8x1xi32, #blocked5> loc(#loc198) + %tmp31_738 = arith.trunci %tmp29_734 : tensor<8x1xi64, #blocked> to tensor<8x1xi32, #blocked> loc(#loc198) + %tmp34 = tt.broadcast %tmp30_736 : tensor<8x1xi32, #blocked> -> tensor<8x16xi32, #blocked> loc(#loc200) + %tmp34_739 = arith.cmpi slt, %tmp0_24, %tmp34 : tensor<8x16xi32, #blocked> loc(#loc200) + %tmp36 = arith.select %tmp34_739, %new_idxs_398, %cst_11 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc201) + %tmp38 = arith.addi %tmp36, %cst_10 : tensor<8x16xi32, #blocked> loc(#loc202) + %tmp39 = arith.cmpi slt, %tmp36, %cst_9 : tensor<8x16xi32, #blocked> loc(#loc203) + %tmp40 = arith.select %tmp39, %tmp38, %tmp36 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc204) + %0 = arith.cmpi sge, %tmp40, %cst_9 : tensor<8x16xi32, #blocked> loc(#loc85) + %1 = arith.cmpi slt, %tmp40, %cst_10 : tensor<8x16xi32, #blocked> loc(#loc86) + %2 = arith.andi %0, %1 : tensor<8x16xi1, #blocked> loc(#loc87) + %3 = arith.xori %xmask, %cst : tensor<8x1xi1, #blocked> loc(#loc88) + %4 = tt.broadcast %3 : tensor<8x1xi1, #blocked> -> tensor<8x16xi1, #blocked> loc(#loc89) + %5 = arith.ori %2, %4 : tensor<8x16xi1, #blocked> loc(#loc89) + tt.assert %5, "index out of bounds: 0 <= tmp40 < 17" : tensor<8x16xi1, #blocked> loc(#loc90) + %tmp45 = tt.broadcast %tmp31_738 : tensor<8x1xi32, #blocked> -> tensor<8x16xi32, #blocked> loc(#loc205) + %tmp45_740 = arith.cmpi slt, %tmp0_24, %tmp45 : tensor<8x16xi32, #blocked> loc(#loc205) + %tmp46 = arith.select %tmp45_740, %new_idxs_730, %cst_11 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc206) + %tmp47 = arith.addi %tmp46, %cst_10 : tensor<8x16xi32, #blocked> loc(#loc207) + %tmp48 = arith.cmpi slt, %tmp46, %cst_9 : tensor<8x16xi32, #blocked> loc(#loc208) + %tmp49 = arith.select %tmp48, %tmp47, %tmp46 : tensor<8x16xi1, #blocked>, tensor<8x16xi32, #blocked> loc(#loc209) + %6 = arith.cmpi sge, %tmp49, %cst_9 : tensor<8x16xi32, #blocked> loc(#loc96) + %7 = arith.cmpi slt, %tmp49, %cst_10 : tensor<8x16xi32, #blocked> loc(#loc97) + %8 = arith.andi %6, %7 : tensor<8x16xi1, #blocked> loc(#loc98) + %9 = arith.ori %8, %4 : tensor<8x16xi1, #blocked> loc(#loc99) + tt.assert %9, "index out of bounds: 0 <= tmp49 < 17" : tensor<8x16xi1, #blocked> loc(#loc100) + %10 = tt.splat %out_ptr4 : !tt.ptr -> tensor<8x1x!tt.ptr, #blocked5> loc(#loc101) + %11 = tt.addptr %10, %xindex_21 : tensor<8x1x!tt.ptr, #blocked5>, tensor<8x1xi32, #blocked5> loc(#loc101) + tt.store %11, %tmp30_735, %xmask_22 : tensor<8x1x!tt.ptr, #blocked5> loc(#loc102) + %12 = tt.splat %out_ptr5 : !tt.ptr -> tensor<8x1x!tt.ptr, #blocked5> loc(#loc103) + %13 = tt.addptr %12, %xindex_21 : tensor<8x1x!tt.ptr, #blocked5>, tensor<8x1xi32, #blocked5> loc(#loc103) + tt.store %13, %tmp31_737, %xmask_22 : tensor<8x1x!tt.ptr, #blocked5> loc(#loc104) + %14 = tt.splat %out_ptr6 : !tt.ptr -> tensor<8x16x!tt.ptr, #blocked> loc(#loc105) + %15 = tt.addptr %14, %tmp0_26 : tensor<8x16x!tt.ptr, #blocked>, tensor<8x16xi32, #blocked> loc(#loc105) + tt.store %15, %new_idxs_398, %tmp0_29 : tensor<8x16x!tt.ptr, #blocked> loc(#loc106) + %16 = arith.muli %xindex_20, %cst_0 : tensor<8x1xi32, #blocked> loc(#loc107) + %17 = tt.broadcast %16 : tensor<8x1xi32, #blocked> -> tensor<8x16xi32, #blocked> loc(#loc108) + %18 = arith.addi %tmp40, %17 : tensor<8x16xi32, #blocked> loc(#loc108) + %19 = tt.splat %out_ptr7 : !tt.ptr -> tensor<8x16x!tt.ptr, #blocked> loc(#loc109) + %20 = tt.addptr %19, %18 : tensor<8x16x!tt.ptr, #blocked>, tensor<8x16xi32, #blocked> loc(#loc109) + %21 = ttg.convert_layout %20 : tensor<8x16x!tt.ptr, #blocked> -> tensor<8x16x!tt.ptr, #blocked5> loc(#loc110) + tt.store %21, %cst_8, %tmp0_30 : tensor<8x16x!tt.ptr, #blocked5> loc(#loc110) + %22 = tt.splat %out_ptr8 : !tt.ptr -> tensor<8x16x!tt.ptr, #blocked> loc(#loc111) + %23 = tt.addptr %22, %tmp0_26 : tensor<8x16x!tt.ptr, #blocked>, tensor<8x16xi32, #blocked> loc(#loc111) + tt.store %23, %new_idxs_730, %tmp0_29 : tensor<8x16x!tt.ptr, #blocked> loc(#loc112) + %24 = arith.addi %tmp49, %17 : tensor<8x16xi32, #blocked> loc(#loc113) + %25 = tt.splat %out_ptr9 : !tt.ptr -> tensor<8x16x!tt.ptr, #blocked> loc(#loc114) + %26 = tt.addptr %25, %24 : tensor<8x16x!tt.ptr, #blocked>, tensor<8x16xi32, #blocked> loc(#loc114) + %27 = ttg.convert_layout %26 : tensor<8x16x!tt.ptr, #blocked> -> tensor<8x16x!tt.ptr, #blocked5> loc(#loc115) + tt.store %27, %cst_8, %tmp0_30 : tensor<8x16x!tt.ptr, #blocked5> loc(#loc115) + tt.return loc(#loc116) + } loc(#loc) +} loc(#loc) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":24:28) +#loc3 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":24:33) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":25:44) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":25:23) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":26:21) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":27:38) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":34:40) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":34:37) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":34:30) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":34:45) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":36:18) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":38:18) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":39:18) +#loc15 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":41:19) +#loc16 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":40:19) +#loc17 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":43:19) +#loc18 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":45:34) +#loc19 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":627:44) +#loc22 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":627:60) +#loc23 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":627:68) +#loc24 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":533:22) +#loc26 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":537:21) +#loc27 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":538:40) +#loc28 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:36) +#loc30 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:15) +#loc31 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":538:65) +#loc32 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":538:78) +#loc33 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":539:41) +#loc35 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":539:67) +#loc36 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":539:80) +#loc37 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":540:30) +#loc38 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":541:32) +#loc39 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":546:29) +#loc40 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:36) +#loc41 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:23) +#loc42 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":290:25) +#loc44 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:53) +#loc45 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:66) +#loc46 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:37) +#loc47 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:23) +#loc49 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:54) +#loc50 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:67) +#loc51 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":553:36) +#loc52 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":554:38) +#loc53 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":574:22) +#loc54 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":591:21) +#loc55 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":594:40) +#loc56 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":594:29) +#loc57 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":594:23) +#loc58 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":599:19) +#loc59 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":599:28) +#loc60 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":600:38) +#loc61 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":600:46) +#loc62 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":600:15) +#loc63 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":601:48) +#loc64 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":601:59) +#loc65 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":601:22) +#loc66 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":47:20) +#loc67 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":49:21) +#loc68 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":48:21) +#loc70 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":52:20) +#loc71 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":54:35) +#loc73 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":60:21) +#loc74 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":55:29) +#loc75 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":56:21) +#loc76 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":58:35) +#loc78 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":61:21) +#loc79 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":59:29) +#loc80 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":64:19) +#loc81 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":66:35) +#loc82 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":68:20) +#loc83 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":69:20) +#loc84 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":70:35) +#loc85 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":71:28) +#loc86 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":71:46) +#loc87 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":71:38) +#loc88 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":71:55) +#loc89 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":71:53) +#loc90 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":71:63) +#loc91 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":75:19) +#loc92 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":76:35) +#loc93 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":77:20) +#loc94 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":78:20) +#loc95 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":79:35) +#loc96 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":80:28) +#loc97 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":80:46) +#loc98 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":80:38) +#loc99 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":80:53) +#loc100 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":80:63) +#loc101 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":81:25) +#loc102 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":81:37) +#loc103 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":82:25) +#loc104 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":82:37) +#loc105 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":83:25) +#loc106 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":83:47) +#loc107 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":84:52) +#loc108 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":84:49) +#loc109 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":84:25) +#loc110 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":84:85) +#loc111 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":85:25) +#loc112 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":85:47) +#loc113 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":86:49) +#loc114 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":86:25) +#loc115 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":86:85) +#loc116 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":86:4) +#loc126 = loc("xoffset"(#loc2)) +#loc127 = loc("xoffset"(#loc3)) +#loc128 = loc("xindex"(#loc4)) +#loc129 = loc("xindex"(#loc5)) +#loc130 = loc("xmask"(#loc6)) +#loc131 = loc("r0_index"(#loc7)) +#loc132 = loc("tmp0"(#loc8)) +#loc133 = loc("tmp0"(#loc9)) +#loc134 = loc("tmp0"(#loc10)) +#loc135 = loc("tmp0"(#loc11)) +#loc136 = loc("tmp2"(#loc12)) +#loc137 = loc("tmp4"(#loc13)) +#loc138 = loc("tmp5"(#loc14)) +#loc139 = loc("tmp7"(#loc15)) +#loc140 = loc("tmp6"(#loc16)) +#loc141 = loc("tmp9"(#loc17)) +#loc142 = loc("tmp11"(#loc18)) +#loc143 = loc("flip"(#loc19)) +#loc145 = loc("flip"(#loc22)) +#loc146 = loc("flip"(#loc23)) +#loc147 = loc("y"(#loc24)) +#loc148 = loc("left_mask"(#loc26)) +#loc149 = loc("ileft"(#loc27)) +#loc151 = loc("ileft"(#loc31)) +#loc152 = loc("ileft"(#loc32)) +#loc153 = loc("iright"(#loc33)) +#loc155 = loc("iright"(#loc35)) +#loc156 = loc("iright"(#loc36)) +#loc157 = loc("ileft"(#loc37)) +#loc158 = loc("iright"(#loc38)) +#loc159 = loc("y_idx"(#loc39)) +#loc160 = loc("left_idx"(#loc40)) +#loc161 = loc("left_idx"(#loc41)) +#loc162 = loc("input"(#loc42)) +#loc164 = loc("left_idx"(#loc44)) +#loc165 = loc("left_idx"(#loc45)) +#loc166 = loc("right_idx"(#loc46)) +#loc167 = loc("right_idx"(#loc47)) +#loc169 = loc("right_idx"(#loc49)) +#loc170 = loc("right_idx"(#loc50)) +#loc171 = loc("left_idx"(#loc51)) +#loc172 = loc("right_idx"(#loc52)) +#loc173 = loc("cond"(#loc53)) +#loc174 = loc("eq"(#loc54)) +#loc175 = loc("cond"(#loc55)) +#loc176 = loc("cond"(#loc56)) +#loc177 = loc("cond"(#loc57)) +#loc178 = loc("cond"(#loc58)) +#loc179 = loc("cond"(#loc59)) +#loc180 = loc("ret"(#loc60)) +#loc181 = loc("ret"(#loc61)) +#loc182 = loc("ret"(#loc62)) +#loc183 = loc("new_idxs"(#loc63)) +#loc184 = loc("new_idxs"(#loc64)) +#loc185 = loc("new_idxs"(#loc65)) +#loc186 = loc("tmp14"(#loc66)) +#loc187 = loc("tmp16"(#loc67)) +#loc188 = loc("tmp15"(#loc68)) +#loc190 = loc("tmp20"(#loc70)) +#loc191 = loc("tmp23"(#loc71)) +#loc193 = loc("tmp30"(#loc73)) +#loc194 = loc("tmp24"(#loc74)) +#loc195 = loc("tmp25"(#loc75)) +#loc196 = loc("tmp28"(#loc76)) +#loc198 = loc("tmp31"(#loc78)) +#loc199 = loc("tmp29"(#loc79)) +#loc200 = loc("tmp34"(#loc80)) +#loc201 = loc("tmp36"(#loc81)) +#loc202 = loc("tmp38"(#loc82)) +#loc203 = loc("tmp39"(#loc83)) +#loc204 = loc("tmp40"(#loc84)) +#loc205 = loc("tmp45"(#loc91)) +#loc206 = loc("tmp46"(#loc92)) +#loc207 = loc("tmp47"(#loc93)) +#loc208 = loc("tmp48"(#loc94)) +#loc209 = loc("tmp49"(#loc95)) +#loc210 = loc(fused[#loc139, #loc140]) +#loc211 = loc(callsite(#loc143 at #loc144)) +#loc212 = loc(callsite(#loc145 at #loc144)) +#loc213 = loc(callsite(#loc146 at #loc144)) +#loc215 = loc("cond"(#loc173)) +#loc216 = loc("eq"(#loc174)) +#loc217 = loc(fused[#loc187, #loc188]) +#loc219 = loc(fused[#loc190, #loc139, #loc140]) +#loc220 = loc(callsite(#loc28 at #loc192)) +#loc222 = loc(fused[#loc195, #loc187, #loc188]) +#loc223 = loc(callsite(#loc28 at #loc197)) +#loc225 = loc(callsite(#loc147 at #loc214)) +#loc226 = loc(callsite(#loc148 at #loc214)) +#loc227 = loc(callsite(#loc149 at #loc214)) +#loc229 = loc(callsite(#loc151 at #loc214)) +#loc230 = loc(callsite(#loc152 at #loc214)) +#loc231 = loc(callsite(#loc153 at #loc214)) +#loc233 = loc(callsite(#loc155 at #loc214)) +#loc234 = loc(callsite(#loc156 at #loc214)) +#loc235 = loc(callsite(#loc157 at #loc214)) +#loc236 = loc(callsite(#loc158 at #loc214)) +#loc237 = loc(callsite(#loc159 at #loc214)) +#loc238 = loc(callsite(#loc160 at #loc214)) +#loc239 = loc(callsite(#loc161 at #loc214)) +#loc241 = loc(callsite(#loc164 at #loc214)) +#loc242 = loc(callsite(#loc165 at #loc214)) +#loc243 = loc(callsite(#loc166 at #loc214)) +#loc244 = loc(callsite(#loc167 at #loc214)) +#loc246 = loc(callsite(#loc169 at #loc214)) +#loc247 = loc(callsite(#loc170 at #loc214)) +#loc248 = loc(callsite(#loc171 at #loc214)) +#loc249 = loc(callsite(#loc172 at #loc214)) +#loc250 = loc(callsite(#loc215 at #loc214)) +#loc251 = loc(callsite(#loc216 at #loc214)) +#loc252 = loc(callsite(#loc175 at #loc214)) +#loc253 = loc(callsite(#loc176 at #loc214)) +#loc254 = loc(callsite(#loc177 at #loc214)) +#loc255 = loc(callsite(#loc178 at #loc214)) +#loc256 = loc(callsite(#loc179 at #loc214)) +#loc257 = loc(callsite(#loc180 at #loc214)) +#loc258 = loc(callsite(#loc181 at #loc214)) +#loc259 = loc(callsite(#loc182 at #loc214)) +#loc260 = loc(callsite(#loc183 at #loc214)) +#loc261 = loc(callsite(#loc184 at #loc214)) +#loc262 = loc(callsite(#loc185 at #loc214)) +#loc263 = loc(callsite(#loc147 at #loc218)) +#loc264 = loc(callsite(#loc149 at #loc218)) +#loc266 = loc(callsite(#loc151 at #loc218)) +#loc267 = loc(callsite(#loc152 at #loc218)) +#loc268 = loc(callsite(#loc153 at #loc218)) +#loc270 = loc(callsite(#loc155 at #loc218)) +#loc271 = loc(callsite(#loc156 at #loc218)) +#loc272 = loc(callsite(#loc157 at #loc218)) +#loc273 = loc(callsite(#loc158 at #loc218)) +#loc274 = loc(callsite(#loc215 at #loc218)) +#loc275 = loc(callsite(#loc216 at #loc218)) +#loc276 = loc(callsite(#loc176 at #loc218)) +#loc277 = loc(callsite(#loc177 at #loc218)) +#loc278 = loc(callsite(#loc178 at #loc218)) +#loc279 = loc(callsite(#loc179 at #loc218)) +#loc280 = loc(callsite(#loc180 at #loc218)) +#loc281 = loc(callsite(#loc181 at #loc218)) +#loc282 = loc(callsite(#loc182 at #loc218)) +#loc283 = loc(callsite(#loc184 at #loc218)) +#loc284 = loc(callsite(#loc185 at #loc218)) +#loc285 = loc(callsite(#loc159 at #loc218)) +#loc286 = loc(callsite(#loc161 at #loc218)) +#loc288 = loc(callsite(#loc164 at #loc218)) +#loc289 = loc(callsite(#loc165 at #loc218)) +#loc290 = loc(callsite(#loc167 at #loc218)) +#loc292 = loc(callsite(#loc169 at #loc218)) +#loc293 = loc(callsite(#loc170 at #loc218)) +#loc294 = loc(callsite(#loc171 at #loc218)) +#loc295 = loc(callsite(#loc172 at #loc218)) +#loc296 = loc(callsite(#loc175 at #loc218)) +#loc297 = loc(callsite(#loc183 at #loc218)) +#loc298 = loc(callsite(#loc30 at #loc220)) +#loc299 = loc(callsite(#loc30 at #loc223)) +#loc300 = loc(callsite(#loc28 at #loc228)) +#loc302 = loc(callsite(#loc28 at #loc232)) +#loc304 = loc(callsite(#loc162 at #loc240)) +#loc305 = loc(callsite(#loc28 at #loc240)) +#loc307 = loc(callsite(#loc162 at #loc245)) +#loc308 = loc(callsite(#loc28 at #loc245)) +#loc310 = loc(callsite(#loc28 at #loc265)) +#loc312 = loc(callsite(#loc28 at #loc269)) +#loc314 = loc(callsite(#loc28 at #loc287)) +#loc316 = loc(callsite(#loc28 at #loc291)) +#loc318 = loc(callsite(#loc30 at #loc300)) +#loc319 = loc(callsite(#loc30 at #loc302)) +#loc320 = loc(callsite(#loc30 at #loc305)) +#loc321 = loc(callsite(#loc30 at #loc308)) +#loc322 = loc(callsite(#loc30 at #loc310)) +#loc323 = loc(callsite(#loc30 at #loc312)) +#loc324 = loc(callsite(#loc30 at #loc314)) +#loc325 = loc(callsite(#loc30 at #loc316)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/5P66Y2R4BAVAKI2AZ4OOKOSCRGKH7SKT5SYLTP4RXGBUCBAJIDZQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ttir b/SpecForge-ext/cache/compiled_kernels/triton/0/5P66Y2R4BAVAKI2AZ4OOKOSCRGKH7SKT5SYLTP4RXGBUCBAJIDZQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ttir new file mode 100644 index 0000000000000000000000000000000000000000..521ed34ac168f535ae559ec4017a7133b1b4d062 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/5P66Y2R4BAVAKI2AZ4OOKOSCRGKH7SKT5SYLTP4RXGBUCBAJIDZQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ttir @@ -0,0 +1,1451 @@ +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":18:0) +#loc1 = loc(unknown) +#loc22 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":662:12) +#loc23 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":46:71) +#loc28 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":634:73) +#loc32 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":538:51) +#loc37 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":539:53) +#loc46 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:50) +#loc51 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:51) +#loc72 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":51:71) +#loc75 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":55:26) +#loc79 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":59:26) +#loc120 = loc("in_ptr0"(#loc)) +#loc121 = loc("out_ptr4"(#loc)) +#loc122 = loc("out_ptr5"(#loc)) +#loc123 = loc("out_ptr6"(#loc)) +#loc124 = loc("out_ptr7"(#loc)) +#loc125 = loc("out_ptr8"(#loc)) +#loc126 = loc("out_ptr9"(#loc)) +#loc127 = loc("xnumel"(#loc)) +#loc128 = loc("r0_numel"(#loc)) +#loc149 = loc(callsite(#loc22 at #loc23)) +#loc156 = loc("ileft"(#loc32)) +#loc160 = loc("iright"(#loc37)) +#loc169 = loc("left_idx"(#loc46)) +#loc174 = loc("right_idx"(#loc51)) +#loc195 = loc(callsite(#loc22 at #loc72)) +#loc198 = loc("tmp24"(#loc75)) +#loc202 = loc("tmp29"(#loc79)) +#loc221 = loc(callsite(#loc28 at #loc149)) +#loc225 = loc(callsite(#loc28 at #loc195)) +#loc228 = loc(callsite(#loc1 at #loc198)) +#loc231 = loc(callsite(#loc1 at #loc202)) +#loc235 = loc(callsite(#loc156 at #loc221)) +#loc239 = loc(callsite(#loc160 at #loc221)) +#loc247 = loc(callsite(#loc169 at #loc221)) +#loc252 = loc(callsite(#loc174 at #loc221)) +#loc272 = loc(callsite(#loc156 at #loc225)) +#loc276 = loc(callsite(#loc160 at #loc225)) +#loc294 = loc(callsite(#loc169 at #loc225)) +#loc298 = loc(callsite(#loc174 at #loc225)) +#loc308 = loc(callsite(#loc1 at #loc235)) +#loc310 = loc(callsite(#loc1 at #loc239)) +#loc313 = loc(callsite(#loc1 at #loc247)) +#loc316 = loc(callsite(#loc1 at #loc252)) +#loc318 = loc(callsite(#loc1 at #loc272)) +#loc320 = loc(callsite(#loc1 at #loc276)) +#loc322 = loc(callsite(#loc1 at #loc294)) +#loc324 = loc(callsite(#loc1 at #loc298)) +module { + tt.func public @triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %out_ptr4: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr4"(#loc)), %out_ptr5: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr5"(#loc)), %out_ptr6: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr6"(#loc)), %out_ptr7: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr7"(#loc)), %out_ptr8: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr8"(#loc)), %out_ptr9: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr9"(#loc)), %xnumel: i32 {tt.divisibility = 16 : i32} loc("xnumel"(#loc)), %r0_numel: i32 {tt.divisibility = 16 : i32} loc("r0_numel"(#loc))) attributes {noinline = false} { + %cst = arith.constant dense<1> : tensor<1x2x1xi32> loc(#loc1) + %cst_0 = arith.constant dense<1> : tensor<8x16xi32> loc(#loc1) + %cst_1 = arith.constant dense<17> : tensor<8x1xi32> loc(#loc1) + %cst_2 = arith.constant dense : tensor<8x1xi1> loc(#loc1) + %cst_3 = arith.constant dense<0> : tensor<8x16xi32> loc(#loc1) + %cst_4 = arith.constant dense<17> : tensor<8x16xi32> loc(#loc1) + %cst_5 = arith.constant dense<16> : tensor<8x16xi32> loc(#loc1) + %cst_6 = arith.constant dense<16384> : tensor<8x16xi64> loc(#loc1) + %cst_7 = arith.constant dense<0> : tensor<8x16xi64> loc(#loc1) + %cst_8 = arith.constant dense<16> : tensor<8x1xi32> loc(#loc1) + %xmask = arith.constant dense<32> : tensor<8x1xi32> loc(#loc129) + %c8_i32 = arith.constant 8 : i32 loc(#loc1) + %xoffset = tt.get_program_id x : i32 loc(#loc130) + %xoffset_9 = arith.muli %xoffset, %c8_i32 : i32 loc(#loc131) + %xindex = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> loc(#loc132) + %xindex_10 = tt.expand_dims %xindex {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> loc(#loc133) + %xindex_11 = tt.splat %xoffset_9 : i32 -> tensor<8x1xi32> loc(#loc134) + %xindex_12 = arith.addi %xindex_11, %xindex_10 : tensor<8x1xi32> loc(#loc134) + %xmask_13 = arith.cmpi slt, %xindex_12, %xmask : tensor<8x1xi32> loc(#loc129) + %r0_index = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc135) + %r0_index_14 = tt.expand_dims %r0_index {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> loc(#loc136) + %tmp0 = arith.muli %xindex_12, %cst_8 : tensor<8x1xi32> loc(#loc137) + %tmp0_15 = tt.broadcast %r0_index_14 : tensor<1x16xi32> -> tensor<8x16xi32> loc(#loc138) + %tmp0_16 = tt.broadcast %tmp0 : tensor<8x1xi32> -> tensor<8x16xi32> loc(#loc138) + %tmp0_17 = arith.addi %tmp0_15, %tmp0_16 : tensor<8x16xi32> loc(#loc138) + %tmp0_18 = tt.splat %in_ptr0 : !tt.ptr -> tensor<8x16x!tt.ptr> loc(#loc139) + %tmp0_19 = tt.addptr %tmp0_18, %tmp0_17 : tensor<8x16x!tt.ptr>, tensor<8x16xi32> loc(#loc139) + %tmp0_20 = tt.broadcast %xmask_13 : tensor<8x1xi1> -> tensor<8x16xi1> loc(#loc140) + %tmp0_21 = tt.load %tmp0_19, %tmp0_20, %cst_7 : tensor<8x16x!tt.ptr> loc(#loc140) + %tmp2 = arith.cmpi sgt, %tmp0_21, %cst_7 : tensor<8x16xi64> loc(#loc141) + %tmp4 = arith.cmpi slt, %tmp0_21, %cst_6 : tensor<8x16xi64> loc(#loc142) + %tmp5 = arith.andi %tmp2, %tmp4 : tensor<8x16xi1> loc(#loc143) + %tmp7 = arith.extui %tmp5 : tensor<8x16xi1> to tensor<8x16xi32> loc(#loc216) + %tmp9 = arith.trunci %r0_index_14 : tensor<1x16xi32> to tensor<1x16xi16> loc(#loc146) + %tmp11 = tt.broadcast %tmp9 : tensor<1x16xi16> -> tensor<8x16xi16> loc(#loc147) + %flip = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc217) + %flip_22 = tt.expand_dims %flip {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc218) + %flip_23 = tt.expand_dims %flip_22 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc218) + %flip_24 = tt.broadcast %flip_23 : tensor<1x2x1xi32> -> tensor<32x2x2xi32> loc(#loc219) + %flip_25 = tt.reshape %flip_24 : tensor<32x2x2xi32> -> tensor<8x16xi32> loc(#loc220) + %y = tt.reshape %tmp7 : tensor<8x16xi32> -> tensor<64x2x1xi32> loc(#loc232) + %left_mask = arith.subi %cst, %flip_23 : tensor<1x2x1xi32> loc(#loc233) + %ileft = tt.broadcast %left_mask : tensor<1x2x1xi32> -> tensor<64x2x1xi32> loc(#loc234) + %ileft_26 = arith.muli %y, %ileft : tensor<64x2x1xi32> loc(#loc234) + %ileft_27 = "tt.reduce"(%ileft_26) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc235)), %ileft_714: i32 loc(callsite(#loc1 at #loc235))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc325) + tt.reduce.return %ileft_715 : i32 loc(#loc307) + }) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc307) + %ileft_28 = tt.expand_dims %ileft_27 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc236) + %ileft_29 = tt.broadcast %ileft_28 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc237) + %iright = tt.broadcast %flip_23 : tensor<1x2x1xi32> -> tensor<64x2x1xi32> loc(#loc238) + %iright_30 = arith.muli %y, %iright : tensor<64x2x1xi32> loc(#loc238) + %iright_31 = "tt.reduce"(%iright_30) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc239)), %iright_714: i32 loc(callsite(#loc1 at #loc239))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc326) + tt.reduce.return %iright_715 : i32 loc(#loc309) + }) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc309) + %iright_32 = tt.expand_dims %iright_31 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc240) + %iright_33 = tt.broadcast %iright_32 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc241) + %ileft_34 = tt.reshape %ileft_29 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc242) + %iright_35 = tt.reshape %iright_33 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc243) + %y_idx = tt.reshape %tmp11 : tensor<8x16xi16> -> tensor<64x2x1xi16> loc(#loc244) + %left_idx = arith.trunci %left_mask : tensor<1x2x1xi32> to tensor<1x2x1xi16> loc(#loc245) + %left_idx_36 = tt.broadcast %left_idx : tensor<1x2x1xi16> -> tensor<64x2x1xi16> loc(#loc246) + %left_idx_37 = arith.muli %y_idx, %left_idx_36 : tensor<64x2x1xi16> loc(#loc246) + %input = arith.extsi %left_idx_37 : tensor<64x2x1xi16> to tensor<64x2x1xi32> loc(#loc311) + %left_idx_38 = "tt.reduce"(%input) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc247)), %left_idx_714: i32 loc(callsite(#loc1 at #loc247))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc327) + tt.reduce.return %left_idx_715 : i32 loc(#loc312) + }) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc312) + %left_idx_39 = tt.expand_dims %left_idx_38 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc248) + %left_idx_40 = tt.broadcast %left_idx_39 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc249) + %right_idx = arith.trunci %flip_23 : tensor<1x2x1xi32> to tensor<1x2x1xi16> loc(#loc250) + %right_idx_41 = tt.broadcast %right_idx : tensor<1x2x1xi16> -> tensor<64x2x1xi16> loc(#loc251) + %right_idx_42 = arith.muli %y_idx, %right_idx_41 : tensor<64x2x1xi16> loc(#loc251) + %input_43 = arith.extsi %right_idx_42 : tensor<64x2x1xi16> to tensor<64x2x1xi32> loc(#loc314) + %right_idx_44 = "tt.reduce"(%input_43) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc252)), %right_idx_714: i32 loc(callsite(#loc1 at #loc252))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc328) + tt.reduce.return %right_idx_715 : i32 loc(#loc315) + }) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc315) + %right_idx_45 = tt.expand_dims %right_idx_44 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc253) + %right_idx_46 = tt.broadcast %right_idx_45 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc254) + %left_idx_47 = tt.reshape %left_idx_40 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc255) + %right_idx_48 = tt.reshape %right_idx_46 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc256) + %cond = arith.cmpi slt, %ileft_34, %iright_35 : tensor<8x16xi32> loc(#loc257) + %eq = arith.cmpi eq, %ileft_34, %iright_35 : tensor<8x16xi32> loc(#loc258) + %cond_49 = arith.cmpi sgt, %left_idx_47, %right_idx_48 : tensor<8x16xi32> loc(#loc259) + %cond_50 = arith.andi %eq, %cond_49 : tensor<8x16xi1> loc(#loc260) + %cond_51 = arith.ori %cond, %cond_50 : tensor<8x16xi1> loc(#loc261) + %cond_52 = arith.extui %cond_51 : tensor<8x16xi1> to tensor<8x16xi32> loc(#loc262) + %cond_53 = arith.xori %cond_52, %flip_25 : tensor<8x16xi32> loc(#loc262) + %cond_54 = arith.cmpi ne, %cond_53, %cst_3 : tensor<8x16xi32> loc(#loc263) + %ret = arith.xori %ileft_34, %iright_35 : tensor<8x16xi32> loc(#loc264) + %ret_55 = arith.select %cond_54, %ret, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc265) + %ret_56 = arith.xori %tmp7, %ret_55 : tensor<8x16xi32> loc(#loc266) + %new_idxs = arith.xori %left_idx_47, %right_idx_48 : tensor<8x16xi32> loc(#loc267) + %new_idxs_57 = arith.select %cond_54, %new_idxs, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc268) + %new_idxs_58 = arith.extsi %tmp9 : tensor<1x16xi16> to tensor<1x16xi32> loc(#loc269) + %new_idxs_59 = tt.broadcast %new_idxs_58 : tensor<1x16xi32> -> tensor<8x16xi32> loc(#loc269) + %new_idxs_60 = arith.xori %new_idxs_59, %new_idxs_57 : tensor<8x16xi32> loc(#loc269) + %flip_61 = tt.broadcast %flip_23 : tensor<1x2x1xi32> -> tensor<16x2x4xi32> loc(#loc219) + %flip_62 = tt.reshape %flip_61 : tensor<16x2x4xi32> -> tensor<8x16xi32> loc(#loc220) + %y_63 = tt.reshape %ret_56 : tensor<8x16xi32> -> tensor<32x2x2xi32> loc(#loc232) + %ileft_64 = tt.broadcast %left_mask : tensor<1x2x1xi32> -> tensor<32x2x2xi32> loc(#loc234) + %ileft_65 = arith.muli %y_63, %ileft_64 : tensor<32x2x2xi32> loc(#loc234) + %ileft_66 = "tt.reduce"(%ileft_65) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc235)), %ileft_714: i32 loc(callsite(#loc1 at #loc235))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc325) + tt.reduce.return %ileft_715 : i32 loc(#loc307) + }) : (tensor<32x2x2xi32>) -> tensor<32x2xi32> loc(#loc307) + %ileft_67 = tt.expand_dims %ileft_66 {axis = 1 : i32} : tensor<32x2xi32> -> tensor<32x1x2xi32> loc(#loc236) + %ileft_68 = tt.broadcast %ileft_67 : tensor<32x1x2xi32> -> tensor<32x2x2xi32> loc(#loc237) + %iright_69 = arith.muli %y_63, %flip_24 : tensor<32x2x2xi32> loc(#loc238) + %iright_70 = "tt.reduce"(%iright_69) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc239)), %iright_714: i32 loc(callsite(#loc1 at #loc239))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc326) + tt.reduce.return %iright_715 : i32 loc(#loc309) + }) : (tensor<32x2x2xi32>) -> tensor<32x2xi32> loc(#loc309) + %iright_71 = tt.expand_dims %iright_70 {axis = 1 : i32} : tensor<32x2xi32> -> tensor<32x1x2xi32> loc(#loc240) + %iright_72 = tt.broadcast %iright_71 : tensor<32x1x2xi32> -> tensor<32x2x2xi32> loc(#loc241) + %ileft_73 = tt.reshape %ileft_68 : tensor<32x2x2xi32> -> tensor<8x16xi32> loc(#loc242) + %iright_74 = tt.reshape %iright_72 : tensor<32x2x2xi32> -> tensor<8x16xi32> loc(#loc243) + %y_idx_75 = tt.reshape %new_idxs_60 : tensor<8x16xi32> -> tensor<32x2x2xi32> loc(#loc244) + %left_idx_76 = arith.muli %y_idx_75, %ileft_64 : tensor<32x2x2xi32> loc(#loc246) + %left_idx_77 = "tt.reduce"(%left_idx_76) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc247)), %left_idx_714: i32 loc(callsite(#loc1 at #loc247))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc327) + tt.reduce.return %left_idx_715 : i32 loc(#loc312) + }) : (tensor<32x2x2xi32>) -> tensor<32x2xi32> loc(#loc312) + %left_idx_78 = tt.expand_dims %left_idx_77 {axis = 1 : i32} : tensor<32x2xi32> -> tensor<32x1x2xi32> loc(#loc248) + %left_idx_79 = tt.broadcast %left_idx_78 : tensor<32x1x2xi32> -> tensor<32x2x2xi32> loc(#loc249) + %right_idx_80 = arith.muli %y_idx_75, %flip_24 : tensor<32x2x2xi32> loc(#loc251) + %right_idx_81 = "tt.reduce"(%right_idx_80) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc252)), %right_idx_714: i32 loc(callsite(#loc1 at #loc252))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc328) + tt.reduce.return %right_idx_715 : i32 loc(#loc315) + }) : (tensor<32x2x2xi32>) -> tensor<32x2xi32> loc(#loc315) + %right_idx_82 = tt.expand_dims %right_idx_81 {axis = 1 : i32} : tensor<32x2xi32> -> tensor<32x1x2xi32> loc(#loc253) + %right_idx_83 = tt.broadcast %right_idx_82 : tensor<32x1x2xi32> -> tensor<32x2x2xi32> loc(#loc254) + %left_idx_84 = tt.reshape %left_idx_79 : tensor<32x2x2xi32> -> tensor<8x16xi32> loc(#loc255) + %right_idx_85 = tt.reshape %right_idx_83 : tensor<32x2x2xi32> -> tensor<8x16xi32> loc(#loc256) + %cond_86 = arith.cmpi slt, %ileft_73, %iright_74 : tensor<8x16xi32> loc(#loc257) + %eq_87 = arith.cmpi eq, %ileft_73, %iright_74 : tensor<8x16xi32> loc(#loc258) + %cond_88 = arith.cmpi sgt, %left_idx_84, %right_idx_85 : tensor<8x16xi32> loc(#loc259) + %cond_89 = arith.andi %eq_87, %cond_88 : tensor<8x16xi1> loc(#loc260) + %cond_90 = arith.ori %cond_86, %cond_89 : tensor<8x16xi1> loc(#loc261) + %cond_91 = arith.extui %cond_90 : tensor<8x16xi1> to tensor<8x16xi32> loc(#loc262) + %cond_92 = arith.xori %cond_91, %flip_62 : tensor<8x16xi32> loc(#loc262) + %cond_93 = arith.cmpi ne, %cond_92, %cst_3 : tensor<8x16xi32> loc(#loc263) + %ret_94 = arith.xori %ileft_73, %iright_74 : tensor<8x16xi32> loc(#loc264) + %ret_95 = arith.select %cond_93, %ret_94, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc265) + %ret_96 = arith.xori %ret_56, %ret_95 : tensor<8x16xi32> loc(#loc266) + %new_idxs_97 = arith.xori %left_idx_84, %right_idx_85 : tensor<8x16xi32> loc(#loc267) + %new_idxs_98 = arith.select %cond_93, %new_idxs_97, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc268) + %new_idxs_99 = arith.xori %new_idxs_60, %new_idxs_98 : tensor<8x16xi32> loc(#loc269) + %y_100 = tt.reshape %ret_96 : tensor<8x16xi32> -> tensor<64x2x1xi32> loc(#loc232) + %ileft_101 = arith.muli %y_100, %ileft : tensor<64x2x1xi32> loc(#loc234) + %ileft_102 = "tt.reduce"(%ileft_101) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc235)), %ileft_714: i32 loc(callsite(#loc1 at #loc235))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc325) + tt.reduce.return %ileft_715 : i32 loc(#loc307) + }) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc307) + %ileft_103 = tt.expand_dims %ileft_102 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc236) + %ileft_104 = tt.broadcast %ileft_103 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc237) + %iright_105 = arith.muli %y_100, %iright : tensor<64x2x1xi32> loc(#loc238) + %iright_106 = "tt.reduce"(%iright_105) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc239)), %iright_714: i32 loc(callsite(#loc1 at #loc239))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc326) + tt.reduce.return %iright_715 : i32 loc(#loc309) + }) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc309) + %iright_107 = tt.expand_dims %iright_106 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc240) + %iright_108 = tt.broadcast %iright_107 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc241) + %ileft_109 = tt.reshape %ileft_104 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc242) + %iright_110 = tt.reshape %iright_108 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc243) + %y_idx_111 = tt.reshape %new_idxs_99 : tensor<8x16xi32> -> tensor<64x2x1xi32> loc(#loc244) + %left_idx_112 = arith.muli %y_idx_111, %ileft : tensor<64x2x1xi32> loc(#loc246) + %left_idx_113 = "tt.reduce"(%left_idx_112) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc247)), %left_idx_714: i32 loc(callsite(#loc1 at #loc247))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc327) + tt.reduce.return %left_idx_715 : i32 loc(#loc312) + }) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc312) + %left_idx_114 = tt.expand_dims %left_idx_113 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc248) + %left_idx_115 = tt.broadcast %left_idx_114 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc249) + %right_idx_116 = arith.muli %y_idx_111, %iright : tensor<64x2x1xi32> loc(#loc251) + %right_idx_117 = "tt.reduce"(%right_idx_116) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc252)), %right_idx_714: i32 loc(callsite(#loc1 at #loc252))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc328) + tt.reduce.return %right_idx_715 : i32 loc(#loc315) + }) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc315) + %right_idx_118 = tt.expand_dims %right_idx_117 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc253) + %right_idx_119 = tt.broadcast %right_idx_118 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc254) + %left_idx_120 = tt.reshape %left_idx_115 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc255) + %right_idx_121 = tt.reshape %right_idx_119 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc256) + %cond_122 = arith.cmpi slt, %ileft_109, %iright_110 : tensor<8x16xi32> loc(#loc257) + %eq_123 = arith.cmpi eq, %ileft_109, %iright_110 : tensor<8x16xi32> loc(#loc258) + %cond_124 = arith.cmpi sgt, %left_idx_120, %right_idx_121 : tensor<8x16xi32> loc(#loc259) + %cond_125 = arith.andi %eq_123, %cond_124 : tensor<8x16xi1> loc(#loc260) + %cond_126 = arith.ori %cond_122, %cond_125 : tensor<8x16xi1> loc(#loc261) + %cond_127 = arith.extui %cond_126 : tensor<8x16xi1> to tensor<8x16xi32> loc(#loc262) + %cond_128 = arith.xori %cond_127, %flip_62 : tensor<8x16xi32> loc(#loc262) + %cond_129 = arith.cmpi ne, %cond_128, %cst_3 : tensor<8x16xi32> loc(#loc263) + %ret_130 = arith.xori %ileft_109, %iright_110 : tensor<8x16xi32> loc(#loc264) + %ret_131 = arith.select %cond_129, %ret_130, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc265) + %ret_132 = arith.xori %ret_96, %ret_131 : tensor<8x16xi32> loc(#loc266) + %new_idxs_133 = arith.xori %left_idx_120, %right_idx_121 : tensor<8x16xi32> loc(#loc267) + %new_idxs_134 = arith.select %cond_129, %new_idxs_133, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc268) + %new_idxs_135 = arith.xori %new_idxs_99, %new_idxs_134 : tensor<8x16xi32> loc(#loc269) + %flip_136 = tt.broadcast %flip_23 : tensor<1x2x1xi32> -> tensor<8x2x8xi32> loc(#loc219) + %flip_137 = tt.reshape %flip_136 : tensor<8x2x8xi32> -> tensor<8x16xi32> loc(#loc220) + %y_138 = tt.reshape %ret_132 : tensor<8x16xi32> -> tensor<16x2x4xi32> loc(#loc232) + %ileft_139 = tt.broadcast %left_mask : tensor<1x2x1xi32> -> tensor<16x2x4xi32> loc(#loc234) + %ileft_140 = arith.muli %y_138, %ileft_139 : tensor<16x2x4xi32> loc(#loc234) + %ileft_141 = "tt.reduce"(%ileft_140) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc235)), %ileft_714: i32 loc(callsite(#loc1 at #loc235))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc325) + tt.reduce.return %ileft_715 : i32 loc(#loc307) + }) : (tensor<16x2x4xi32>) -> tensor<16x4xi32> loc(#loc307) + %ileft_142 = tt.expand_dims %ileft_141 {axis = 1 : i32} : tensor<16x4xi32> -> tensor<16x1x4xi32> loc(#loc236) + %ileft_143 = tt.broadcast %ileft_142 : tensor<16x1x4xi32> -> tensor<16x2x4xi32> loc(#loc237) + %iright_144 = arith.muli %y_138, %flip_61 : tensor<16x2x4xi32> loc(#loc238) + %iright_145 = "tt.reduce"(%iright_144) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc239)), %iright_714: i32 loc(callsite(#loc1 at #loc239))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc326) + tt.reduce.return %iright_715 : i32 loc(#loc309) + }) : (tensor<16x2x4xi32>) -> tensor<16x4xi32> loc(#loc309) + %iright_146 = tt.expand_dims %iright_145 {axis = 1 : i32} : tensor<16x4xi32> -> tensor<16x1x4xi32> loc(#loc240) + %iright_147 = tt.broadcast %iright_146 : tensor<16x1x4xi32> -> tensor<16x2x4xi32> loc(#loc241) + %ileft_148 = tt.reshape %ileft_143 : tensor<16x2x4xi32> -> tensor<8x16xi32> loc(#loc242) + %iright_149 = tt.reshape %iright_147 : tensor<16x2x4xi32> -> tensor<8x16xi32> loc(#loc243) + %y_idx_150 = tt.reshape %new_idxs_135 : tensor<8x16xi32> -> tensor<16x2x4xi32> loc(#loc244) + %left_idx_151 = arith.muli %y_idx_150, %ileft_139 : tensor<16x2x4xi32> loc(#loc246) + %left_idx_152 = "tt.reduce"(%left_idx_151) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc247)), %left_idx_714: i32 loc(callsite(#loc1 at #loc247))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc327) + tt.reduce.return %left_idx_715 : i32 loc(#loc312) + }) : (tensor<16x2x4xi32>) -> tensor<16x4xi32> loc(#loc312) + %left_idx_153 = tt.expand_dims %left_idx_152 {axis = 1 : i32} : tensor<16x4xi32> -> tensor<16x1x4xi32> loc(#loc248) + %left_idx_154 = tt.broadcast %left_idx_153 : tensor<16x1x4xi32> -> tensor<16x2x4xi32> loc(#loc249) + %right_idx_155 = arith.muli %y_idx_150, %flip_61 : tensor<16x2x4xi32> loc(#loc251) + %right_idx_156 = "tt.reduce"(%right_idx_155) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc252)), %right_idx_714: i32 loc(callsite(#loc1 at #loc252))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc328) + tt.reduce.return %right_idx_715 : i32 loc(#loc315) + }) : (tensor<16x2x4xi32>) -> tensor<16x4xi32> loc(#loc315) + %right_idx_157 = tt.expand_dims %right_idx_156 {axis = 1 : i32} : tensor<16x4xi32> -> tensor<16x1x4xi32> loc(#loc253) + %right_idx_158 = tt.broadcast %right_idx_157 : tensor<16x1x4xi32> -> tensor<16x2x4xi32> loc(#loc254) + %left_idx_159 = tt.reshape %left_idx_154 : tensor<16x2x4xi32> -> tensor<8x16xi32> loc(#loc255) + %right_idx_160 = tt.reshape %right_idx_158 : tensor<16x2x4xi32> -> tensor<8x16xi32> loc(#loc256) + %cond_161 = arith.cmpi slt, %ileft_148, %iright_149 : tensor<8x16xi32> loc(#loc257) + %eq_162 = arith.cmpi eq, %ileft_148, %iright_149 : tensor<8x16xi32> loc(#loc258) + %cond_163 = arith.cmpi sgt, %left_idx_159, %right_idx_160 : tensor<8x16xi32> loc(#loc259) + %cond_164 = arith.andi %eq_162, %cond_163 : tensor<8x16xi1> loc(#loc260) + %cond_165 = arith.ori %cond_161, %cond_164 : tensor<8x16xi1> loc(#loc261) + %cond_166 = arith.extui %cond_165 : tensor<8x16xi1> to tensor<8x16xi32> loc(#loc262) + %cond_167 = arith.xori %cond_166, %flip_137 : tensor<8x16xi32> loc(#loc262) + %cond_168 = arith.cmpi ne, %cond_167, %cst_3 : tensor<8x16xi32> loc(#loc263) + %ret_169 = arith.xori %ileft_148, %iright_149 : tensor<8x16xi32> loc(#loc264) + %ret_170 = arith.select %cond_168, %ret_169, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc265) + %ret_171 = arith.xori %ret_132, %ret_170 : tensor<8x16xi32> loc(#loc266) + %new_idxs_172 = arith.xori %left_idx_159, %right_idx_160 : tensor<8x16xi32> loc(#loc267) + %new_idxs_173 = arith.select %cond_168, %new_idxs_172, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc268) + %new_idxs_174 = arith.xori %new_idxs_135, %new_idxs_173 : tensor<8x16xi32> loc(#loc269) + %y_175 = tt.reshape %ret_171 : tensor<8x16xi32> -> tensor<32x2x2xi32> loc(#loc232) + %ileft_176 = arith.muli %y_175, %ileft_64 : tensor<32x2x2xi32> loc(#loc234) + %ileft_177 = "tt.reduce"(%ileft_176) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc235)), %ileft_714: i32 loc(callsite(#loc1 at #loc235))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc325) + tt.reduce.return %ileft_715 : i32 loc(#loc307) + }) : (tensor<32x2x2xi32>) -> tensor<32x2xi32> loc(#loc307) + %ileft_178 = tt.expand_dims %ileft_177 {axis = 1 : i32} : tensor<32x2xi32> -> tensor<32x1x2xi32> loc(#loc236) + %ileft_179 = tt.broadcast %ileft_178 : tensor<32x1x2xi32> -> tensor<32x2x2xi32> loc(#loc237) + %iright_180 = arith.muli %y_175, %flip_24 : tensor<32x2x2xi32> loc(#loc238) + %iright_181 = "tt.reduce"(%iright_180) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc239)), %iright_714: i32 loc(callsite(#loc1 at #loc239))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc326) + tt.reduce.return %iright_715 : i32 loc(#loc309) + }) : (tensor<32x2x2xi32>) -> tensor<32x2xi32> loc(#loc309) + %iright_182 = tt.expand_dims %iright_181 {axis = 1 : i32} : tensor<32x2xi32> -> tensor<32x1x2xi32> loc(#loc240) + %iright_183 = tt.broadcast %iright_182 : tensor<32x1x2xi32> -> tensor<32x2x2xi32> loc(#loc241) + %ileft_184 = tt.reshape %ileft_179 : tensor<32x2x2xi32> -> tensor<8x16xi32> loc(#loc242) + %iright_185 = tt.reshape %iright_183 : tensor<32x2x2xi32> -> tensor<8x16xi32> loc(#loc243) + %y_idx_186 = tt.reshape %new_idxs_174 : tensor<8x16xi32> -> tensor<32x2x2xi32> loc(#loc244) + %left_idx_187 = arith.muli %y_idx_186, %ileft_64 : tensor<32x2x2xi32> loc(#loc246) + %left_idx_188 = "tt.reduce"(%left_idx_187) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc247)), %left_idx_714: i32 loc(callsite(#loc1 at #loc247))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc327) + tt.reduce.return %left_idx_715 : i32 loc(#loc312) + }) : (tensor<32x2x2xi32>) -> tensor<32x2xi32> loc(#loc312) + %left_idx_189 = tt.expand_dims %left_idx_188 {axis = 1 : i32} : tensor<32x2xi32> -> tensor<32x1x2xi32> loc(#loc248) + %left_idx_190 = tt.broadcast %left_idx_189 : tensor<32x1x2xi32> -> tensor<32x2x2xi32> loc(#loc249) + %right_idx_191 = arith.muli %y_idx_186, %flip_24 : tensor<32x2x2xi32> loc(#loc251) + %right_idx_192 = "tt.reduce"(%right_idx_191) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc252)), %right_idx_714: i32 loc(callsite(#loc1 at #loc252))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc328) + tt.reduce.return %right_idx_715 : i32 loc(#loc315) + }) : (tensor<32x2x2xi32>) -> tensor<32x2xi32> loc(#loc315) + %right_idx_193 = tt.expand_dims %right_idx_192 {axis = 1 : i32} : tensor<32x2xi32> -> tensor<32x1x2xi32> loc(#loc253) + %right_idx_194 = tt.broadcast %right_idx_193 : tensor<32x1x2xi32> -> tensor<32x2x2xi32> loc(#loc254) + %left_idx_195 = tt.reshape %left_idx_190 : tensor<32x2x2xi32> -> tensor<8x16xi32> loc(#loc255) + %right_idx_196 = tt.reshape %right_idx_194 : tensor<32x2x2xi32> -> tensor<8x16xi32> loc(#loc256) + %cond_197 = arith.cmpi slt, %ileft_184, %iright_185 : tensor<8x16xi32> loc(#loc257) + %eq_198 = arith.cmpi eq, %ileft_184, %iright_185 : tensor<8x16xi32> loc(#loc258) + %cond_199 = arith.cmpi sgt, %left_idx_195, %right_idx_196 : tensor<8x16xi32> loc(#loc259) + %cond_200 = arith.andi %eq_198, %cond_199 : tensor<8x16xi1> loc(#loc260) + %cond_201 = arith.ori %cond_197, %cond_200 : tensor<8x16xi1> loc(#loc261) + %cond_202 = arith.extui %cond_201 : tensor<8x16xi1> to tensor<8x16xi32> loc(#loc262) + %cond_203 = arith.xori %cond_202, %flip_137 : tensor<8x16xi32> loc(#loc262) + %cond_204 = arith.cmpi ne, %cond_203, %cst_3 : tensor<8x16xi32> loc(#loc263) + %ret_205 = arith.xori %ileft_184, %iright_185 : tensor<8x16xi32> loc(#loc264) + %ret_206 = arith.select %cond_204, %ret_205, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc265) + %ret_207 = arith.xori %ret_171, %ret_206 : tensor<8x16xi32> loc(#loc266) + %new_idxs_208 = arith.xori %left_idx_195, %right_idx_196 : tensor<8x16xi32> loc(#loc267) + %new_idxs_209 = arith.select %cond_204, %new_idxs_208, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc268) + %new_idxs_210 = arith.xori %new_idxs_174, %new_idxs_209 : tensor<8x16xi32> loc(#loc269) + %y_211 = tt.reshape %ret_207 : tensor<8x16xi32> -> tensor<64x2x1xi32> loc(#loc232) + %ileft_212 = arith.muli %y_211, %ileft : tensor<64x2x1xi32> loc(#loc234) + %ileft_213 = "tt.reduce"(%ileft_212) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc235)), %ileft_714: i32 loc(callsite(#loc1 at #loc235))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc325) + tt.reduce.return %ileft_715 : i32 loc(#loc307) + }) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc307) + %ileft_214 = tt.expand_dims %ileft_213 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc236) + %ileft_215 = tt.broadcast %ileft_214 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc237) + %iright_216 = arith.muli %y_211, %iright : tensor<64x2x1xi32> loc(#loc238) + %iright_217 = "tt.reduce"(%iright_216) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc239)), %iright_714: i32 loc(callsite(#loc1 at #loc239))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc326) + tt.reduce.return %iright_715 : i32 loc(#loc309) + }) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc309) + %iright_218 = tt.expand_dims %iright_217 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc240) + %iright_219 = tt.broadcast %iright_218 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc241) + %ileft_220 = tt.reshape %ileft_215 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc242) + %iright_221 = tt.reshape %iright_219 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc243) + %y_idx_222 = tt.reshape %new_idxs_210 : tensor<8x16xi32> -> tensor<64x2x1xi32> loc(#loc244) + %left_idx_223 = arith.muli %y_idx_222, %ileft : tensor<64x2x1xi32> loc(#loc246) + %left_idx_224 = "tt.reduce"(%left_idx_223) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc247)), %left_idx_714: i32 loc(callsite(#loc1 at #loc247))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc327) + tt.reduce.return %left_idx_715 : i32 loc(#loc312) + }) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc312) + %left_idx_225 = tt.expand_dims %left_idx_224 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc248) + %left_idx_226 = tt.broadcast %left_idx_225 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc249) + %right_idx_227 = arith.muli %y_idx_222, %iright : tensor<64x2x1xi32> loc(#loc251) + %right_idx_228 = "tt.reduce"(%right_idx_227) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc252)), %right_idx_714: i32 loc(callsite(#loc1 at #loc252))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc328) + tt.reduce.return %right_idx_715 : i32 loc(#loc315) + }) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc315) + %right_idx_229 = tt.expand_dims %right_idx_228 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc253) + %right_idx_230 = tt.broadcast %right_idx_229 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc254) + %left_idx_231 = tt.reshape %left_idx_226 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc255) + %right_idx_232 = tt.reshape %right_idx_230 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc256) + %cond_233 = arith.cmpi slt, %ileft_220, %iright_221 : tensor<8x16xi32> loc(#loc257) + %eq_234 = arith.cmpi eq, %ileft_220, %iright_221 : tensor<8x16xi32> loc(#loc258) + %cond_235 = arith.cmpi sgt, %left_idx_231, %right_idx_232 : tensor<8x16xi32> loc(#loc259) + %cond_236 = arith.andi %eq_234, %cond_235 : tensor<8x16xi1> loc(#loc260) + %cond_237 = arith.ori %cond_233, %cond_236 : tensor<8x16xi1> loc(#loc261) + %cond_238 = arith.extui %cond_237 : tensor<8x16xi1> to tensor<8x16xi32> loc(#loc262) + %cond_239 = arith.xori %cond_238, %flip_137 : tensor<8x16xi32> loc(#loc262) + %cond_240 = arith.cmpi ne, %cond_239, %cst_3 : tensor<8x16xi32> loc(#loc263) + %ret_241 = arith.xori %ileft_220, %iright_221 : tensor<8x16xi32> loc(#loc264) + %ret_242 = arith.select %cond_240, %ret_241, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc265) + %ret_243 = arith.xori %ret_207, %ret_242 : tensor<8x16xi32> loc(#loc266) + %new_idxs_244 = arith.xori %left_idx_231, %right_idx_232 : tensor<8x16xi32> loc(#loc267) + %new_idxs_245 = arith.select %cond_240, %new_idxs_244, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc268) + %new_idxs_246 = arith.xori %new_idxs_210, %new_idxs_245 : tensor<8x16xi32> loc(#loc269) + %y_247 = tt.reshape %ret_243 : tensor<8x16xi32> -> tensor<8x2x8xi32> loc(#loc232) + %ileft_248 = tt.broadcast %left_mask : tensor<1x2x1xi32> -> tensor<8x2x8xi32> loc(#loc234) + %ileft_249 = arith.muli %y_247, %ileft_248 : tensor<8x2x8xi32> loc(#loc234) + %ileft_250 = "tt.reduce"(%ileft_249) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc235)), %ileft_714: i32 loc(callsite(#loc1 at #loc235))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc325) + tt.reduce.return %ileft_715 : i32 loc(#loc307) + }) : (tensor<8x2x8xi32>) -> tensor<8x8xi32> loc(#loc307) + %ileft_251 = tt.expand_dims %ileft_250 {axis = 1 : i32} : tensor<8x8xi32> -> tensor<8x1x8xi32> loc(#loc236) + %ileft_252 = tt.broadcast %ileft_251 : tensor<8x1x8xi32> -> tensor<8x2x8xi32> loc(#loc237) + %iright_253 = arith.muli %y_247, %flip_136 : tensor<8x2x8xi32> loc(#loc238) + %iright_254 = "tt.reduce"(%iright_253) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc239)), %iright_714: i32 loc(callsite(#loc1 at #loc239))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc326) + tt.reduce.return %iright_715 : i32 loc(#loc309) + }) : (tensor<8x2x8xi32>) -> tensor<8x8xi32> loc(#loc309) + %iright_255 = tt.expand_dims %iright_254 {axis = 1 : i32} : tensor<8x8xi32> -> tensor<8x1x8xi32> loc(#loc240) + %iright_256 = tt.broadcast %iright_255 : tensor<8x1x8xi32> -> tensor<8x2x8xi32> loc(#loc241) + %ileft_257 = tt.reshape %ileft_252 : tensor<8x2x8xi32> -> tensor<8x16xi32> loc(#loc242) + %iright_258 = tt.reshape %iright_256 : tensor<8x2x8xi32> -> tensor<8x16xi32> loc(#loc243) + %y_idx_259 = tt.reshape %new_idxs_246 : tensor<8x16xi32> -> tensor<8x2x8xi32> loc(#loc244) + %left_idx_260 = arith.muli %y_idx_259, %ileft_248 : tensor<8x2x8xi32> loc(#loc246) + %left_idx_261 = "tt.reduce"(%left_idx_260) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc247)), %left_idx_714: i32 loc(callsite(#loc1 at #loc247))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc327) + tt.reduce.return %left_idx_715 : i32 loc(#loc312) + }) : (tensor<8x2x8xi32>) -> tensor<8x8xi32> loc(#loc312) + %left_idx_262 = tt.expand_dims %left_idx_261 {axis = 1 : i32} : tensor<8x8xi32> -> tensor<8x1x8xi32> loc(#loc248) + %left_idx_263 = tt.broadcast %left_idx_262 : tensor<8x1x8xi32> -> tensor<8x2x8xi32> loc(#loc249) + %right_idx_264 = arith.muli %y_idx_259, %flip_136 : tensor<8x2x8xi32> loc(#loc251) + %right_idx_265 = "tt.reduce"(%right_idx_264) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc252)), %right_idx_714: i32 loc(callsite(#loc1 at #loc252))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc328) + tt.reduce.return %right_idx_715 : i32 loc(#loc315) + }) : (tensor<8x2x8xi32>) -> tensor<8x8xi32> loc(#loc315) + %right_idx_266 = tt.expand_dims %right_idx_265 {axis = 1 : i32} : tensor<8x8xi32> -> tensor<8x1x8xi32> loc(#loc253) + %right_idx_267 = tt.broadcast %right_idx_266 : tensor<8x1x8xi32> -> tensor<8x2x8xi32> loc(#loc254) + %left_idx_268 = tt.reshape %left_idx_263 : tensor<8x2x8xi32> -> tensor<8x16xi32> loc(#loc255) + %right_idx_269 = tt.reshape %right_idx_267 : tensor<8x2x8xi32> -> tensor<8x16xi32> loc(#loc256) + %cond_270 = arith.cmpi slt, %ileft_257, %iright_258 : tensor<8x16xi32> loc(#loc257) + %eq_271 = arith.cmpi eq, %ileft_257, %iright_258 : tensor<8x16xi32> loc(#loc258) + %cond_272 = arith.cmpi sgt, %left_idx_268, %right_idx_269 : tensor<8x16xi32> loc(#loc259) + %cond_273 = arith.andi %eq_271, %cond_272 : tensor<8x16xi1> loc(#loc260) + %cond_274 = arith.ori %cond_270, %cond_273 : tensor<8x16xi1> loc(#loc261) + %ret_275 = arith.xori %ileft_257, %iright_258 : tensor<8x16xi32> loc(#loc264) + %ret_276 = arith.select %cond_274, %ret_275, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc265) + %ret_277 = arith.xori %ret_243, %ret_276 : tensor<8x16xi32> loc(#loc266) + %new_idxs_278 = arith.xori %left_idx_268, %right_idx_269 : tensor<8x16xi32> loc(#loc267) + %new_idxs_279 = arith.select %cond_274, %new_idxs_278, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc268) + %new_idxs_280 = arith.xori %new_idxs_246, %new_idxs_279 : tensor<8x16xi32> loc(#loc269) + %y_281 = tt.reshape %ret_277 : tensor<8x16xi32> -> tensor<16x2x4xi32> loc(#loc232) + %ileft_282 = arith.muli %y_281, %ileft_139 : tensor<16x2x4xi32> loc(#loc234) + %ileft_283 = "tt.reduce"(%ileft_282) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc235)), %ileft_714: i32 loc(callsite(#loc1 at #loc235))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc325) + tt.reduce.return %ileft_715 : i32 loc(#loc307) + }) : (tensor<16x2x4xi32>) -> tensor<16x4xi32> loc(#loc307) + %ileft_284 = tt.expand_dims %ileft_283 {axis = 1 : i32} : tensor<16x4xi32> -> tensor<16x1x4xi32> loc(#loc236) + %ileft_285 = tt.broadcast %ileft_284 : tensor<16x1x4xi32> -> tensor<16x2x4xi32> loc(#loc237) + %iright_286 = arith.muli %y_281, %flip_61 : tensor<16x2x4xi32> loc(#loc238) + %iright_287 = "tt.reduce"(%iright_286) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc239)), %iright_714: i32 loc(callsite(#loc1 at #loc239))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc326) + tt.reduce.return %iright_715 : i32 loc(#loc309) + }) : (tensor<16x2x4xi32>) -> tensor<16x4xi32> loc(#loc309) + %iright_288 = tt.expand_dims %iright_287 {axis = 1 : i32} : tensor<16x4xi32> -> tensor<16x1x4xi32> loc(#loc240) + %iright_289 = tt.broadcast %iright_288 : tensor<16x1x4xi32> -> tensor<16x2x4xi32> loc(#loc241) + %ileft_290 = tt.reshape %ileft_285 : tensor<16x2x4xi32> -> tensor<8x16xi32> loc(#loc242) + %iright_291 = tt.reshape %iright_289 : tensor<16x2x4xi32> -> tensor<8x16xi32> loc(#loc243) + %y_idx_292 = tt.reshape %new_idxs_280 : tensor<8x16xi32> -> tensor<16x2x4xi32> loc(#loc244) + %left_idx_293 = arith.muli %y_idx_292, %ileft_139 : tensor<16x2x4xi32> loc(#loc246) + %left_idx_294 = "tt.reduce"(%left_idx_293) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc247)), %left_idx_714: i32 loc(callsite(#loc1 at #loc247))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc327) + tt.reduce.return %left_idx_715 : i32 loc(#loc312) + }) : (tensor<16x2x4xi32>) -> tensor<16x4xi32> loc(#loc312) + %left_idx_295 = tt.expand_dims %left_idx_294 {axis = 1 : i32} : tensor<16x4xi32> -> tensor<16x1x4xi32> loc(#loc248) + %left_idx_296 = tt.broadcast %left_idx_295 : tensor<16x1x4xi32> -> tensor<16x2x4xi32> loc(#loc249) + %right_idx_297 = arith.muli %y_idx_292, %flip_61 : tensor<16x2x4xi32> loc(#loc251) + %right_idx_298 = "tt.reduce"(%right_idx_297) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc252)), %right_idx_714: i32 loc(callsite(#loc1 at #loc252))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc328) + tt.reduce.return %right_idx_715 : i32 loc(#loc315) + }) : (tensor<16x2x4xi32>) -> tensor<16x4xi32> loc(#loc315) + %right_idx_299 = tt.expand_dims %right_idx_298 {axis = 1 : i32} : tensor<16x4xi32> -> tensor<16x1x4xi32> loc(#loc253) + %right_idx_300 = tt.broadcast %right_idx_299 : tensor<16x1x4xi32> -> tensor<16x2x4xi32> loc(#loc254) + %left_idx_301 = tt.reshape %left_idx_296 : tensor<16x2x4xi32> -> tensor<8x16xi32> loc(#loc255) + %right_idx_302 = tt.reshape %right_idx_300 : tensor<16x2x4xi32> -> tensor<8x16xi32> loc(#loc256) + %cond_303 = arith.cmpi slt, %ileft_290, %iright_291 : tensor<8x16xi32> loc(#loc257) + %eq_304 = arith.cmpi eq, %ileft_290, %iright_291 : tensor<8x16xi32> loc(#loc258) + %cond_305 = arith.cmpi sgt, %left_idx_301, %right_idx_302 : tensor<8x16xi32> loc(#loc259) + %cond_306 = arith.andi %eq_304, %cond_305 : tensor<8x16xi1> loc(#loc260) + %cond_307 = arith.ori %cond_303, %cond_306 : tensor<8x16xi1> loc(#loc261) + %ret_308 = arith.xori %ileft_290, %iright_291 : tensor<8x16xi32> loc(#loc264) + %ret_309 = arith.select %cond_307, %ret_308, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc265) + %ret_310 = arith.xori %ret_277, %ret_309 : tensor<8x16xi32> loc(#loc266) + %new_idxs_311 = arith.xori %left_idx_301, %right_idx_302 : tensor<8x16xi32> loc(#loc267) + %new_idxs_312 = arith.select %cond_307, %new_idxs_311, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc268) + %new_idxs_313 = arith.xori %new_idxs_280, %new_idxs_312 : tensor<8x16xi32> loc(#loc269) + %y_314 = tt.reshape %ret_310 : tensor<8x16xi32> -> tensor<32x2x2xi32> loc(#loc232) + %ileft_315 = arith.muli %y_314, %ileft_64 : tensor<32x2x2xi32> loc(#loc234) + %ileft_316 = "tt.reduce"(%ileft_315) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc235)), %ileft_714: i32 loc(callsite(#loc1 at #loc235))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc325) + tt.reduce.return %ileft_715 : i32 loc(#loc307) + }) : (tensor<32x2x2xi32>) -> tensor<32x2xi32> loc(#loc307) + %ileft_317 = tt.expand_dims %ileft_316 {axis = 1 : i32} : tensor<32x2xi32> -> tensor<32x1x2xi32> loc(#loc236) + %ileft_318 = tt.broadcast %ileft_317 : tensor<32x1x2xi32> -> tensor<32x2x2xi32> loc(#loc237) + %iright_319 = arith.muli %y_314, %flip_24 : tensor<32x2x2xi32> loc(#loc238) + %iright_320 = "tt.reduce"(%iright_319) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc239)), %iright_714: i32 loc(callsite(#loc1 at #loc239))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc326) + tt.reduce.return %iright_715 : i32 loc(#loc309) + }) : (tensor<32x2x2xi32>) -> tensor<32x2xi32> loc(#loc309) + %iright_321 = tt.expand_dims %iright_320 {axis = 1 : i32} : tensor<32x2xi32> -> tensor<32x1x2xi32> loc(#loc240) + %iright_322 = tt.broadcast %iright_321 : tensor<32x1x2xi32> -> tensor<32x2x2xi32> loc(#loc241) + %ileft_323 = tt.reshape %ileft_318 : tensor<32x2x2xi32> -> tensor<8x16xi32> loc(#loc242) + %iright_324 = tt.reshape %iright_322 : tensor<32x2x2xi32> -> tensor<8x16xi32> loc(#loc243) + %y_idx_325 = tt.reshape %new_idxs_313 : tensor<8x16xi32> -> tensor<32x2x2xi32> loc(#loc244) + %left_idx_326 = arith.muli %y_idx_325, %ileft_64 : tensor<32x2x2xi32> loc(#loc246) + %left_idx_327 = "tt.reduce"(%left_idx_326) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc247)), %left_idx_714: i32 loc(callsite(#loc1 at #loc247))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc327) + tt.reduce.return %left_idx_715 : i32 loc(#loc312) + }) : (tensor<32x2x2xi32>) -> tensor<32x2xi32> loc(#loc312) + %left_idx_328 = tt.expand_dims %left_idx_327 {axis = 1 : i32} : tensor<32x2xi32> -> tensor<32x1x2xi32> loc(#loc248) + %left_idx_329 = tt.broadcast %left_idx_328 : tensor<32x1x2xi32> -> tensor<32x2x2xi32> loc(#loc249) + %right_idx_330 = arith.muli %y_idx_325, %flip_24 : tensor<32x2x2xi32> loc(#loc251) + %right_idx_331 = "tt.reduce"(%right_idx_330) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc252)), %right_idx_714: i32 loc(callsite(#loc1 at #loc252))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc328) + tt.reduce.return %right_idx_715 : i32 loc(#loc315) + }) : (tensor<32x2x2xi32>) -> tensor<32x2xi32> loc(#loc315) + %right_idx_332 = tt.expand_dims %right_idx_331 {axis = 1 : i32} : tensor<32x2xi32> -> tensor<32x1x2xi32> loc(#loc253) + %right_idx_333 = tt.broadcast %right_idx_332 : tensor<32x1x2xi32> -> tensor<32x2x2xi32> loc(#loc254) + %left_idx_334 = tt.reshape %left_idx_329 : tensor<32x2x2xi32> -> tensor<8x16xi32> loc(#loc255) + %right_idx_335 = tt.reshape %right_idx_333 : tensor<32x2x2xi32> -> tensor<8x16xi32> loc(#loc256) + %cond_336 = arith.cmpi slt, %ileft_323, %iright_324 : tensor<8x16xi32> loc(#loc257) + %eq_337 = arith.cmpi eq, %ileft_323, %iright_324 : tensor<8x16xi32> loc(#loc258) + %cond_338 = arith.cmpi sgt, %left_idx_334, %right_idx_335 : tensor<8x16xi32> loc(#loc259) + %cond_339 = arith.andi %eq_337, %cond_338 : tensor<8x16xi1> loc(#loc260) + %cond_340 = arith.ori %cond_336, %cond_339 : tensor<8x16xi1> loc(#loc261) + %ret_341 = arith.xori %ileft_323, %iright_324 : tensor<8x16xi32> loc(#loc264) + %ret_342 = arith.select %cond_340, %ret_341, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc265) + %ret_343 = arith.xori %ret_310, %ret_342 : tensor<8x16xi32> loc(#loc266) + %new_idxs_344 = arith.xori %left_idx_334, %right_idx_335 : tensor<8x16xi32> loc(#loc267) + %new_idxs_345 = arith.select %cond_340, %new_idxs_344, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc268) + %new_idxs_346 = arith.xori %new_idxs_313, %new_idxs_345 : tensor<8x16xi32> loc(#loc269) + %y_347 = tt.reshape %ret_343 : tensor<8x16xi32> -> tensor<64x2x1xi32> loc(#loc232) + %ileft_348 = arith.muli %y_347, %ileft : tensor<64x2x1xi32> loc(#loc234) + %ileft_349 = "tt.reduce"(%ileft_348) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc235)), %ileft_714: i32 loc(callsite(#loc1 at #loc235))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc325) + tt.reduce.return %ileft_715 : i32 loc(#loc307) + }) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc307) + %ileft_350 = tt.expand_dims %ileft_349 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc236) + %ileft_351 = tt.broadcast %ileft_350 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc237) + %iright_352 = arith.muli %y_347, %iright : tensor<64x2x1xi32> loc(#loc238) + %iright_353 = "tt.reduce"(%iright_352) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc239)), %iright_714: i32 loc(callsite(#loc1 at #loc239))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc326) + tt.reduce.return %iright_715 : i32 loc(#loc309) + }) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc309) + %iright_354 = tt.expand_dims %iright_353 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc240) + %iright_355 = tt.broadcast %iright_354 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc241) + %ileft_356 = tt.reshape %ileft_351 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc242) + %iright_357 = tt.reshape %iright_355 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc243) + %y_idx_358 = tt.reshape %new_idxs_346 : tensor<8x16xi32> -> tensor<64x2x1xi32> loc(#loc244) + %left_idx_359 = arith.muli %y_idx_358, %ileft : tensor<64x2x1xi32> loc(#loc246) + %left_idx_360 = "tt.reduce"(%left_idx_359) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc247)), %left_idx_714: i32 loc(callsite(#loc1 at #loc247))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc327) + tt.reduce.return %left_idx_715 : i32 loc(#loc312) + }) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc312) + %left_idx_361 = tt.expand_dims %left_idx_360 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc248) + %left_idx_362 = tt.broadcast %left_idx_361 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc249) + %right_idx_363 = arith.muli %y_idx_358, %iright : tensor<64x2x1xi32> loc(#loc251) + %right_idx_364 = "tt.reduce"(%right_idx_363) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc252)), %right_idx_714: i32 loc(callsite(#loc1 at #loc252))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc328) + tt.reduce.return %right_idx_715 : i32 loc(#loc315) + }) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc315) + %right_idx_365 = tt.expand_dims %right_idx_364 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc253) + %right_idx_366 = tt.broadcast %right_idx_365 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc254) + %left_idx_367 = tt.reshape %left_idx_362 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc255) + %right_idx_368 = tt.reshape %right_idx_366 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc256) + %cond_369 = arith.cmpi slt, %ileft_356, %iright_357 : tensor<8x16xi32> loc(#loc257) + %eq_370 = arith.cmpi eq, %ileft_356, %iright_357 : tensor<8x16xi32> loc(#loc258) + %cond_371 = arith.cmpi sgt, %left_idx_367, %right_idx_368 : tensor<8x16xi32> loc(#loc259) + %cond_372 = arith.andi %eq_370, %cond_371 : tensor<8x16xi1> loc(#loc260) + %cond_373 = arith.ori %cond_369, %cond_372 : tensor<8x16xi1> loc(#loc261) + %new_idxs_374 = arith.xori %left_idx_367, %right_idx_368 : tensor<8x16xi32> loc(#loc267) + %new_idxs_375 = arith.select %cond_373, %new_idxs_374, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc268) + %new_idxs_376 = arith.xori %new_idxs_346, %new_idxs_375 : tensor<8x16xi32> loc(#loc269) + %tmp14 = arith.cmpi eq, %tmp0_21, %cst_6 : tensor<8x16xi64> loc(#loc192) + %tmp16 = arith.extui %tmp14 : tensor<8x16xi1> to tensor<8x16xi32> loc(#loc224) + %y_377 = tt.reshape %tmp16 : tensor<8x16xi32> -> tensor<64x2x1xi32> loc(#loc270) + %ileft_378 = arith.muli %y_377, %ileft : tensor<64x2x1xi32> loc(#loc271) + %ileft_379 = "tt.reduce"(%ileft_378) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc272)), %ileft_714: i32 loc(callsite(#loc1 at #loc272))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc329) + tt.reduce.return %ileft_715 : i32 loc(#loc317) + }) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc317) + %ileft_380 = tt.expand_dims %ileft_379 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc273) + %ileft_381 = tt.broadcast %ileft_380 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc274) + %iright_382 = arith.muli %y_377, %iright : tensor<64x2x1xi32> loc(#loc275) + %iright_383 = "tt.reduce"(%iright_382) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc276)), %iright_714: i32 loc(callsite(#loc1 at #loc276))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc330) + tt.reduce.return %iright_715 : i32 loc(#loc319) + }) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc319) + %iright_384 = tt.expand_dims %iright_383 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc277) + %iright_385 = tt.broadcast %iright_384 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc278) + %ileft_386 = tt.reshape %ileft_381 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc279) + %iright_387 = tt.reshape %iright_385 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc280) + %cond_388 = arith.cmpi slt, %ileft_386, %iright_387 : tensor<8x16xi32> loc(#loc281) + %eq_389 = arith.cmpi eq, %ileft_386, %iright_387 : tensor<8x16xi32> loc(#loc282) + %cond_390 = arith.andi %eq_389, %cond_49 : tensor<8x16xi1> loc(#loc283) + %cond_391 = arith.ori %cond_388, %cond_390 : tensor<8x16xi1> loc(#loc284) + %cond_392 = arith.extui %cond_391 : tensor<8x16xi1> to tensor<8x16xi32> loc(#loc285) + %cond_393 = arith.xori %cond_392, %flip_25 : tensor<8x16xi32> loc(#loc285) + %cond_394 = arith.cmpi ne, %cond_393, %cst_3 : tensor<8x16xi32> loc(#loc286) + %ret_395 = arith.xori %ileft_386, %iright_387 : tensor<8x16xi32> loc(#loc287) + %ret_396 = arith.select %cond_394, %ret_395, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc288) + %ret_397 = arith.xori %tmp16, %ret_396 : tensor<8x16xi32> loc(#loc289) + %new_idxs_398 = arith.select %cond_394, %new_idxs, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc290) + %new_idxs_399 = arith.xori %new_idxs_59, %new_idxs_398 : tensor<8x16xi32> loc(#loc291) + %y_400 = tt.reshape %ret_397 : tensor<8x16xi32> -> tensor<32x2x2xi32> loc(#loc270) + %ileft_401 = arith.muli %y_400, %ileft_64 : tensor<32x2x2xi32> loc(#loc271) + %ileft_402 = "tt.reduce"(%ileft_401) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc272)), %ileft_714: i32 loc(callsite(#loc1 at #loc272))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc329) + tt.reduce.return %ileft_715 : i32 loc(#loc317) + }) : (tensor<32x2x2xi32>) -> tensor<32x2xi32> loc(#loc317) + %ileft_403 = tt.expand_dims %ileft_402 {axis = 1 : i32} : tensor<32x2xi32> -> tensor<32x1x2xi32> loc(#loc273) + %ileft_404 = tt.broadcast %ileft_403 : tensor<32x1x2xi32> -> tensor<32x2x2xi32> loc(#loc274) + %iright_405 = arith.muli %y_400, %flip_24 : tensor<32x2x2xi32> loc(#loc275) + %iright_406 = "tt.reduce"(%iright_405) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc276)), %iright_714: i32 loc(callsite(#loc1 at #loc276))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc330) + tt.reduce.return %iright_715 : i32 loc(#loc319) + }) : (tensor<32x2x2xi32>) -> tensor<32x2xi32> loc(#loc319) + %iright_407 = tt.expand_dims %iright_406 {axis = 1 : i32} : tensor<32x2xi32> -> tensor<32x1x2xi32> loc(#loc277) + %iright_408 = tt.broadcast %iright_407 : tensor<32x1x2xi32> -> tensor<32x2x2xi32> loc(#loc278) + %ileft_409 = tt.reshape %ileft_404 : tensor<32x2x2xi32> -> tensor<8x16xi32> loc(#loc279) + %iright_410 = tt.reshape %iright_408 : tensor<32x2x2xi32> -> tensor<8x16xi32> loc(#loc280) + %y_idx_411 = tt.reshape %new_idxs_399 : tensor<8x16xi32> -> tensor<32x2x2xi32> loc(#loc292) + %left_idx_412 = arith.muli %y_idx_411, %ileft_64 : tensor<32x2x2xi32> loc(#loc293) + %left_idx_413 = "tt.reduce"(%left_idx_412) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc294)), %left_idx_714: i32 loc(callsite(#loc1 at #loc294))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc331) + tt.reduce.return %left_idx_715 : i32 loc(#loc321) + }) : (tensor<32x2x2xi32>) -> tensor<32x2xi32> loc(#loc321) + %left_idx_414 = tt.expand_dims %left_idx_413 {axis = 1 : i32} : tensor<32x2xi32> -> tensor<32x1x2xi32> loc(#loc295) + %left_idx_415 = tt.broadcast %left_idx_414 : tensor<32x1x2xi32> -> tensor<32x2x2xi32> loc(#loc296) + %right_idx_416 = arith.muli %y_idx_411, %flip_24 : tensor<32x2x2xi32> loc(#loc297) + %right_idx_417 = "tt.reduce"(%right_idx_416) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc298)), %right_idx_714: i32 loc(callsite(#loc1 at #loc298))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc332) + tt.reduce.return %right_idx_715 : i32 loc(#loc323) + }) : (tensor<32x2x2xi32>) -> tensor<32x2xi32> loc(#loc323) + %right_idx_418 = tt.expand_dims %right_idx_417 {axis = 1 : i32} : tensor<32x2xi32> -> tensor<32x1x2xi32> loc(#loc299) + %right_idx_419 = tt.broadcast %right_idx_418 : tensor<32x1x2xi32> -> tensor<32x2x2xi32> loc(#loc300) + %left_idx_420 = tt.reshape %left_idx_415 : tensor<32x2x2xi32> -> tensor<8x16xi32> loc(#loc301) + %right_idx_421 = tt.reshape %right_idx_419 : tensor<32x2x2xi32> -> tensor<8x16xi32> loc(#loc302) + %cond_422 = arith.cmpi slt, %ileft_409, %iright_410 : tensor<8x16xi32> loc(#loc281) + %eq_423 = arith.cmpi eq, %ileft_409, %iright_410 : tensor<8x16xi32> loc(#loc282) + %cond_424 = arith.cmpi sgt, %left_idx_420, %right_idx_421 : tensor<8x16xi32> loc(#loc303) + %cond_425 = arith.andi %eq_423, %cond_424 : tensor<8x16xi1> loc(#loc283) + %cond_426 = arith.ori %cond_422, %cond_425 : tensor<8x16xi1> loc(#loc284) + %cond_427 = arith.extui %cond_426 : tensor<8x16xi1> to tensor<8x16xi32> loc(#loc285) + %cond_428 = arith.xori %cond_427, %flip_62 : tensor<8x16xi32> loc(#loc285) + %cond_429 = arith.cmpi ne, %cond_428, %cst_3 : tensor<8x16xi32> loc(#loc286) + %ret_430 = arith.xori %ileft_409, %iright_410 : tensor<8x16xi32> loc(#loc287) + %ret_431 = arith.select %cond_429, %ret_430, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc288) + %ret_432 = arith.xori %ret_397, %ret_431 : tensor<8x16xi32> loc(#loc289) + %new_idxs_433 = arith.xori %left_idx_420, %right_idx_421 : tensor<8x16xi32> loc(#loc304) + %new_idxs_434 = arith.select %cond_429, %new_idxs_433, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc290) + %new_idxs_435 = arith.xori %new_idxs_399, %new_idxs_434 : tensor<8x16xi32> loc(#loc291) + %y_436 = tt.reshape %ret_432 : tensor<8x16xi32> -> tensor<64x2x1xi32> loc(#loc270) + %ileft_437 = arith.muli %y_436, %ileft : tensor<64x2x1xi32> loc(#loc271) + %ileft_438 = "tt.reduce"(%ileft_437) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc272)), %ileft_714: i32 loc(callsite(#loc1 at #loc272))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc329) + tt.reduce.return %ileft_715 : i32 loc(#loc317) + }) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc317) + %ileft_439 = tt.expand_dims %ileft_438 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc273) + %ileft_440 = tt.broadcast %ileft_439 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc274) + %iright_441 = arith.muli %y_436, %iright : tensor<64x2x1xi32> loc(#loc275) + %iright_442 = "tt.reduce"(%iright_441) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc276)), %iright_714: i32 loc(callsite(#loc1 at #loc276))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc330) + tt.reduce.return %iright_715 : i32 loc(#loc319) + }) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc319) + %iright_443 = tt.expand_dims %iright_442 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc277) + %iright_444 = tt.broadcast %iright_443 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc278) + %ileft_445 = tt.reshape %ileft_440 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc279) + %iright_446 = tt.reshape %iright_444 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc280) + %y_idx_447 = tt.reshape %new_idxs_435 : tensor<8x16xi32> -> tensor<64x2x1xi32> loc(#loc292) + %left_idx_448 = arith.muli %y_idx_447, %ileft : tensor<64x2x1xi32> loc(#loc293) + %left_idx_449 = "tt.reduce"(%left_idx_448) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc294)), %left_idx_714: i32 loc(callsite(#loc1 at #loc294))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc331) + tt.reduce.return %left_idx_715 : i32 loc(#loc321) + }) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc321) + %left_idx_450 = tt.expand_dims %left_idx_449 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc295) + %left_idx_451 = tt.broadcast %left_idx_450 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc296) + %right_idx_452 = arith.muli %y_idx_447, %iright : tensor<64x2x1xi32> loc(#loc297) + %right_idx_453 = "tt.reduce"(%right_idx_452) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc298)), %right_idx_714: i32 loc(callsite(#loc1 at #loc298))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc332) + tt.reduce.return %right_idx_715 : i32 loc(#loc323) + }) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc323) + %right_idx_454 = tt.expand_dims %right_idx_453 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc299) + %right_idx_455 = tt.broadcast %right_idx_454 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc300) + %left_idx_456 = tt.reshape %left_idx_451 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc301) + %right_idx_457 = tt.reshape %right_idx_455 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc302) + %cond_458 = arith.cmpi slt, %ileft_445, %iright_446 : tensor<8x16xi32> loc(#loc281) + %eq_459 = arith.cmpi eq, %ileft_445, %iright_446 : tensor<8x16xi32> loc(#loc282) + %cond_460 = arith.cmpi sgt, %left_idx_456, %right_idx_457 : tensor<8x16xi32> loc(#loc303) + %cond_461 = arith.andi %eq_459, %cond_460 : tensor<8x16xi1> loc(#loc283) + %cond_462 = arith.ori %cond_458, %cond_461 : tensor<8x16xi1> loc(#loc284) + %cond_463 = arith.extui %cond_462 : tensor<8x16xi1> to tensor<8x16xi32> loc(#loc285) + %cond_464 = arith.xori %cond_463, %flip_62 : tensor<8x16xi32> loc(#loc285) + %cond_465 = arith.cmpi ne, %cond_464, %cst_3 : tensor<8x16xi32> loc(#loc286) + %ret_466 = arith.xori %ileft_445, %iright_446 : tensor<8x16xi32> loc(#loc287) + %ret_467 = arith.select %cond_465, %ret_466, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc288) + %ret_468 = arith.xori %ret_432, %ret_467 : tensor<8x16xi32> loc(#loc289) + %new_idxs_469 = arith.xori %left_idx_456, %right_idx_457 : tensor<8x16xi32> loc(#loc304) + %new_idxs_470 = arith.select %cond_465, %new_idxs_469, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc290) + %new_idxs_471 = arith.xori %new_idxs_435, %new_idxs_470 : tensor<8x16xi32> loc(#loc291) + %y_472 = tt.reshape %ret_468 : tensor<8x16xi32> -> tensor<16x2x4xi32> loc(#loc270) + %ileft_473 = arith.muli %y_472, %ileft_139 : tensor<16x2x4xi32> loc(#loc271) + %ileft_474 = "tt.reduce"(%ileft_473) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc272)), %ileft_714: i32 loc(callsite(#loc1 at #loc272))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc329) + tt.reduce.return %ileft_715 : i32 loc(#loc317) + }) : (tensor<16x2x4xi32>) -> tensor<16x4xi32> loc(#loc317) + %ileft_475 = tt.expand_dims %ileft_474 {axis = 1 : i32} : tensor<16x4xi32> -> tensor<16x1x4xi32> loc(#loc273) + %ileft_476 = tt.broadcast %ileft_475 : tensor<16x1x4xi32> -> tensor<16x2x4xi32> loc(#loc274) + %iright_477 = arith.muli %y_472, %flip_61 : tensor<16x2x4xi32> loc(#loc275) + %iright_478 = "tt.reduce"(%iright_477) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc276)), %iright_714: i32 loc(callsite(#loc1 at #loc276))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc330) + tt.reduce.return %iright_715 : i32 loc(#loc319) + }) : (tensor<16x2x4xi32>) -> tensor<16x4xi32> loc(#loc319) + %iright_479 = tt.expand_dims %iright_478 {axis = 1 : i32} : tensor<16x4xi32> -> tensor<16x1x4xi32> loc(#loc277) + %iright_480 = tt.broadcast %iright_479 : tensor<16x1x4xi32> -> tensor<16x2x4xi32> loc(#loc278) + %ileft_481 = tt.reshape %ileft_476 : tensor<16x2x4xi32> -> tensor<8x16xi32> loc(#loc279) + %iright_482 = tt.reshape %iright_480 : tensor<16x2x4xi32> -> tensor<8x16xi32> loc(#loc280) + %y_idx_483 = tt.reshape %new_idxs_471 : tensor<8x16xi32> -> tensor<16x2x4xi32> loc(#loc292) + %left_idx_484 = arith.muli %y_idx_483, %ileft_139 : tensor<16x2x4xi32> loc(#loc293) + %left_idx_485 = "tt.reduce"(%left_idx_484) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc294)), %left_idx_714: i32 loc(callsite(#loc1 at #loc294))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc331) + tt.reduce.return %left_idx_715 : i32 loc(#loc321) + }) : (tensor<16x2x4xi32>) -> tensor<16x4xi32> loc(#loc321) + %left_idx_486 = tt.expand_dims %left_idx_485 {axis = 1 : i32} : tensor<16x4xi32> -> tensor<16x1x4xi32> loc(#loc295) + %left_idx_487 = tt.broadcast %left_idx_486 : tensor<16x1x4xi32> -> tensor<16x2x4xi32> loc(#loc296) + %right_idx_488 = arith.muli %y_idx_483, %flip_61 : tensor<16x2x4xi32> loc(#loc297) + %right_idx_489 = "tt.reduce"(%right_idx_488) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc298)), %right_idx_714: i32 loc(callsite(#loc1 at #loc298))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc332) + tt.reduce.return %right_idx_715 : i32 loc(#loc323) + }) : (tensor<16x2x4xi32>) -> tensor<16x4xi32> loc(#loc323) + %right_idx_490 = tt.expand_dims %right_idx_489 {axis = 1 : i32} : tensor<16x4xi32> -> tensor<16x1x4xi32> loc(#loc299) + %right_idx_491 = tt.broadcast %right_idx_490 : tensor<16x1x4xi32> -> tensor<16x2x4xi32> loc(#loc300) + %left_idx_492 = tt.reshape %left_idx_487 : tensor<16x2x4xi32> -> tensor<8x16xi32> loc(#loc301) + %right_idx_493 = tt.reshape %right_idx_491 : tensor<16x2x4xi32> -> tensor<8x16xi32> loc(#loc302) + %cond_494 = arith.cmpi slt, %ileft_481, %iright_482 : tensor<8x16xi32> loc(#loc281) + %eq_495 = arith.cmpi eq, %ileft_481, %iright_482 : tensor<8x16xi32> loc(#loc282) + %cond_496 = arith.cmpi sgt, %left_idx_492, %right_idx_493 : tensor<8x16xi32> loc(#loc303) + %cond_497 = arith.andi %eq_495, %cond_496 : tensor<8x16xi1> loc(#loc283) + %cond_498 = arith.ori %cond_494, %cond_497 : tensor<8x16xi1> loc(#loc284) + %cond_499 = arith.extui %cond_498 : tensor<8x16xi1> to tensor<8x16xi32> loc(#loc285) + %cond_500 = arith.xori %cond_499, %flip_137 : tensor<8x16xi32> loc(#loc285) + %cond_501 = arith.cmpi ne, %cond_500, %cst_3 : tensor<8x16xi32> loc(#loc286) + %ret_502 = arith.xori %ileft_481, %iright_482 : tensor<8x16xi32> loc(#loc287) + %ret_503 = arith.select %cond_501, %ret_502, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc288) + %ret_504 = arith.xori %ret_468, %ret_503 : tensor<8x16xi32> loc(#loc289) + %new_idxs_505 = arith.xori %left_idx_492, %right_idx_493 : tensor<8x16xi32> loc(#loc304) + %new_idxs_506 = arith.select %cond_501, %new_idxs_505, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc290) + %new_idxs_507 = arith.xori %new_idxs_471, %new_idxs_506 : tensor<8x16xi32> loc(#loc291) + %y_508 = tt.reshape %ret_504 : tensor<8x16xi32> -> tensor<32x2x2xi32> loc(#loc270) + %ileft_509 = arith.muli %y_508, %ileft_64 : tensor<32x2x2xi32> loc(#loc271) + %ileft_510 = "tt.reduce"(%ileft_509) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc272)), %ileft_714: i32 loc(callsite(#loc1 at #loc272))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc329) + tt.reduce.return %ileft_715 : i32 loc(#loc317) + }) : (tensor<32x2x2xi32>) -> tensor<32x2xi32> loc(#loc317) + %ileft_511 = tt.expand_dims %ileft_510 {axis = 1 : i32} : tensor<32x2xi32> -> tensor<32x1x2xi32> loc(#loc273) + %ileft_512 = tt.broadcast %ileft_511 : tensor<32x1x2xi32> -> tensor<32x2x2xi32> loc(#loc274) + %iright_513 = arith.muli %y_508, %flip_24 : tensor<32x2x2xi32> loc(#loc275) + %iright_514 = "tt.reduce"(%iright_513) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc276)), %iright_714: i32 loc(callsite(#loc1 at #loc276))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc330) + tt.reduce.return %iright_715 : i32 loc(#loc319) + }) : (tensor<32x2x2xi32>) -> tensor<32x2xi32> loc(#loc319) + %iright_515 = tt.expand_dims %iright_514 {axis = 1 : i32} : tensor<32x2xi32> -> tensor<32x1x2xi32> loc(#loc277) + %iright_516 = tt.broadcast %iright_515 : tensor<32x1x2xi32> -> tensor<32x2x2xi32> loc(#loc278) + %ileft_517 = tt.reshape %ileft_512 : tensor<32x2x2xi32> -> tensor<8x16xi32> loc(#loc279) + %iright_518 = tt.reshape %iright_516 : tensor<32x2x2xi32> -> tensor<8x16xi32> loc(#loc280) + %y_idx_519 = tt.reshape %new_idxs_507 : tensor<8x16xi32> -> tensor<32x2x2xi32> loc(#loc292) + %left_idx_520 = arith.muli %y_idx_519, %ileft_64 : tensor<32x2x2xi32> loc(#loc293) + %left_idx_521 = "tt.reduce"(%left_idx_520) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc294)), %left_idx_714: i32 loc(callsite(#loc1 at #loc294))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc331) + tt.reduce.return %left_idx_715 : i32 loc(#loc321) + }) : (tensor<32x2x2xi32>) -> tensor<32x2xi32> loc(#loc321) + %left_idx_522 = tt.expand_dims %left_idx_521 {axis = 1 : i32} : tensor<32x2xi32> -> tensor<32x1x2xi32> loc(#loc295) + %left_idx_523 = tt.broadcast %left_idx_522 : tensor<32x1x2xi32> -> tensor<32x2x2xi32> loc(#loc296) + %right_idx_524 = arith.muli %y_idx_519, %flip_24 : tensor<32x2x2xi32> loc(#loc297) + %right_idx_525 = "tt.reduce"(%right_idx_524) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc298)), %right_idx_714: i32 loc(callsite(#loc1 at #loc298))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc332) + tt.reduce.return %right_idx_715 : i32 loc(#loc323) + }) : (tensor<32x2x2xi32>) -> tensor<32x2xi32> loc(#loc323) + %right_idx_526 = tt.expand_dims %right_idx_525 {axis = 1 : i32} : tensor<32x2xi32> -> tensor<32x1x2xi32> loc(#loc299) + %right_idx_527 = tt.broadcast %right_idx_526 : tensor<32x1x2xi32> -> tensor<32x2x2xi32> loc(#loc300) + %left_idx_528 = tt.reshape %left_idx_523 : tensor<32x2x2xi32> -> tensor<8x16xi32> loc(#loc301) + %right_idx_529 = tt.reshape %right_idx_527 : tensor<32x2x2xi32> -> tensor<8x16xi32> loc(#loc302) + %cond_530 = arith.cmpi slt, %ileft_517, %iright_518 : tensor<8x16xi32> loc(#loc281) + %eq_531 = arith.cmpi eq, %ileft_517, %iright_518 : tensor<8x16xi32> loc(#loc282) + %cond_532 = arith.cmpi sgt, %left_idx_528, %right_idx_529 : tensor<8x16xi32> loc(#loc303) + %cond_533 = arith.andi %eq_531, %cond_532 : tensor<8x16xi1> loc(#loc283) + %cond_534 = arith.ori %cond_530, %cond_533 : tensor<8x16xi1> loc(#loc284) + %cond_535 = arith.extui %cond_534 : tensor<8x16xi1> to tensor<8x16xi32> loc(#loc285) + %cond_536 = arith.xori %cond_535, %flip_137 : tensor<8x16xi32> loc(#loc285) + %cond_537 = arith.cmpi ne, %cond_536, %cst_3 : tensor<8x16xi32> loc(#loc286) + %ret_538 = arith.xori %ileft_517, %iright_518 : tensor<8x16xi32> loc(#loc287) + %ret_539 = arith.select %cond_537, %ret_538, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc288) + %ret_540 = arith.xori %ret_504, %ret_539 : tensor<8x16xi32> loc(#loc289) + %new_idxs_541 = arith.xori %left_idx_528, %right_idx_529 : tensor<8x16xi32> loc(#loc304) + %new_idxs_542 = arith.select %cond_537, %new_idxs_541, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc290) + %new_idxs_543 = arith.xori %new_idxs_507, %new_idxs_542 : tensor<8x16xi32> loc(#loc291) + %y_544 = tt.reshape %ret_540 : tensor<8x16xi32> -> tensor<64x2x1xi32> loc(#loc270) + %ileft_545 = arith.muli %y_544, %ileft : tensor<64x2x1xi32> loc(#loc271) + %ileft_546 = "tt.reduce"(%ileft_545) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc272)), %ileft_714: i32 loc(callsite(#loc1 at #loc272))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc329) + tt.reduce.return %ileft_715 : i32 loc(#loc317) + }) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc317) + %ileft_547 = tt.expand_dims %ileft_546 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc273) + %ileft_548 = tt.broadcast %ileft_547 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc274) + %iright_549 = arith.muli %y_544, %iright : tensor<64x2x1xi32> loc(#loc275) + %iright_550 = "tt.reduce"(%iright_549) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc276)), %iright_714: i32 loc(callsite(#loc1 at #loc276))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc330) + tt.reduce.return %iright_715 : i32 loc(#loc319) + }) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc319) + %iright_551 = tt.expand_dims %iright_550 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc277) + %iright_552 = tt.broadcast %iright_551 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc278) + %ileft_553 = tt.reshape %ileft_548 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc279) + %iright_554 = tt.reshape %iright_552 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc280) + %y_idx_555 = tt.reshape %new_idxs_543 : tensor<8x16xi32> -> tensor<64x2x1xi32> loc(#loc292) + %left_idx_556 = arith.muli %y_idx_555, %ileft : tensor<64x2x1xi32> loc(#loc293) + %left_idx_557 = "tt.reduce"(%left_idx_556) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc294)), %left_idx_714: i32 loc(callsite(#loc1 at #loc294))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc331) + tt.reduce.return %left_idx_715 : i32 loc(#loc321) + }) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc321) + %left_idx_558 = tt.expand_dims %left_idx_557 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc295) + %left_idx_559 = tt.broadcast %left_idx_558 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc296) + %right_idx_560 = arith.muli %y_idx_555, %iright : tensor<64x2x1xi32> loc(#loc297) + %right_idx_561 = "tt.reduce"(%right_idx_560) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc298)), %right_idx_714: i32 loc(callsite(#loc1 at #loc298))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc332) + tt.reduce.return %right_idx_715 : i32 loc(#loc323) + }) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc323) + %right_idx_562 = tt.expand_dims %right_idx_561 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc299) + %right_idx_563 = tt.broadcast %right_idx_562 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc300) + %left_idx_564 = tt.reshape %left_idx_559 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc301) + %right_idx_565 = tt.reshape %right_idx_563 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc302) + %cond_566 = arith.cmpi slt, %ileft_553, %iright_554 : tensor<8x16xi32> loc(#loc281) + %eq_567 = arith.cmpi eq, %ileft_553, %iright_554 : tensor<8x16xi32> loc(#loc282) + %cond_568 = arith.cmpi sgt, %left_idx_564, %right_idx_565 : tensor<8x16xi32> loc(#loc303) + %cond_569 = arith.andi %eq_567, %cond_568 : tensor<8x16xi1> loc(#loc283) + %cond_570 = arith.ori %cond_566, %cond_569 : tensor<8x16xi1> loc(#loc284) + %cond_571 = arith.extui %cond_570 : tensor<8x16xi1> to tensor<8x16xi32> loc(#loc285) + %cond_572 = arith.xori %cond_571, %flip_137 : tensor<8x16xi32> loc(#loc285) + %cond_573 = arith.cmpi ne, %cond_572, %cst_3 : tensor<8x16xi32> loc(#loc286) + %ret_574 = arith.xori %ileft_553, %iright_554 : tensor<8x16xi32> loc(#loc287) + %ret_575 = arith.select %cond_573, %ret_574, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc288) + %ret_576 = arith.xori %ret_540, %ret_575 : tensor<8x16xi32> loc(#loc289) + %new_idxs_577 = arith.xori %left_idx_564, %right_idx_565 : tensor<8x16xi32> loc(#loc304) + %new_idxs_578 = arith.select %cond_573, %new_idxs_577, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc290) + %new_idxs_579 = arith.xori %new_idxs_543, %new_idxs_578 : tensor<8x16xi32> loc(#loc291) + %y_580 = tt.reshape %ret_576 : tensor<8x16xi32> -> tensor<8x2x8xi32> loc(#loc270) + %ileft_581 = arith.muli %y_580, %ileft_248 : tensor<8x2x8xi32> loc(#loc271) + %ileft_582 = "tt.reduce"(%ileft_581) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc272)), %ileft_714: i32 loc(callsite(#loc1 at #loc272))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc329) + tt.reduce.return %ileft_715 : i32 loc(#loc317) + }) : (tensor<8x2x8xi32>) -> tensor<8x8xi32> loc(#loc317) + %ileft_583 = tt.expand_dims %ileft_582 {axis = 1 : i32} : tensor<8x8xi32> -> tensor<8x1x8xi32> loc(#loc273) + %ileft_584 = tt.broadcast %ileft_583 : tensor<8x1x8xi32> -> tensor<8x2x8xi32> loc(#loc274) + %iright_585 = arith.muli %y_580, %flip_136 : tensor<8x2x8xi32> loc(#loc275) + %iright_586 = "tt.reduce"(%iright_585) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc276)), %iright_714: i32 loc(callsite(#loc1 at #loc276))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc330) + tt.reduce.return %iright_715 : i32 loc(#loc319) + }) : (tensor<8x2x8xi32>) -> tensor<8x8xi32> loc(#loc319) + %iright_587 = tt.expand_dims %iright_586 {axis = 1 : i32} : tensor<8x8xi32> -> tensor<8x1x8xi32> loc(#loc277) + %iright_588 = tt.broadcast %iright_587 : tensor<8x1x8xi32> -> tensor<8x2x8xi32> loc(#loc278) + %ileft_589 = tt.reshape %ileft_584 : tensor<8x2x8xi32> -> tensor<8x16xi32> loc(#loc279) + %iright_590 = tt.reshape %iright_588 : tensor<8x2x8xi32> -> tensor<8x16xi32> loc(#loc280) + %y_idx_591 = tt.reshape %new_idxs_579 : tensor<8x16xi32> -> tensor<8x2x8xi32> loc(#loc292) + %left_idx_592 = arith.muli %y_idx_591, %ileft_248 : tensor<8x2x8xi32> loc(#loc293) + %left_idx_593 = "tt.reduce"(%left_idx_592) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc294)), %left_idx_714: i32 loc(callsite(#loc1 at #loc294))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc331) + tt.reduce.return %left_idx_715 : i32 loc(#loc321) + }) : (tensor<8x2x8xi32>) -> tensor<8x8xi32> loc(#loc321) + %left_idx_594 = tt.expand_dims %left_idx_593 {axis = 1 : i32} : tensor<8x8xi32> -> tensor<8x1x8xi32> loc(#loc295) + %left_idx_595 = tt.broadcast %left_idx_594 : tensor<8x1x8xi32> -> tensor<8x2x8xi32> loc(#loc296) + %right_idx_596 = arith.muli %y_idx_591, %flip_136 : tensor<8x2x8xi32> loc(#loc297) + %right_idx_597 = "tt.reduce"(%right_idx_596) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc298)), %right_idx_714: i32 loc(callsite(#loc1 at #loc298))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc332) + tt.reduce.return %right_idx_715 : i32 loc(#loc323) + }) : (tensor<8x2x8xi32>) -> tensor<8x8xi32> loc(#loc323) + %right_idx_598 = tt.expand_dims %right_idx_597 {axis = 1 : i32} : tensor<8x8xi32> -> tensor<8x1x8xi32> loc(#loc299) + %right_idx_599 = tt.broadcast %right_idx_598 : tensor<8x1x8xi32> -> tensor<8x2x8xi32> loc(#loc300) + %left_idx_600 = tt.reshape %left_idx_595 : tensor<8x2x8xi32> -> tensor<8x16xi32> loc(#loc301) + %right_idx_601 = tt.reshape %right_idx_599 : tensor<8x2x8xi32> -> tensor<8x16xi32> loc(#loc302) + %cond_602 = arith.cmpi slt, %ileft_589, %iright_590 : tensor<8x16xi32> loc(#loc281) + %eq_603 = arith.cmpi eq, %ileft_589, %iright_590 : tensor<8x16xi32> loc(#loc282) + %cond_604 = arith.cmpi sgt, %left_idx_600, %right_idx_601 : tensor<8x16xi32> loc(#loc303) + %cond_605 = arith.andi %eq_603, %cond_604 : tensor<8x16xi1> loc(#loc283) + %cond_606 = arith.ori %cond_602, %cond_605 : tensor<8x16xi1> loc(#loc284) + %ret_607 = arith.xori %ileft_589, %iright_590 : tensor<8x16xi32> loc(#loc287) + %ret_608 = arith.select %cond_606, %ret_607, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc288) + %ret_609 = arith.xori %ret_576, %ret_608 : tensor<8x16xi32> loc(#loc289) + %new_idxs_610 = arith.xori %left_idx_600, %right_idx_601 : tensor<8x16xi32> loc(#loc304) + %new_idxs_611 = arith.select %cond_606, %new_idxs_610, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc290) + %new_idxs_612 = arith.xori %new_idxs_579, %new_idxs_611 : tensor<8x16xi32> loc(#loc291) + %y_613 = tt.reshape %ret_609 : tensor<8x16xi32> -> tensor<16x2x4xi32> loc(#loc270) + %ileft_614 = arith.muli %y_613, %ileft_139 : tensor<16x2x4xi32> loc(#loc271) + %ileft_615 = "tt.reduce"(%ileft_614) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc272)), %ileft_714: i32 loc(callsite(#loc1 at #loc272))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc329) + tt.reduce.return %ileft_715 : i32 loc(#loc317) + }) : (tensor<16x2x4xi32>) -> tensor<16x4xi32> loc(#loc317) + %ileft_616 = tt.expand_dims %ileft_615 {axis = 1 : i32} : tensor<16x4xi32> -> tensor<16x1x4xi32> loc(#loc273) + %ileft_617 = tt.broadcast %ileft_616 : tensor<16x1x4xi32> -> tensor<16x2x4xi32> loc(#loc274) + %iright_618 = arith.muli %y_613, %flip_61 : tensor<16x2x4xi32> loc(#loc275) + %iright_619 = "tt.reduce"(%iright_618) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc276)), %iright_714: i32 loc(callsite(#loc1 at #loc276))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc330) + tt.reduce.return %iright_715 : i32 loc(#loc319) + }) : (tensor<16x2x4xi32>) -> tensor<16x4xi32> loc(#loc319) + %iright_620 = tt.expand_dims %iright_619 {axis = 1 : i32} : tensor<16x4xi32> -> tensor<16x1x4xi32> loc(#loc277) + %iright_621 = tt.broadcast %iright_620 : tensor<16x1x4xi32> -> tensor<16x2x4xi32> loc(#loc278) + %ileft_622 = tt.reshape %ileft_617 : tensor<16x2x4xi32> -> tensor<8x16xi32> loc(#loc279) + %iright_623 = tt.reshape %iright_621 : tensor<16x2x4xi32> -> tensor<8x16xi32> loc(#loc280) + %y_idx_624 = tt.reshape %new_idxs_612 : tensor<8x16xi32> -> tensor<16x2x4xi32> loc(#loc292) + %left_idx_625 = arith.muli %y_idx_624, %ileft_139 : tensor<16x2x4xi32> loc(#loc293) + %left_idx_626 = "tt.reduce"(%left_idx_625) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc294)), %left_idx_714: i32 loc(callsite(#loc1 at #loc294))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc331) + tt.reduce.return %left_idx_715 : i32 loc(#loc321) + }) : (tensor<16x2x4xi32>) -> tensor<16x4xi32> loc(#loc321) + %left_idx_627 = tt.expand_dims %left_idx_626 {axis = 1 : i32} : tensor<16x4xi32> -> tensor<16x1x4xi32> loc(#loc295) + %left_idx_628 = tt.broadcast %left_idx_627 : tensor<16x1x4xi32> -> tensor<16x2x4xi32> loc(#loc296) + %right_idx_629 = arith.muli %y_idx_624, %flip_61 : tensor<16x2x4xi32> loc(#loc297) + %right_idx_630 = "tt.reduce"(%right_idx_629) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc298)), %right_idx_714: i32 loc(callsite(#loc1 at #loc298))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc332) + tt.reduce.return %right_idx_715 : i32 loc(#loc323) + }) : (tensor<16x2x4xi32>) -> tensor<16x4xi32> loc(#loc323) + %right_idx_631 = tt.expand_dims %right_idx_630 {axis = 1 : i32} : tensor<16x4xi32> -> tensor<16x1x4xi32> loc(#loc299) + %right_idx_632 = tt.broadcast %right_idx_631 : tensor<16x1x4xi32> -> tensor<16x2x4xi32> loc(#loc300) + %left_idx_633 = tt.reshape %left_idx_628 : tensor<16x2x4xi32> -> tensor<8x16xi32> loc(#loc301) + %right_idx_634 = tt.reshape %right_idx_632 : tensor<16x2x4xi32> -> tensor<8x16xi32> loc(#loc302) + %cond_635 = arith.cmpi slt, %ileft_622, %iright_623 : tensor<8x16xi32> loc(#loc281) + %eq_636 = arith.cmpi eq, %ileft_622, %iright_623 : tensor<8x16xi32> loc(#loc282) + %cond_637 = arith.cmpi sgt, %left_idx_633, %right_idx_634 : tensor<8x16xi32> loc(#loc303) + %cond_638 = arith.andi %eq_636, %cond_637 : tensor<8x16xi1> loc(#loc283) + %cond_639 = arith.ori %cond_635, %cond_638 : tensor<8x16xi1> loc(#loc284) + %ret_640 = arith.xori %ileft_622, %iright_623 : tensor<8x16xi32> loc(#loc287) + %ret_641 = arith.select %cond_639, %ret_640, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc288) + %ret_642 = arith.xori %ret_609, %ret_641 : tensor<8x16xi32> loc(#loc289) + %new_idxs_643 = arith.xori %left_idx_633, %right_idx_634 : tensor<8x16xi32> loc(#loc304) + %new_idxs_644 = arith.select %cond_639, %new_idxs_643, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc290) + %new_idxs_645 = arith.xori %new_idxs_612, %new_idxs_644 : tensor<8x16xi32> loc(#loc291) + %y_646 = tt.reshape %ret_642 : tensor<8x16xi32> -> tensor<32x2x2xi32> loc(#loc270) + %ileft_647 = arith.muli %y_646, %ileft_64 : tensor<32x2x2xi32> loc(#loc271) + %ileft_648 = "tt.reduce"(%ileft_647) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc272)), %ileft_714: i32 loc(callsite(#loc1 at #loc272))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc329) + tt.reduce.return %ileft_715 : i32 loc(#loc317) + }) : (tensor<32x2x2xi32>) -> tensor<32x2xi32> loc(#loc317) + %ileft_649 = tt.expand_dims %ileft_648 {axis = 1 : i32} : tensor<32x2xi32> -> tensor<32x1x2xi32> loc(#loc273) + %ileft_650 = tt.broadcast %ileft_649 : tensor<32x1x2xi32> -> tensor<32x2x2xi32> loc(#loc274) + %iright_651 = arith.muli %y_646, %flip_24 : tensor<32x2x2xi32> loc(#loc275) + %iright_652 = "tt.reduce"(%iright_651) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc276)), %iright_714: i32 loc(callsite(#loc1 at #loc276))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc330) + tt.reduce.return %iright_715 : i32 loc(#loc319) + }) : (tensor<32x2x2xi32>) -> tensor<32x2xi32> loc(#loc319) + %iright_653 = tt.expand_dims %iright_652 {axis = 1 : i32} : tensor<32x2xi32> -> tensor<32x1x2xi32> loc(#loc277) + %iright_654 = tt.broadcast %iright_653 : tensor<32x1x2xi32> -> tensor<32x2x2xi32> loc(#loc278) + %ileft_655 = tt.reshape %ileft_650 : tensor<32x2x2xi32> -> tensor<8x16xi32> loc(#loc279) + %iright_656 = tt.reshape %iright_654 : tensor<32x2x2xi32> -> tensor<8x16xi32> loc(#loc280) + %y_idx_657 = tt.reshape %new_idxs_645 : tensor<8x16xi32> -> tensor<32x2x2xi32> loc(#loc292) + %left_idx_658 = arith.muli %y_idx_657, %ileft_64 : tensor<32x2x2xi32> loc(#loc293) + %left_idx_659 = "tt.reduce"(%left_idx_658) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc294)), %left_idx_714: i32 loc(callsite(#loc1 at #loc294))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc331) + tt.reduce.return %left_idx_715 : i32 loc(#loc321) + }) : (tensor<32x2x2xi32>) -> tensor<32x2xi32> loc(#loc321) + %left_idx_660 = tt.expand_dims %left_idx_659 {axis = 1 : i32} : tensor<32x2xi32> -> tensor<32x1x2xi32> loc(#loc295) + %left_idx_661 = tt.broadcast %left_idx_660 : tensor<32x1x2xi32> -> tensor<32x2x2xi32> loc(#loc296) + %right_idx_662 = arith.muli %y_idx_657, %flip_24 : tensor<32x2x2xi32> loc(#loc297) + %right_idx_663 = "tt.reduce"(%right_idx_662) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc298)), %right_idx_714: i32 loc(callsite(#loc1 at #loc298))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc332) + tt.reduce.return %right_idx_715 : i32 loc(#loc323) + }) : (tensor<32x2x2xi32>) -> tensor<32x2xi32> loc(#loc323) + %right_idx_664 = tt.expand_dims %right_idx_663 {axis = 1 : i32} : tensor<32x2xi32> -> tensor<32x1x2xi32> loc(#loc299) + %right_idx_665 = tt.broadcast %right_idx_664 : tensor<32x1x2xi32> -> tensor<32x2x2xi32> loc(#loc300) + %left_idx_666 = tt.reshape %left_idx_661 : tensor<32x2x2xi32> -> tensor<8x16xi32> loc(#loc301) + %right_idx_667 = tt.reshape %right_idx_665 : tensor<32x2x2xi32> -> tensor<8x16xi32> loc(#loc302) + %cond_668 = arith.cmpi slt, %ileft_655, %iright_656 : tensor<8x16xi32> loc(#loc281) + %eq_669 = arith.cmpi eq, %ileft_655, %iright_656 : tensor<8x16xi32> loc(#loc282) + %cond_670 = arith.cmpi sgt, %left_idx_666, %right_idx_667 : tensor<8x16xi32> loc(#loc303) + %cond_671 = arith.andi %eq_669, %cond_670 : tensor<8x16xi1> loc(#loc283) + %cond_672 = arith.ori %cond_668, %cond_671 : tensor<8x16xi1> loc(#loc284) + %ret_673 = arith.xori %ileft_655, %iright_656 : tensor<8x16xi32> loc(#loc287) + %ret_674 = arith.select %cond_672, %ret_673, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc288) + %ret_675 = arith.xori %ret_642, %ret_674 : tensor<8x16xi32> loc(#loc289) + %new_idxs_676 = arith.xori %left_idx_666, %right_idx_667 : tensor<8x16xi32> loc(#loc304) + %new_idxs_677 = arith.select %cond_672, %new_idxs_676, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc290) + %new_idxs_678 = arith.xori %new_idxs_645, %new_idxs_677 : tensor<8x16xi32> loc(#loc291) + %y_679 = tt.reshape %ret_675 : tensor<8x16xi32> -> tensor<64x2x1xi32> loc(#loc270) + %ileft_680 = arith.muli %y_679, %ileft : tensor<64x2x1xi32> loc(#loc271) + %ileft_681 = "tt.reduce"(%ileft_680) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc272)), %ileft_714: i32 loc(callsite(#loc1 at #loc272))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc329) + tt.reduce.return %ileft_715 : i32 loc(#loc317) + }) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc317) + %ileft_682 = tt.expand_dims %ileft_681 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc273) + %ileft_683 = tt.broadcast %ileft_682 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc274) + %iright_684 = arith.muli %y_679, %iright : tensor<64x2x1xi32> loc(#loc275) + %iright_685 = "tt.reduce"(%iright_684) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc276)), %iright_714: i32 loc(callsite(#loc1 at #loc276))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc330) + tt.reduce.return %iright_715 : i32 loc(#loc319) + }) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc319) + %iright_686 = tt.expand_dims %iright_685 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc277) + %iright_687 = tt.broadcast %iright_686 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc278) + %ileft_688 = tt.reshape %ileft_683 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc279) + %iright_689 = tt.reshape %iright_687 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc280) + %y_idx_690 = tt.reshape %new_idxs_678 : tensor<8x16xi32> -> tensor<64x2x1xi32> loc(#loc292) + %left_idx_691 = arith.muli %y_idx_690, %ileft : tensor<64x2x1xi32> loc(#loc293) + %left_idx_692 = "tt.reduce"(%left_idx_691) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc294)), %left_idx_714: i32 loc(callsite(#loc1 at #loc294))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc331) + tt.reduce.return %left_idx_715 : i32 loc(#loc321) + }) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc321) + %left_idx_693 = tt.expand_dims %left_idx_692 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc295) + %left_idx_694 = tt.broadcast %left_idx_693 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc296) + %right_idx_695 = arith.muli %y_idx_690, %iright : tensor<64x2x1xi32> loc(#loc297) + %right_idx_696 = "tt.reduce"(%right_idx_695) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc298)), %right_idx_714: i32 loc(callsite(#loc1 at #loc298))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc332) + tt.reduce.return %right_idx_715 : i32 loc(#loc323) + }) : (tensor<64x2x1xi32>) -> tensor<64x1xi32> loc(#loc323) + %right_idx_697 = tt.expand_dims %right_idx_696 {axis = 1 : i32} : tensor<64x1xi32> -> tensor<64x1x1xi32> loc(#loc299) + %right_idx_698 = tt.broadcast %right_idx_697 : tensor<64x1x1xi32> -> tensor<64x2x1xi32> loc(#loc300) + %left_idx_699 = tt.reshape %left_idx_694 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc301) + %right_idx_700 = tt.reshape %right_idx_698 : tensor<64x2x1xi32> -> tensor<8x16xi32> loc(#loc302) + %cond_701 = arith.cmpi slt, %ileft_688, %iright_689 : tensor<8x16xi32> loc(#loc281) + %eq_702 = arith.cmpi eq, %ileft_688, %iright_689 : tensor<8x16xi32> loc(#loc282) + %cond_703 = arith.cmpi sgt, %left_idx_699, %right_idx_700 : tensor<8x16xi32> loc(#loc303) + %cond_704 = arith.andi %eq_702, %cond_703 : tensor<8x16xi1> loc(#loc283) + %cond_705 = arith.ori %cond_701, %cond_704 : tensor<8x16xi1> loc(#loc284) + %new_idxs_706 = arith.xori %left_idx_699, %right_idx_700 : tensor<8x16xi32> loc(#loc304) + %new_idxs_707 = arith.select %cond_705, %new_idxs_706, %cst_3 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc290) + %new_idxs_708 = arith.xori %new_idxs_678, %new_idxs_707 : tensor<8x16xi32> loc(#loc291) + %tmp20 = arith.extui %tmp5 : tensor<8x16xi1> to tensor<8x16xi64> loc(#loc226) + %tmp23 = arith.select %tmp0_20, %tmp20, %cst_7 : tensor<8x16xi1>, tensor<8x16xi64> loc(#loc197) + %tmp24 = "tt.reduce"(%tmp23) <{axis = 1 : i32}> ({ + ^bb0(%tmp24_713: i64 loc(callsite(#loc1 at #loc198)), %tmp24_714: i64 loc(callsite(#loc1 at #loc198))): + %tmp24_715 = arith.addi %tmp24_713, %tmp24_714 : i64 loc(#loc305) + tt.reduce.return %tmp24_715 : i64 loc(#loc227) + }) : (tensor<8x16xi64>) -> tensor<8xi64> loc(#loc227) + %tmp24_709 = tt.expand_dims %tmp24 {axis = 1 : i32} : tensor<8xi64> -> tensor<8x1xi64> loc(#loc199) + %tmp25 = arith.extui %tmp14 : tensor<8x16xi1> to tensor<8x16xi64> loc(#loc229) + %tmp28 = arith.select %tmp0_20, %tmp25, %cst_7 : tensor<8x16xi1>, tensor<8x16xi64> loc(#loc201) + %tmp29 = "tt.reduce"(%tmp28) <{axis = 1 : i32}> ({ + ^bb0(%tmp29_713: i64 loc(callsite(#loc1 at #loc202)), %tmp29_714: i64 loc(callsite(#loc1 at #loc202))): + %tmp29_715 = arith.addi %tmp29_713, %tmp29_714 : i64 loc(#loc306) + tt.reduce.return %tmp29_715 : i64 loc(#loc230) + }) : (tensor<8x16xi64>) -> tensor<8xi64> loc(#loc230) + %tmp29_710 = tt.expand_dims %tmp29 {axis = 1 : i32} : tensor<8xi64> -> tensor<8x1xi64> loc(#loc203) + %tmp30 = arith.trunci %tmp24_709 : tensor<8x1xi64> to tensor<8x1xi32> loc(#loc204) + %tmp31 = arith.trunci %tmp29_710 : tensor<8x1xi64> to tensor<8x1xi32> loc(#loc205) + %tmp34 = tt.broadcast %tmp30 : tensor<8x1xi32> -> tensor<8x16xi32> loc(#loc206) + %tmp34_711 = arith.cmpi slt, %tmp0_15, %tmp34 : tensor<8x16xi32> loc(#loc206) + %tmp36 = arith.select %tmp34_711, %new_idxs_376, %cst_5 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc207) + %tmp38 = arith.addi %tmp36, %cst_4 : tensor<8x16xi32> loc(#loc208) + %tmp39 = arith.cmpi slt, %tmp36, %cst_3 : tensor<8x16xi32> loc(#loc209) + %tmp40 = arith.select %tmp39, %tmp38, %tmp36 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc210) + %0 = arith.cmpi sge, %tmp40, %cst_3 : tensor<8x16xi32> loc(#loc88) + %1 = arith.cmpi slt, %tmp40, %cst_4 : tensor<8x16xi32> loc(#loc89) + %2 = arith.andi %0, %1 : tensor<8x16xi1> loc(#loc90) + %3 = arith.xori %xmask_13, %cst_2 : tensor<8x1xi1> loc(#loc91) + %4 = tt.broadcast %3 : tensor<8x1xi1> -> tensor<8x16xi1> loc(#loc92) + %5 = arith.ori %2, %4 : tensor<8x16xi1> loc(#loc92) + tt.assert %5, "index out of bounds: 0 <= tmp40 < 17" : tensor<8x16xi1> loc(#loc93) + %tmp45 = tt.broadcast %tmp31 : tensor<8x1xi32> -> tensor<8x16xi32> loc(#loc211) + %tmp45_712 = arith.cmpi slt, %tmp0_15, %tmp45 : tensor<8x16xi32> loc(#loc211) + %tmp46 = arith.select %tmp45_712, %new_idxs_708, %cst_5 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc212) + %tmp47 = arith.addi %tmp46, %cst_4 : tensor<8x16xi32> loc(#loc213) + %tmp48 = arith.cmpi slt, %tmp46, %cst_3 : tensor<8x16xi32> loc(#loc214) + %tmp49 = arith.select %tmp48, %tmp47, %tmp46 : tensor<8x16xi1>, tensor<8x16xi32> loc(#loc215) + %6 = arith.cmpi sge, %tmp49, %cst_3 : tensor<8x16xi32> loc(#loc99) + %7 = arith.cmpi slt, %tmp49, %cst_4 : tensor<8x16xi32> loc(#loc100) + %8 = arith.andi %6, %7 : tensor<8x16xi1> loc(#loc101) + %9 = arith.ori %8, %4 : tensor<8x16xi1> loc(#loc102) + tt.assert %9, "index out of bounds: 0 <= tmp49 < 17" : tensor<8x16xi1> loc(#loc103) + %10 = tt.splat %out_ptr4 : !tt.ptr -> tensor<8x1x!tt.ptr> loc(#loc104) + %11 = tt.addptr %10, %xindex_12 : tensor<8x1x!tt.ptr>, tensor<8x1xi32> loc(#loc104) + tt.store %11, %tmp30, %xmask_13 : tensor<8x1x!tt.ptr> loc(#loc105) + %12 = tt.splat %out_ptr5 : !tt.ptr -> tensor<8x1x!tt.ptr> loc(#loc106) + %13 = tt.addptr %12, %xindex_12 : tensor<8x1x!tt.ptr>, tensor<8x1xi32> loc(#loc106) + tt.store %13, %tmp31, %xmask_13 : tensor<8x1x!tt.ptr> loc(#loc107) + %14 = tt.splat %out_ptr6 : !tt.ptr -> tensor<8x16x!tt.ptr> loc(#loc108) + %15 = tt.addptr %14, %tmp0_17 : tensor<8x16x!tt.ptr>, tensor<8x16xi32> loc(#loc108) + tt.store %15, %new_idxs_376, %tmp0_20 : tensor<8x16x!tt.ptr> loc(#loc109) + %16 = arith.muli %xindex_12, %cst_1 : tensor<8x1xi32> loc(#loc110) + %17 = tt.broadcast %16 : tensor<8x1xi32> -> tensor<8x16xi32> loc(#loc111) + %18 = arith.addi %tmp40, %17 : tensor<8x16xi32> loc(#loc111) + %19 = tt.splat %out_ptr7 : !tt.ptr -> tensor<8x16x!tt.ptr> loc(#loc112) + %20 = tt.addptr %19, %18 : tensor<8x16x!tt.ptr>, tensor<8x16xi32> loc(#loc112) + tt.store %20, %cst_0, %tmp0_20 : tensor<8x16x!tt.ptr> loc(#loc113) + %21 = tt.splat %out_ptr8 : !tt.ptr -> tensor<8x16x!tt.ptr> loc(#loc114) + %22 = tt.addptr %21, %tmp0_17 : tensor<8x16x!tt.ptr>, tensor<8x16xi32> loc(#loc114) + tt.store %22, %new_idxs_708, %tmp0_20 : tensor<8x16x!tt.ptr> loc(#loc115) + %23 = arith.addi %tmp49, %17 : tensor<8x16xi32> loc(#loc116) + %24 = tt.splat %out_ptr9 : !tt.ptr -> tensor<8x16x!tt.ptr> loc(#loc117) + %25 = tt.addptr %24, %23 : tensor<8x16x!tt.ptr>, tensor<8x16xi32> loc(#loc117) + tt.store %25, %cst_0, %tmp0_20 : tensor<8x16x!tt.ptr> loc(#loc118) + tt.return loc(#loc119) + } loc(#loc) +} loc(#loc) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":26:21) +#loc3 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":24:28) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":24:33) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":25:36) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":25:44) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":25:23) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":27:28) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":27:38) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":34:40) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":34:37) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":34:30) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":34:45) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":36:18) +#loc15 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":38:18) +#loc16 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":39:18) +#loc17 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":41:19) +#loc18 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":40:19) +#loc19 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":43:19) +#loc20 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":45:34) +#loc21 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":627:41) +#loc24 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":627:44) +#loc25 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":627:60) +#loc26 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":627:68) +#loc27 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":533:22) +#loc29 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":537:21) +#loc30 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":538:40) +#loc31 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:36) +#loc33 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:15) +#loc34 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":538:65) +#loc35 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":538:78) +#loc36 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":539:41) +#loc38 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":539:67) +#loc39 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":539:80) +#loc40 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":540:30) +#loc41 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":541:32) +#loc42 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":546:29) +#loc43 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:36) +#loc44 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:23) +#loc45 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":290:25) +#loc47 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:53) +#loc48 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:66) +#loc49 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:37) +#loc50 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:23) +#loc52 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:54) +#loc53 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:67) +#loc54 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":553:36) +#loc55 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":554:38) +#loc56 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":574:22) +#loc57 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":591:21) +#loc58 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":594:40) +#loc59 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":594:29) +#loc60 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":594:23) +#loc61 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":599:19) +#loc62 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":599:28) +#loc63 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":600:38) +#loc64 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":600:46) +#loc65 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":600:15) +#loc66 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":601:48) +#loc67 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":601:59) +#loc68 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":601:22) +#loc69 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":47:20) +#loc70 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":49:21) +#loc71 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":48:21) +#loc73 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":52:20) +#loc74 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":54:35) +#loc76 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":55:29) +#loc77 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":56:21) +#loc78 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":58:35) +#loc80 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":59:29) +#loc81 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":60:21) +#loc82 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":61:21) +#loc83 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":64:19) +#loc84 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":66:35) +#loc85 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":68:20) +#loc86 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":69:20) +#loc87 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":70:35) +#loc88 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":71:28) +#loc89 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":71:46) +#loc90 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":71:38) +#loc91 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":71:55) +#loc92 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":71:53) +#loc93 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":71:63) +#loc94 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":75:19) +#loc95 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":76:35) +#loc96 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":77:20) +#loc97 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":78:20) +#loc98 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":79:35) +#loc99 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":80:28) +#loc100 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":80:46) +#loc101 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":80:38) +#loc102 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":80:53) +#loc103 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":80:63) +#loc104 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":81:25) +#loc105 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":81:37) +#loc106 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":82:25) +#loc107 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":82:37) +#loc108 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":83:25) +#loc109 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":83:47) +#loc110 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":84:52) +#loc111 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":84:49) +#loc112 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":84:25) +#loc113 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":84:85) +#loc114 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":85:25) +#loc115 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":85:47) +#loc116 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":86:49) +#loc117 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":86:25) +#loc118 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":86:85) +#loc119 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":86:4) +#loc129 = loc("xmask"(#loc2)) +#loc130 = loc("xoffset"(#loc3)) +#loc131 = loc("xoffset"(#loc4)) +#loc132 = loc("xindex"(#loc5)) +#loc133 = loc("xindex"(#loc6)) +#loc134 = loc("xindex"(#loc7)) +#loc135 = loc("r0_index"(#loc8)) +#loc136 = loc("r0_index"(#loc9)) +#loc137 = loc("tmp0"(#loc10)) +#loc138 = loc("tmp0"(#loc11)) +#loc139 = loc("tmp0"(#loc12)) +#loc140 = loc("tmp0"(#loc13)) +#loc141 = loc("tmp2"(#loc14)) +#loc142 = loc("tmp4"(#loc15)) +#loc143 = loc("tmp5"(#loc16)) +#loc144 = loc("tmp7"(#loc17)) +#loc145 = loc("tmp6"(#loc18)) +#loc146 = loc("tmp9"(#loc19)) +#loc147 = loc("tmp11"(#loc20)) +#loc148 = loc("flip"(#loc21)) +#loc150 = loc("flip"(#loc24)) +#loc151 = loc("flip"(#loc25)) +#loc152 = loc("flip"(#loc26)) +#loc153 = loc("y"(#loc27)) +#loc154 = loc("left_mask"(#loc29)) +#loc155 = loc("ileft"(#loc30)) +#loc157 = loc("ileft"(#loc34)) +#loc158 = loc("ileft"(#loc35)) +#loc159 = loc("iright"(#loc36)) +#loc161 = loc("iright"(#loc38)) +#loc162 = loc("iright"(#loc39)) +#loc163 = loc("ileft"(#loc40)) +#loc164 = loc("iright"(#loc41)) +#loc165 = loc("y_idx"(#loc42)) +#loc166 = loc("left_idx"(#loc43)) +#loc167 = loc("left_idx"(#loc44)) +#loc168 = loc("input"(#loc45)) +#loc170 = loc("left_idx"(#loc47)) +#loc171 = loc("left_idx"(#loc48)) +#loc172 = loc("right_idx"(#loc49)) +#loc173 = loc("right_idx"(#loc50)) +#loc175 = loc("right_idx"(#loc52)) +#loc176 = loc("right_idx"(#loc53)) +#loc177 = loc("left_idx"(#loc54)) +#loc178 = loc("right_idx"(#loc55)) +#loc179 = loc("cond"(#loc56)) +#loc180 = loc("eq"(#loc57)) +#loc181 = loc("cond"(#loc58)) +#loc182 = loc("cond"(#loc59)) +#loc183 = loc("cond"(#loc60)) +#loc184 = loc("cond"(#loc61)) +#loc185 = loc("cond"(#loc62)) +#loc186 = loc("ret"(#loc63)) +#loc187 = loc("ret"(#loc64)) +#loc188 = loc("ret"(#loc65)) +#loc189 = loc("new_idxs"(#loc66)) +#loc190 = loc("new_idxs"(#loc67)) +#loc191 = loc("new_idxs"(#loc68)) +#loc192 = loc("tmp14"(#loc69)) +#loc193 = loc("tmp16"(#loc70)) +#loc194 = loc("tmp15"(#loc71)) +#loc196 = loc("tmp20"(#loc73)) +#loc197 = loc("tmp23"(#loc74)) +#loc199 = loc("tmp24"(#loc76)) +#loc200 = loc("tmp25"(#loc77)) +#loc201 = loc("tmp28"(#loc78)) +#loc203 = loc("tmp29"(#loc80)) +#loc204 = loc("tmp30"(#loc81)) +#loc205 = loc("tmp31"(#loc82)) +#loc206 = loc("tmp34"(#loc83)) +#loc207 = loc("tmp36"(#loc84)) +#loc208 = loc("tmp38"(#loc85)) +#loc209 = loc("tmp39"(#loc86)) +#loc210 = loc("tmp40"(#loc87)) +#loc211 = loc("tmp45"(#loc94)) +#loc212 = loc("tmp46"(#loc95)) +#loc213 = loc("tmp47"(#loc96)) +#loc214 = loc("tmp48"(#loc97)) +#loc215 = loc("tmp49"(#loc98)) +#loc216 = loc(fused[#loc144, #loc145]) +#loc217 = loc(callsite(#loc148 at #loc149)) +#loc218 = loc(callsite(#loc150 at #loc149)) +#loc219 = loc(callsite(#loc151 at #loc149)) +#loc220 = loc(callsite(#loc152 at #loc149)) +#loc222 = loc("cond"(#loc179)) +#loc223 = loc("eq"(#loc180)) +#loc224 = loc(fused[#loc193, #loc194]) +#loc226 = loc(fused[#loc196, #loc144, #loc145]) +#loc227 = loc(callsite(#loc31 at #loc198)) +#loc229 = loc(fused[#loc200, #loc193, #loc194]) +#loc230 = loc(callsite(#loc31 at #loc202)) +#loc232 = loc(callsite(#loc153 at #loc221)) +#loc233 = loc(callsite(#loc154 at #loc221)) +#loc234 = loc(callsite(#loc155 at #loc221)) +#loc236 = loc(callsite(#loc157 at #loc221)) +#loc237 = loc(callsite(#loc158 at #loc221)) +#loc238 = loc(callsite(#loc159 at #loc221)) +#loc240 = loc(callsite(#loc161 at #loc221)) +#loc241 = loc(callsite(#loc162 at #loc221)) +#loc242 = loc(callsite(#loc163 at #loc221)) +#loc243 = loc(callsite(#loc164 at #loc221)) +#loc244 = loc(callsite(#loc165 at #loc221)) +#loc245 = loc(callsite(#loc166 at #loc221)) +#loc246 = loc(callsite(#loc167 at #loc221)) +#loc248 = loc(callsite(#loc170 at #loc221)) +#loc249 = loc(callsite(#loc171 at #loc221)) +#loc250 = loc(callsite(#loc172 at #loc221)) +#loc251 = loc(callsite(#loc173 at #loc221)) +#loc253 = loc(callsite(#loc175 at #loc221)) +#loc254 = loc(callsite(#loc176 at #loc221)) +#loc255 = loc(callsite(#loc177 at #loc221)) +#loc256 = loc(callsite(#loc178 at #loc221)) +#loc257 = loc(callsite(#loc222 at #loc221)) +#loc258 = loc(callsite(#loc223 at #loc221)) +#loc259 = loc(callsite(#loc181 at #loc221)) +#loc260 = loc(callsite(#loc182 at #loc221)) +#loc261 = loc(callsite(#loc183 at #loc221)) +#loc262 = loc(callsite(#loc184 at #loc221)) +#loc263 = loc(callsite(#loc185 at #loc221)) +#loc264 = loc(callsite(#loc186 at #loc221)) +#loc265 = loc(callsite(#loc187 at #loc221)) +#loc266 = loc(callsite(#loc188 at #loc221)) +#loc267 = loc(callsite(#loc189 at #loc221)) +#loc268 = loc(callsite(#loc190 at #loc221)) +#loc269 = loc(callsite(#loc191 at #loc221)) +#loc270 = loc(callsite(#loc153 at #loc225)) +#loc271 = loc(callsite(#loc155 at #loc225)) +#loc273 = loc(callsite(#loc157 at #loc225)) +#loc274 = loc(callsite(#loc158 at #loc225)) +#loc275 = loc(callsite(#loc159 at #loc225)) +#loc277 = loc(callsite(#loc161 at #loc225)) +#loc278 = loc(callsite(#loc162 at #loc225)) +#loc279 = loc(callsite(#loc163 at #loc225)) +#loc280 = loc(callsite(#loc164 at #loc225)) +#loc281 = loc(callsite(#loc222 at #loc225)) +#loc282 = loc(callsite(#loc223 at #loc225)) +#loc283 = loc(callsite(#loc182 at #loc225)) +#loc284 = loc(callsite(#loc183 at #loc225)) +#loc285 = loc(callsite(#loc184 at #loc225)) +#loc286 = loc(callsite(#loc185 at #loc225)) +#loc287 = loc(callsite(#loc186 at #loc225)) +#loc288 = loc(callsite(#loc187 at #loc225)) +#loc289 = loc(callsite(#loc188 at #loc225)) +#loc290 = loc(callsite(#loc190 at #loc225)) +#loc291 = loc(callsite(#loc191 at #loc225)) +#loc292 = loc(callsite(#loc165 at #loc225)) +#loc293 = loc(callsite(#loc167 at #loc225)) +#loc295 = loc(callsite(#loc170 at #loc225)) +#loc296 = loc(callsite(#loc171 at #loc225)) +#loc297 = loc(callsite(#loc173 at #loc225)) +#loc299 = loc(callsite(#loc175 at #loc225)) +#loc300 = loc(callsite(#loc176 at #loc225)) +#loc301 = loc(callsite(#loc177 at #loc225)) +#loc302 = loc(callsite(#loc178 at #loc225)) +#loc303 = loc(callsite(#loc181 at #loc225)) +#loc304 = loc(callsite(#loc189 at #loc225)) +#loc305 = loc(callsite(#loc33 at #loc227)) +#loc306 = loc(callsite(#loc33 at #loc230)) +#loc307 = loc(callsite(#loc31 at #loc235)) +#loc309 = loc(callsite(#loc31 at #loc239)) +#loc311 = loc(callsite(#loc168 at #loc247)) +#loc312 = loc(callsite(#loc31 at #loc247)) +#loc314 = loc(callsite(#loc168 at #loc252)) +#loc315 = loc(callsite(#loc31 at #loc252)) +#loc317 = loc(callsite(#loc31 at #loc272)) +#loc319 = loc(callsite(#loc31 at #loc276)) +#loc321 = loc(callsite(#loc31 at #loc294)) +#loc323 = loc(callsite(#loc31 at #loc298)) +#loc325 = loc(callsite(#loc33 at #loc307)) +#loc326 = loc(callsite(#loc33 at #loc309)) +#loc327 = loc(callsite(#loc33 at #loc312)) +#loc328 = loc(callsite(#loc33 at #loc315)) +#loc329 = loc(callsite(#loc33 at #loc317)) +#loc330 = loc(callsite(#loc33 at #loc319)) +#loc331 = loc(callsite(#loc33 at #loc321)) +#loc332 = loc(callsite(#loc33 at #loc323)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/ADCZ7SME2JTK2CEPVRJ7J4XLDGQGUMK5VNKNGD5DTWMOCOTS4J2Q/__grp__triton_red_fused_zeros_0.json b/SpecForge-ext/cache/compiled_kernels/triton/0/ADCZ7SME2JTK2CEPVRJ7J4XLDGQGUMK5VNKNGD5DTWMOCOTS4J2Q/__grp__triton_red_fused_zeros_0.json new file mode 100644 index 0000000000000000000000000000000000000000..5d550d3f4dd67738d931c53f3a1086280144f091 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/ADCZ7SME2JTK2CEPVRJ7J4XLDGQGUMK5VNKNGD5DTWMOCOTS4J2Q/__grp__triton_red_fused_zeros_0.json @@ -0,0 +1 @@ +{"child_paths": {"triton_red_fused_zeros_0.source": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/ADCZ7SME2JTK2CEPVRJ7J4XLDGQGUMK5VNKNGD5DTWMOCOTS4J2Q/triton_red_fused_zeros_0.source", "triton_red_fused_zeros_0.ttir": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/ADCZ7SME2JTK2CEPVRJ7J4XLDGQGUMK5VNKNGD5DTWMOCOTS4J2Q/triton_red_fused_zeros_0.ttir", "triton_red_fused_zeros_0.ttgir": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/ADCZ7SME2JTK2CEPVRJ7J4XLDGQGUMK5VNKNGD5DTWMOCOTS4J2Q/triton_red_fused_zeros_0.ttgir", "triton_red_fused_zeros_0.llir": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/ADCZ7SME2JTK2CEPVRJ7J4XLDGQGUMK5VNKNGD5DTWMOCOTS4J2Q/triton_red_fused_zeros_0.llir", "triton_red_fused_zeros_0.ptx": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/ADCZ7SME2JTK2CEPVRJ7J4XLDGQGUMK5VNKNGD5DTWMOCOTS4J2Q/triton_red_fused_zeros_0.ptx", "triton_red_fused_zeros_0.cubin": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/ADCZ7SME2JTK2CEPVRJ7J4XLDGQGUMK5VNKNGD5DTWMOCOTS4J2Q/triton_red_fused_zeros_0.cubin", "triton_red_fused_zeros_0.json": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/ADCZ7SME2JTK2CEPVRJ7J4XLDGQGUMK5VNKNGD5DTWMOCOTS4J2Q/triton_red_fused_zeros_0.json"}} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/ADCZ7SME2JTK2CEPVRJ7J4XLDGQGUMK5VNKNGD5DTWMOCOTS4J2Q/triton_red_fused_zeros_0.cubin b/SpecForge-ext/cache/compiled_kernels/triton/0/ADCZ7SME2JTK2CEPVRJ7J4XLDGQGUMK5VNKNGD5DTWMOCOTS4J2Q/triton_red_fused_zeros_0.cubin new file mode 100644 index 0000000000000000000000000000000000000000..aaf53998c285ad7a4d217741a8006c91c4d6e884 Binary files /dev/null and b/SpecForge-ext/cache/compiled_kernels/triton/0/ADCZ7SME2JTK2CEPVRJ7J4XLDGQGUMK5VNKNGD5DTWMOCOTS4J2Q/triton_red_fused_zeros_0.cubin differ diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/ADCZ7SME2JTK2CEPVRJ7J4XLDGQGUMK5VNKNGD5DTWMOCOTS4J2Q/triton_red_fused_zeros_0.json b/SpecForge-ext/cache/compiled_kernels/triton/0/ADCZ7SME2JTK2CEPVRJ7J4XLDGQGUMK5VNKNGD5DTWMOCOTS4J2Q/triton_red_fused_zeros_0.json new file mode 100644 index 0000000000000000000000000000000000000000..2e24a8d2fb96551654c519b583f69bd398f40b0a --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/ADCZ7SME2JTK2CEPVRJ7J4XLDGQGUMK5VNKNGD5DTWMOCOTS4J2Q/triton_red_fused_zeros_0.json @@ -0,0 +1 @@ +{"hash": "00c59fc984d266ad088fac53f4f2eb19a06a315dab54d30fa39d98e13a72e275", "target": {"backend": "cuda", "arch": 90, "warp_size": 32}, "num_warps": 4, "num_ctas": 1, "num_stages": 1, "warp_size": 32, "maxnreg": null, "cluster_dims": [1, 1, 1], "ptx_version": null, "ptx_options": null, "ir_override": null, "enable_fp_fusion": true, "launch_cooperative_grid": false, "launch_pdl": false, "supported_fp8_dtypes": ["fp8e4b15", "fp8e4nv", "fp8e5"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b15"], "default_dot_input_precision": "tf32", "allowed_dot_input_precisions": ["tf32", "tf32x3", "ieee"], "max_num_imprecise_acc_default": 1073741824, "extern_libs": [["libdevice", "/workspace/specforge/lib/python3.11/site-packages/triton/backends/nvidia/lib/libdevice.10.bc"]], "debug": true, "backend_name": "cuda", "sanitize_overflow": false, "arch": "sm90", "instrumentation_mode": "", "triton_version": "3.5.1", "tensordesc_meta": [], "shared": 256, "tmem_size": 0, "global_scratch_size": 0, "global_scratch_align": 1, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "triton_red_fused_zeros_0"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/ADCZ7SME2JTK2CEPVRJ7J4XLDGQGUMK5VNKNGD5DTWMOCOTS4J2Q/triton_red_fused_zeros_0.llir b/SpecForge-ext/cache/compiled_kernels/triton/0/ADCZ7SME2JTK2CEPVRJ7J4XLDGQGUMK5VNKNGD5DTWMOCOTS4J2Q/triton_red_fused_zeros_0.llir new file mode 100644 index 0000000000000000000000000000000000000000..45e1981adce1a24c116db1cb9d3959f0d4662e58 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/ADCZ7SME2JTK2CEPVRJ7J4XLDGQGUMK5VNKNGD5DTWMOCOTS4J2Q/triton_red_fused_zeros_0.llir @@ -0,0 +1,206 @@ +; ModuleID = 'LLVMDialectModule' +source_filename = "LLVMDialectModule" +target datalayout = "e-p3:32:32-p4:32:32-p5:32:32-p6:32:32-p7:32:32-i64:64-i128:128-v16:16-v32:32-n16:32:64" + +@global_smem = external local_unnamed_addr addrspace(3) global [0 x i8], align 16 + +; Function Attrs: nounwind +define ptx_kernel void @triton_red_fused_zeros_0(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, i64 %3, i64 %4, i32 %5, i32 %6, ptr addrspace(1) readnone captures(none) %7, ptr addrspace(1) readnone captures(none) %8) local_unnamed_addr #0 !dbg !4 { + %10 = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !dbg !7 + %11 = shl i32 %10, 6, !dbg !8 + %12 = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !dbg !9 + %13 = and i32 %12, 126, !dbg !9 + %14 = lshr exact i32 %13, 1, !dbg !9 + %15 = or disjoint i32 %14, %11, !dbg !10 + %16 = icmp slt i32 %15, %5, !dbg !11 + %17 = shl nuw nsw i32 %12, 2, !dbg !12 + %18 = and i32 %17, 4, !dbg !12 + %19 = sext i32 %15 to i64, !dbg !13 + %.frozen = freeze i64 %3, !dbg !14 + %20 = sdiv i64 %19, %.frozen, !dbg !14 + %21 = mul i64 %20, %.frozen, !dbg !13 + %.decomposed = sub i64 %19, %21, !dbg !13 + %22 = srem i64 %20, 32, !dbg !15 + %23 = sdiv i64 %19, %4, !dbg !16 + %.not = icmp ne i64 %.decomposed, 0, !dbg !17 + %24 = icmp slt i32 %11, 0, !dbg !21 + %25 = icmp slt i64 %3, 0, !dbg !22 + %26 = xor i1 %24, %25, !dbg !23 + %narrow = select i1 %26, i1 %.not, i1 false, !dbg !24 + %27 = sext i1 %narrow to i64, !dbg !24 + %28 = add nsw i64 %20, %27, !dbg !24 + %29 = shl i64 %3, 12, !dbg !25 + %30 = mul i64 %29, %23, !dbg !26 + %31 = icmp slt i64 %3, 2, !dbg !27 + %32 = icmp sgt i64 %3, 1, !dbg !28 + %33 = select i1 %32, i64 %3, i64 0, !dbg !29 + %34 = zext i1 %31 to i64, !dbg !30 + %35 = add i64 %33, %34, !dbg !31 + %36 = shl i64 %35, 7, !dbg !32 + %37 = mul i64 %36, %28, !dbg !33 + %.idx = shl nsw i64 %22, 8 + %38 = getelementptr i8, ptr addrspace(1) %0, i64 %.idx + %.idx1 = shl nsw i64 %.decomposed, 13 + %invariant.gep = getelementptr i8, ptr addrspace(1) %38, i64 %.idx1, !dbg !34 + %invariant.gep3 = getelementptr bfloat, ptr addrspace(1) %invariant.gep, i64 %30, !dbg !34 + %.idx2 = shl nsw i64 %.decomposed, 8 + %39 = getelementptr i8, ptr addrspace(1) %1, i64 %.idx2 + %invariant.gep5 = getelementptr bfloat, ptr addrspace(1) %39, i64 %37, !dbg !34 + %40 = zext nneg i32 %18 to i64, !dbg !34 + br label %41, !dbg !34 + +41: ; preds = %9, %41 + %indvars.iv = phi i64 [ 0, %9 ], [ %indvars.iv.next, %41 ] + %42 = phi float [ 0.000000e+00, %9 ], [ %83, %41 ] + %43 = phi float [ 0.000000e+00, %9 ], [ %84, %41 ] + %44 = phi float [ 0.000000e+00, %9 ], [ %85, %41 ] + %45 = phi float [ 0.000000e+00, %9 ], [ %86, %41 ] + %46 = or disjoint i64 %indvars.iv, %40, !dbg !35 + %gep4 = getelementptr bfloat, ptr addrspace(1) %invariant.gep3, i64 %46, !dbg !36 + %47 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_first.b64 $0, 1.0;", "=l"() #4, !dbg !37 + %48 = tail call { i32, i32 } asm sideeffect "mov.u32 $0, $2;\0A\09mov.u32 $1, $3;\0A\09@$6 ld.global.L1::evict_first.L2::cache_hint.v2.b32 { $0, $1 }, [ $4 + 0 ], $5;", "=r,=r,r,r,l,l,b"(i32 0, i32 0, ptr addrspace(1) %gep4, i64 %47, i1 %16) #4, !dbg !37 + %49 = extractvalue { i32, i32 } %48, 0, !dbg !37 + %50 = bitcast i32 %49 to <2 x bfloat>, !dbg !37 + %51 = extractvalue { i32, i32 } %48, 1, !dbg !37 + %52 = bitcast i32 %51 to <2 x bfloat>, !dbg !37 + %53 = extractelement <2 x bfloat> %50, i64 0, !dbg !37 + %54 = extractelement <2 x bfloat> %50, i64 1, !dbg !37 + %55 = extractelement <2 x bfloat> %52, i64 0, !dbg !37 + %56 = extractelement <2 x bfloat> %52, i64 1, !dbg !37 + %57 = fpext bfloat %53 to float, !dbg !38 + %58 = fpext bfloat %54 to float, !dbg !38 + %59 = fpext bfloat %55 to float, !dbg !38 + %60 = fpext bfloat %56 to float, !dbg !38 + %gep = getelementptr bfloat, ptr addrspace(1) %invariant.gep5, i64 %46, !dbg !39 + %61 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_first.b64 $0, 1.0;", "=l"() #4, !dbg !40 + %62 = tail call { i32, i32 } asm sideeffect "mov.u32 $0, $2;\0A\09mov.u32 $1, $3;\0A\09@$6 ld.global.L1::evict_first.L2::cache_hint.v2.b32 { $0, $1 }, [ $4 + 0 ], $5;", "=r,=r,r,r,l,l,b"(i32 0, i32 0, ptr addrspace(1) %gep, i64 %61, i1 %16) #4, !dbg !40 + %63 = extractvalue { i32, i32 } %62, 0, !dbg !40 + %64 = bitcast i32 %63 to <2 x bfloat>, !dbg !40 + %65 = extractvalue { i32, i32 } %62, 1, !dbg !40 + %66 = bitcast i32 %65 to <2 x bfloat>, !dbg !40 + %67 = extractelement <2 x bfloat> %64, i64 0, !dbg !40 + %68 = extractelement <2 x bfloat> %64, i64 1, !dbg !40 + %69 = extractelement <2 x bfloat> %66, i64 0, !dbg !40 + %70 = extractelement <2 x bfloat> %66, i64 1, !dbg !40 + %71 = fpext bfloat %67 to float, !dbg !41 + %72 = fpext bfloat %68 to float, !dbg !41 + %73 = fpext bfloat %69 to float, !dbg !41 + %74 = fpext bfloat %70 to float, !dbg !41 + %75 = fmul float %57, %71, !dbg !42 + %76 = fmul float %58, %72, !dbg !42 + %77 = fmul float %59, %73, !dbg !42 + %78 = fmul float %60, %74, !dbg !42 + %79 = fadd float %42, %75, !dbg !43 + %80 = fadd float %43, %76, !dbg !43 + %81 = fadd float %44, %77, !dbg !43 + %82 = fadd float %45, %78, !dbg !43 + %83 = select i1 %16, float %79, float %42, !dbg !44 + %84 = select i1 %16, float %80, float %43, !dbg !44 + %85 = select i1 %16, float %81, float %44, !dbg !44 + %86 = select i1 %16, float %82, float %45, !dbg !44 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 8, !dbg !34 + %87 = icmp samesign ult i64 %indvars.iv, 120, !dbg !34 + br i1 %87, label %41, label %88, !dbg !34 + +88: ; preds = %41 + %89 = and i32 %12, 63, !dbg !9 + %90 = or disjoint i32 %11, %89, !dbg !10 + %91 = icmp slt i32 %90, %5, !dbg !11 + %92 = fadd float %83, %84, !dbg !45 + %93 = fadd float %85, %92, !dbg !45 + %94 = fadd float %86, %93, !dbg !45 + %95 = bitcast float %94 to i32, !dbg !49 + %96 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %95, i32 1, i32 31), !dbg !49 + %97 = bitcast i32 %96 to float, !dbg !49 + %98 = fadd float %94, %97, !dbg !45 + %99 = shl nuw nsw i32 %13, 1, !dbg !50 + %100 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %99, !dbg !50 + store float %98, ptr addrspace(3) %100, align 4, !dbg !50 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !50 + %101 = shl nuw nsw i32 %89, 2, !dbg !50 + %102 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %101, !dbg !50 + %103 = load i32, ptr addrspace(3) %102, align 4, !dbg !50 + %104 = sext i32 %90 to i64, !dbg !51 + %105 = getelementptr float, ptr addrspace(1) %2, i64 %104, !dbg !51 + %106 = and i32 %12, 64, !dbg !52 + %107 = icmp eq i32 %106, 0, !dbg !52 + %108 = and i1 %107, %91, !dbg !52 + tail call void asm sideeffect "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b"(i32 %103, ptr addrspace(1) %105, i1 %108) #4, !dbg !52 + ret void, !dbg !53 +} + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare noundef range(i32 0, 2147483647) i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #1 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare noundef range(i32 0, 1024) i32 @llvm.nvvm.read.ptx.sreg.tid.x() #1 + +; Function Attrs: convergent nocallback nounwind memory(inaccessiblemem: readwrite) +declare i32 @llvm.nvvm.shfl.sync.bfly.i32(i32, i32, i32, i32) #2 + +; Function Attrs: convergent nocallback nounwind +declare void @llvm.nvvm.barrier.cta.sync.aligned.all(i32) #3 + +attributes #0 = { nounwind "nvvm.reqntid"="128" } +attributes #1 = { mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) } +attributes #2 = { convergent nocallback nounwind memory(inaccessiblemem: readwrite) } +attributes #3 = { convergent nocallback nounwind } +attributes #4 = { nounwind } + +!llvm.dbg.cu = !{!0} +!llvm.module.flags = !{!2, !3} + +!0 = distinct !DICompileUnit(language: DW_LANG_C, file: !1, producer: "triton", isOptimized: true, runtimeVersion: 0, emissionKind: LineTablesOnly) +!1 = !DIFile(filename: "cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py", directory: "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc") +!2 = !{i32 2, !"Debug Info Version", i32 3} +!3 = !{i32 4, !"nvvm-reflect-ftz", i32 1} +!4 = distinct !DISubprogram(name: "triton_red_fused_zeros_0", linkageName: "triton_red_fused_zeros_0", scope: !1, file: !1, line: 18, type: !5, scopeLine: 18, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !0) +!5 = !DISubroutineType(cc: DW_CC_normal, types: !6) +!6 = !{} +!7 = !DILocation(line: 22, column: 28, scope: !4) +!8 = !DILocation(line: 22, column: 33, scope: !4) +!9 = !DILocation(line: 23, column: 44, scope: !4) +!10 = !DILocation(line: 23, column: 23, scope: !4) +!11 = !DILocation(line: 24, column: 21, scope: !4) +!12 = !DILocation(line: 25, column: 37, scope: !4) +!13 = !DILocation(line: 27, column: 19, scope: !4) +!14 = !DILocation(line: 28, column: 21, scope: !4) +!15 = !DILocation(line: 28, column: 28, scope: !4) +!16 = !DILocation(line: 29, column: 19, scope: !4) +!17 = !DILocation(line: 74, column: 34, scope: !18, inlinedAt: !20) +!18 = distinct !DILexicalBlockFile(scope: !4, file: !19, discriminator: 0) +!19 = !DIFile(filename: "triton_helpers.py", directory: "/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime") +!20 = !DILocation(line: 30, column: 51, scope: !4) +!21 = !DILocation(line: 75, column: 25, scope: !18, inlinedAt: !20) +!22 = !DILocation(line: 75, column: 36, scope: !18, inlinedAt: !20) +!23 = !DILocation(line: 75, column: 32, scope: !18, inlinedAt: !20) +!24 = !DILocation(line: 75, column: 47, scope: !18, inlinedAt: !20) +!25 = !DILocation(line: 39, column: 65, scope: !4) +!26 = !DILocation(line: 39, column: 69, scope: !4) +!27 = !DILocation(line: 40, column: 73, scope: !4) +!28 = !DILocation(line: 40, column: 99, scope: !4) +!29 = !DILocation(line: 40, column: 90, scope: !4) +!30 = !DILocation(line: 40, scope: !4) +!31 = !DILocation(line: 40, column: 81, scope: !4) +!32 = !DILocation(line: 40, column: 54, scope: !4) +!33 = !DILocation(line: 40, column: 58, scope: !4) +!34 = !DILocation(line: 33, column: 40, scope: !4) +!35 = !DILocation(line: 34, column: 31, scope: !4) +!36 = !DILocation(line: 39, column: 34, scope: !4) +!37 = !DILocation(line: 39, column: 74, scope: !4) +!38 = !DILocation(line: 39, column: 136, scope: !4) +!39 = !DILocation(line: 40, column: 34, scope: !4) +!40 = !DILocation(line: 40, column: 106, scope: !4) +!41 = !DILocation(line: 40, column: 168, scope: !4) +!42 = !DILocation(line: 41, column: 22, scope: !4) +!43 = !DILocation(line: 43, column: 23, scope: !4) +!44 = !DILocation(line: 44, column: 48, scope: !4) +!45 = !DILocation(line: 261, column: 15, scope: !46, inlinedAt: !48) +!46 = distinct !DILexicalBlockFile(scope: !4, file: !47, discriminator: 0) +!47 = !DIFile(filename: "standard.py", directory: "/workspace/specforge/lib/python3.11/site-packages/triton/language") +!48 = !DILocation(line: 45, column: 25, scope: !4) +!49 = !DILocation(line: 291, column: 36, scope: !46, inlinedAt: !48) +!50 = !DILocation(line: 45, column: 28, scope: !4) +!51 = !DILocation(line: 49, column: 25, scope: !4) +!52 = !DILocation(line: 49, column: 36, scope: !4) +!53 = !DILocation(line: 49, column: 4, scope: !4) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/ADCZ7SME2JTK2CEPVRJ7J4XLDGQGUMK5VNKNGD5DTWMOCOTS4J2Q/triton_red_fused_zeros_0.ptx b/SpecForge-ext/cache/compiled_kernels/triton/0/ADCZ7SME2JTK2CEPVRJ7J4XLDGQGUMK5VNKNGD5DTWMOCOTS4J2Q/triton_red_fused_zeros_0.ptx new file mode 100644 index 0000000000000000000000000000000000000000..af27b5fc261ec2ab229d7149babb28bf01408d81 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/ADCZ7SME2JTK2CEPVRJ7J4XLDGQGUMK5VNKNGD5DTWMOCOTS4J2Q/triton_red_fused_zeros_0.ptx @@ -0,0 +1,474 @@ +// +// Generated by LLVM NVPTX Back-End +// + +.version 8.7 +.target sm_90a +.address_size 64 + + // .globl triton_red_fused_zeros_0 // -- Begin function triton_red_fused_zeros_0 +.extern .shared .align 16 .b8 global_smem[]; + // @triton_red_fused_zeros_0 +.visible .entry triton_red_fused_zeros_0( + .param .u64 .ptr .global .align 1 triton_red_fused_zeros_0_param_0, + .param .u64 .ptr .global .align 1 triton_red_fused_zeros_0_param_1, + .param .u64 .ptr .global .align 1 triton_red_fused_zeros_0_param_2, + .param .u64 triton_red_fused_zeros_0_param_3, + .param .u64 triton_red_fused_zeros_0_param_4, + .param .u32 triton_red_fused_zeros_0_param_5, + .param .u32 triton_red_fused_zeros_0_param_6, + .param .u64 .ptr .global .align 1 triton_red_fused_zeros_0_param_7, + .param .u64 .ptr .global .align 1 triton_red_fused_zeros_0_param_8 +) +.reqntid 128 +{ + .reg .pred %p<16>; + .reg .b16 %rs<9>; + .reg .b32 %r<66>; + .reg .b64 %rd<67>; + .loc 1 18 0 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:18:0 +$L__func_begin0: + .loc 1 18 0 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:18:0 + +// %bb.0: + ld.param.b64 %rd24, [triton_red_fused_zeros_0_param_4]; + ld.param.b64 %rd23, [triton_red_fused_zeros_0_param_3]; +$L__tmp0: + .loc 1 22 28 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:22:28 + mov.u32 %r14, %ctaid.x; + .loc 1 22 33 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:22:33 + shl.b32 %r1, %r14, 6; + .loc 1 23 44 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:23:44 + mov.u32 %r2, %tid.x; + bfe.u32 %r4, %r2, 1, 6; + .loc 1 23 23 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:23:23 + or.b32 %r15, %r4, %r1; + .loc 1 27 19 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:27:19 + cvt.s64.s32 %rd1, %r15; + .loc 1 28 21 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:28:21 + or.b64 %rd26, %rd1, %rd23; + and.b64 %rd27, %rd26, -4294967296; + setp.ne.b64 %p1, %rd27, 0; + cvt.u32.u64 %r61, %rd1; + @%p1 bra $L__BB0_2; + bra.uni $L__BB0_1; +$L__BB0_2: + div.s64 %rd62, %rd1, %rd23; + bra.uni $L__BB0_3; +$L__BB0_1: + cvt.u32.u64 %r16, %rd23; + div.u32 %r18, %r61, %r16; + cvt.u64.u32 %rd62, %r18; +$L__BB0_3: + .loc 1 0 21 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:0:21 + ld.param.b32 %r13, [triton_red_fused_zeros_0_param_5]; + ld.param.b64 %rd21, [triton_red_fused_zeros_0_param_1]; + ld.param.b64 %rd20, [triton_red_fused_zeros_0_param_0]; + .loc 1 27 19 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:27:19 + mul.lo.s64 %rd6, %rd62, %rd23; + sub.s64 %rd7, %rd1, %rd6; + .loc 1 28 28 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:28:28 + shr.s64 %rd28, %rd62, 63; + shr.u64 %rd29, %rd28, 59; + add.s64 %rd30, %rd62, %rd29; + and.b64 %rd31, %rd30, -32; + sub.s64 %rd8, %rd62, %rd31; + .loc 1 29 19 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:29:19 + or.b64 %rd32, %rd1, %rd24; + and.b64 %rd33, %rd32, -4294967296; + setp.ne.b64 %p2, %rd33, 0; + @%p2 bra $L__BB0_5; + bra.uni $L__BB0_4; +$L__BB0_5: + div.s64 %rd63, %rd1, %rd24; + bra.uni $L__BB0_6; +$L__BB0_4: + cvt.u32.u64 %r19, %rd24; + div.u32 %r21, %r61, %r19; + cvt.u64.u32 %rd63, %r21; +$L__BB0_6: + .loc 1 0 19 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:0:19 + ld.param.b64 %rd22, [triton_red_fused_zeros_0_param_2]; + and.b32 %r3, %r2, 126; +$L__tmp1: + .loc 2 74 34 // triton_helpers.py:74:34 @[ cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:30:51 ] + setp.ne.b64 %p3, %rd7, 0; + .loc 2 75 25 // triton_helpers.py:75:25 @[ cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:30:51 ] + setp.lt.s32 %p4, %r1, 0; + .loc 2 75 36 // triton_helpers.py:75:36 @[ cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:30:51 ] + setp.lt.s64 %p5, %rd23, 0; + .loc 2 75 32 // triton_helpers.py:75:32 @[ cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:30:51 ] + xor.pred %p6, %p4, %p5; +$L__tmp2: + .loc 1 40 73 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:40:73 + setp.lt.s64 %p7, %rd23, 2; + .loc 1 40 99 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:40:99 + setp.gt.s64 %p8, %rd23, 1; + .loc 1 40 90 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:40:90 + selp.b64 %rd35, %rd23, 0, %p8; + .loc 1 40 0 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:40 + selp.b64 %rd36, 1, 0, %p7; + .loc 1 40 81 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:40:81 + add.s64 %rd37, %rd35, %rd36; + shl.b64 %rd38, %rd8, 8; + .loc 1 33 40 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:33:40 + and.pred %p9, %p3, %p6; + selp.b64 %rd39, -1, 0, %p9; + add.s64 %rd40, %rd62, %rd39; + mul.lo.s64 %rd41, %rd40, %rd37; + shl.b64 %rd42, %rd41, 8; + add.s32 %r23, %r1, %r4; + mad.wide.s32 %rd43, %r23, 256, %rd42; + and.b32 %r24, %r2, 1; + mul.wide.u32 %rd44, %r24, 8; + or.b64 %rd45, %rd43, %rd44; + shl.b64 %rd46, %rd6, 8; + sub.s64 %rd47, %rd45, %rd46; + add.s64 %rd65, %rd21, %rd47; + mul.lo.s64 %rd48, %rd63, %rd23; + shl.b64 %rd49, %rd48, 13; + mad.wide.s32 %rd50, %r23, 8192, %rd49; + add.s64 %rd51, %rd50, %rd38; + or.b64 %rd52, %rd51, %rd44; + shl.b64 %rd53, %rd6, 13; + sub.s64 %rd54, %rd52, %rd53; + add.s64 %rd64, %rd20, %rd54; + mov.b32 %r62, 0f00000000; + mov.b64 %rd66, -8; + setp.lt.s32 %p10, %r61, %r13; + mov.b32 %r63, %r62; + mov.b32 %r64, %r62; + mov.b32 %r65, %r62; +$L__BB0_7: // =>This Inner Loop Header: Depth=1 + .loc 1 39 74 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:39:74 + // begin inline asm + mov.u64 %rd55, 0x0; + createpolicy.fractional.L2::evict_first.b64 %rd55, 1.0; + // end inline asm + mov.b32 %r27, 0; + // begin inline asm + mov.u32 %r25, %r27; + mov.u32 %r26, %r27; + @%p10 ld.global.L1::evict_first.L2::cache_hint.v2.b32 { %r25, %r26 }, [ %rd64 + 0 ], %rd55; + // end inline asm + mov.b32 {%rs1, %rs2}, %r25; + mov.b32 {%rs3, %rs4}, %r26; + .loc 1 39 136 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:39:136 + cvt.f32.bf16 %r34, %rs1; + cvt.f32.bf16 %r35, %rs2; + cvt.f32.bf16 %r36, %rs3; + cvt.f32.bf16 %r37, %rs4; + .loc 1 40 106 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:40:106 + // begin inline asm + mov.u64 %rd58, 0x0; + createpolicy.fractional.L2::evict_first.b64 %rd58, 1.0; + // end inline asm + // begin inline asm + mov.u32 %r29, %r27; + mov.u32 %r30, %r27; + @%p10 ld.global.L1::evict_first.L2::cache_hint.v2.b32 { %r29, %r30 }, [ %rd65 + 0 ], %rd58; + // end inline asm + mov.b32 {%rs5, %rs6}, %r29; + mov.b32 {%rs7, %rs8}, %r30; + .loc 1 40 168 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:40:168 + cvt.f32.bf16 %r38, %rs5; + cvt.f32.bf16 %r39, %rs6; + cvt.f32.bf16 %r40, %rs7; + cvt.f32.bf16 %r41, %rs8; + .loc 1 43 23 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:43:23 + fma.rn.f32 %r42, %r34, %r38, %r62; + fma.rn.f32 %r43, %r35, %r39, %r63; + fma.rn.f32 %r44, %r36, %r40, %r64; + fma.rn.f32 %r45, %r37, %r41, %r65; + .loc 1 44 48 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:44:48 + selp.f32 %r62, %r42, %r62, %p10; + selp.f32 %r63, %r43, %r63, %p10; + selp.f32 %r64, %r44, %r64, %p10; + selp.f32 %r65, %r45, %r65, %p10; + .loc 1 33 40 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:33:40 + add.s64 %rd66, %rd66, 8; + add.s64 %rd65, %rd65, 16; + add.s64 %rd64, %rd64, 16; + setp.lt.u64 %p12, %rd66, 120; + @%p12 bra $L__BB0_7; +// %bb.8: + .loc 1 23 44 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:23:44 + and.b32 %r47, %r2, 63; + .loc 1 23 23 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:23:23 + or.b32 %r48, %r1, %r47; + .loc 1 24 21 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:24:21 + setp.lt.s32 %p14, %r48, %r13; +$L__tmp3: + .loc 3 261 15 // standard.py:261:15 @[ cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:45:25 ] + add.f32 %r49, %r62, %r63; + add.f32 %r50, %r64, %r49; + add.f32 %r51, %r65, %r50; + .loc 3 291 36 // standard.py:291:36 @[ cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:45:25 ] + shfl.sync.bfly.b32 %r52, %r51, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:45:25 ] + add.f32 %r53, %r51, %r52; +$L__tmp4: + .loc 1 45 28 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:45:28 + shl.b32 %r54, %r3, 1; + mov.b32 %r55, global_smem; + add.s32 %r56, %r55, %r54; + st.shared.b32 [%r56], %r53; + bar.sync 0; + shl.b32 %r57, %r47, 2; + add.s32 %r58, %r55, %r57; + ld.shared.b32 %r46, [%r58]; + .loc 1 49 25 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:49:25 + mad.wide.s32 %rd61, %r48, 4, %rd22; + .loc 1 49 36 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:49:36 + and.b32 %r59, %r2, 64; + setp.eq.b32 %p15, %r59, 0; + and.pred %p13, %p15, %p14; + // begin inline asm + @%p13 st.global.b32 [ %rd61 + 0 ], { %r46 }; + // end inline asm + .loc 1 49 4 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:49:4 + ret; +$L__tmp5: +$L__func_end0: + // -- End function +} + .file 1 "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py" + .file 2 "/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py" + .file 3 "/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py" + .section .debug_abbrev + { +.b8 1 // Abbreviation Code +.b8 17 // DW_TAG_compile_unit +.b8 1 // DW_CHILDREN_yes +.b8 37 // DW_AT_producer +.b8 8 // DW_FORM_string +.b8 19 // DW_AT_language +.b8 5 // DW_FORM_data2 +.b8 3 // DW_AT_name +.b8 8 // DW_FORM_string +.b8 16 // DW_AT_stmt_list +.b8 6 // DW_FORM_data4 +.b8 27 // DW_AT_comp_dir +.b8 8 // DW_FORM_string +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 2 // Abbreviation Code +.b8 46 // DW_TAG_subprogram +.b8 0 // DW_CHILDREN_no +.b8 3 // DW_AT_name +.b8 8 // DW_FORM_string +.b8 32 // DW_AT_inline +.b8 11 // DW_FORM_data1 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 3 // Abbreviation Code +.b8 46 // DW_TAG_subprogram +.b8 1 // DW_CHILDREN_yes +.b8 17 // DW_AT_low_pc +.b8 1 // DW_FORM_addr +.b8 18 // DW_AT_high_pc +.b8 1 // DW_FORM_addr +.b8 49 // DW_AT_abstract_origin +.b8 19 // DW_FORM_ref4 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 4 // Abbreviation Code +.b8 29 // DW_TAG_inlined_subroutine +.b8 0 // DW_CHILDREN_no +.b8 49 // DW_AT_abstract_origin +.b8 19 // DW_FORM_ref4 +.b8 17 // DW_AT_low_pc +.b8 1 // DW_FORM_addr +.b8 18 // DW_AT_high_pc +.b8 1 // DW_FORM_addr +.b8 88 // DW_AT_call_file +.b8 11 // DW_FORM_data1 +.b8 89 // DW_AT_call_line +.b8 11 // DW_FORM_data1 +.b8 87 // DW_AT_call_column +.b8 11 // DW_FORM_data1 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 0 // EOM(3) + } + .section .debug_info + { +.b32 233 // Length of Unit +.b8 2 // DWARF version number +.b8 0 +.b32 .debug_abbrev // Offset Into Abbrev. Section +.b8 8 // Address Size (in bytes) +.b8 1 // Abbrev [1] 0xb:0xe2 DW_TAG_compile_unit +.b8 116 // DW_AT_producer +.b8 114 +.b8 105 +.b8 116 +.b8 111 +.b8 110 +.b8 0 +.b8 2 // DW_AT_language +.b8 0 +.b8 99 // DW_AT_name +.b8 120 +.b8 99 +.b8 52 +.b8 105 +.b8 110 +.b8 99 +.b8 119 +.b8 109 +.b8 107 +.b8 108 +.b8 50 +.b8 120 +.b8 111 +.b8 104 +.b8 51 +.b8 116 +.b8 54 +.b8 113 +.b8 104 +.b8 114 +.b8 105 +.b8 114 +.b8 104 +.b8 52 +.b8 104 +.b8 117 +.b8 98 +.b8 105 +.b8 118 +.b8 51 +.b8 113 +.b8 54 +.b8 52 +.b8 53 +.b8 118 +.b8 100 +.b8 113 +.b8 53 +.b8 115 +.b8 119 +.b8 98 +.b8 53 +.b8 103 +.b8 50 +.b8 100 +.b8 120 +.b8 101 +.b8 103 +.b8 122 +.b8 55 +.b8 98 +.b8 46 +.b8 112 +.b8 121 +.b8 0 +.b32 .debug_line // DW_AT_stmt_list +.b8 47 // DW_AT_comp_dir +.b8 119 +.b8 111 +.b8 114 +.b8 107 +.b8 115 +.b8 112 +.b8 97 +.b8 99 +.b8 101 +.b8 47 +.b8 104 +.b8 97 +.b8 110 +.b8 114 +.b8 117 +.b8 105 +.b8 47 +.b8 83 +.b8 112 +.b8 101 +.b8 99 +.b8 70 +.b8 111 +.b8 114 +.b8 103 +.b8 101 +.b8 45 +.b8 101 +.b8 120 +.b8 116 +.b8 47 +.b8 99 +.b8 97 +.b8 99 +.b8 104 +.b8 101 +.b8 47 +.b8 99 +.b8 111 +.b8 109 +.b8 112 +.b8 105 +.b8 108 +.b8 101 +.b8 100 +.b8 95 +.b8 107 +.b8 101 +.b8 114 +.b8 110 +.b8 101 +.b8 108 +.b8 115 +.b8 47 +.b8 120 +.b8 99 +.b8 0 +.b8 2 // Abbrev [2] 0x8b:0x1b DW_TAG_subprogram +.b8 116 // DW_AT_name +.b8 114 +.b8 105 +.b8 116 +.b8 111 +.b8 110 +.b8 95 +.b8 114 +.b8 101 +.b8 100 +.b8 95 +.b8 102 +.b8 117 +.b8 115 +.b8 101 +.b8 100 +.b8 95 +.b8 122 +.b8 101 +.b8 114 +.b8 111 +.b8 115 +.b8 95 +.b8 48 +.b8 0 +.b8 1 // DW_AT_inline +.b8 3 // Abbrev [3] 0xa6:0x46 DW_TAG_subprogram +.b64 $L__func_begin0 // DW_AT_low_pc +.b64 $L__func_end0 // DW_AT_high_pc +.b32 139 // DW_AT_abstract_origin +.b8 4 // Abbrev [4] 0xbb:0x18 DW_TAG_inlined_subroutine +.b32 139 // DW_AT_abstract_origin +.b64 $L__tmp1 // DW_AT_low_pc +.b64 $L__tmp2 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 30 // DW_AT_call_line +.b8 51 // DW_AT_call_column +.b8 4 // Abbrev [4] 0xd3:0x18 DW_TAG_inlined_subroutine +.b32 139 // DW_AT_abstract_origin +.b64 $L__tmp3 // DW_AT_low_pc +.b64 $L__tmp4 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 45 // DW_AT_call_line +.b8 25 // DW_AT_call_column +.b8 0 // End Of Children Mark +.b8 0 // End Of Children Mark + } + .section .debug_macinfo { } diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/ADCZ7SME2JTK2CEPVRJ7J4XLDGQGUMK5VNKNGD5DTWMOCOTS4J2Q/triton_red_fused_zeros_0.source b/SpecForge-ext/cache/compiled_kernels/triton/0/ADCZ7SME2JTK2CEPVRJ7J4XLDGQGUMK5VNKNGD5DTWMOCOTS4J2Q/triton_red_fused_zeros_0.source new file mode 100644 index 0000000000000000000000000000000000000000..304029efee43c36005f1ac94ff2b23eeac4c9d27 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/ADCZ7SME2JTK2CEPVRJ7J4XLDGQGUMK5VNKNGD5DTWMOCOTS4J2Q/triton_red_fused_zeros_0.source @@ -0,0 +1,325 @@ +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":18:0) +#loc56 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":69:0) +#loc68 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":285:0) +#loc70 = loc(unknown) +#loc73 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":260:0) +#loc77 = loc("in_ptr0"(#loc)) +#loc78 = loc("in_ptr1"(#loc)) +#loc79 = loc("out_ptr1"(#loc)) +#loc80 = loc("ks0"(#loc)) +#loc81 = loc("ks1"(#loc)) +#loc82 = loc("xnumel"(#loc)) +#loc83 = loc("r0_numel"(#loc)) +#loc135 = loc("a"(#loc56)) +#loc136 = loc("b"(#loc56)) +#loc142 = loc("input"(#loc68)) +#loc143 = loc("a"(#loc73)) +#loc144 = loc("b"(#loc73)) +module { + tt.func public @triton_red_fused_zeros_0(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %in_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr1"(#loc)), %out_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr1"(#loc)), %ks0: i64 loc("ks0"(#loc)), %ks1: i64 {tt.divisibility = 16 : i32} loc("ks1"(#loc)), %xnumel: i32 {tt.divisibility = 16 : i32} loc("xnumel"(#loc)), %r0_numel: i32 {tt.divisibility = 16 : i32} loc("r0_numel"(#loc))) attributes {noinline = false} { + %r0_numel_0 = arith.constant 128 : i32 loc(#loc84) + %xoffset = tt.get_program_id x : i32 loc(#loc85) + %xoffset_1 = arith.constant 64 : i32 loc(#loc86) + %xoffset_2 = arith.constant 64 : i32 loc(#loc86) + %xoffset_3 = arith.muli %xoffset, %xoffset_2 : i32 loc(#loc86) + %xindex = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc87) + %xindex_4 = tt.expand_dims %xindex {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc88) + %xindex_5 = tt.splat %xoffset_3 : i32 -> tensor<64x1xi32> loc(#loc89) + %xindex_6 = arith.addi %xindex_5, %xindex_4 : tensor<64x1xi32> loc(#loc89) + %xmask = tt.splat %xnumel : i32 -> tensor<64x1xi32> loc(#loc90) + %xmask_7 = arith.cmpi slt, %xindex_6, %xmask : tensor<64x1xi32> loc(#loc90) + %r0_base = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> loc(#loc91) + %r0_base_8 = tt.expand_dims %r0_base {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> loc(#loc92) + %x0 = arith.extsi %xindex_6 : tensor<64x1xi32> to tensor<64x1xi64> loc(#loc93) + %x0_9 = tt.splat %ks0 : i64 -> tensor<64x1xi64> loc(#loc93) + %x0_10 = arith.remsi %x0, %x0_9 : tensor<64x1xi64> loc(#loc93) + %x1 = arith.extsi %xindex_6 : tensor<64x1xi32> to tensor<64x1xi64> loc(#loc94) + %x1_11 = tt.splat %ks0 : i64 -> tensor<64x1xi64> loc(#loc94) + %x1_12 = arith.divsi %x1, %x1_11 : tensor<64x1xi64> loc(#loc94) + %x1_13 = arith.constant 32 : i32 loc(#loc95) + %x1_14 = arith.constant 32 : i64 loc(#loc95) + %x1_15 = arith.constant dense<32> : tensor<64x1xi64> loc(#loc95) + %x1_16 = arith.remsi %x1_12, %x1_15 : tensor<64x1xi64> loc(#loc95) + %x2 = arith.extsi %xindex_6 : tensor<64x1xi32> to tensor<64x1xi64> loc(#loc96) + %x2_17 = tt.splat %ks1 : i64 -> tensor<64x1xi64> loc(#loc96) + %x2_18 = arith.divsi %x2, %x2_17 : tensor<64x1xi64> loc(#loc96) + %x5 = tt.call @torch._inductor.runtime.triton_helpers.div_floor_integer__i32S64_1S_i64__(%xindex_6, %ks0) : (tensor<64x1xi32>, i64) -> tensor<64x1xi64> loc(#loc97) + %_tmp4 = arith.constant 0.000000e+00 : f32 loc(#loc98) + %_tmp4_19 = arith.constant dense<0.000000e+00> : tensor<64x8xf32> loc(#loc98) + %c0_i32 = arith.constant 0 : i32 loc(#loc16) + %c8_i32 = arith.constant 8 : i32 loc(#loc16) + %0 = arith.bitcast %c0_i32 : i32 to i32 loc(#loc16) + %1 = arith.bitcast %r0_numel_0 : i32 to i32 loc(#loc16) + %2 = arith.bitcast %c8_i32 : i32 to i32 loc(#loc16) + %3 = ub.poison : i32 loc(#loc16) + %_tmp4_20 = scf.for %r0_offset = %0 to %1 step %2 iter_args(%_tmp4_23 = %_tmp4_19) -> (tensor<64x8xf32>) : i32 { + %r0_index = tt.splat %r0_offset : i32 -> tensor<1x8xi32> loc(#loc100) + %r0_index_24 = arith.addi %r0_index, %r0_base_8 : tensor<1x8xi32> loc(#loc100) + %r0_mask = arith.constant dense<128> : tensor<1x8xi32> loc(#loc101) + %r0_mask_25 = arith.cmpi slt, %r0_index_24, %r0_mask : tensor<1x8xi32> loc(#loc101) + %tmp0 = arith.constant 128 : i32 loc(#loc102) + %tmp0_26 = arith.constant 128 : i64 loc(#loc102) + %tmp0_27 = arith.constant dense<128> : tensor<64x1xi64> loc(#loc102) + %tmp0_28 = arith.muli %tmp0_27, %x1_16 : tensor<64x1xi64> loc(#loc102) + %tmp0_29 = arith.extsi %r0_index_24 : tensor<1x8xi32> to tensor<1x8xi64> loc(#loc103) + %tmp0_30 = tt.broadcast %tmp0_29 : tensor<1x8xi64> -> tensor<64x8xi64> loc(#loc103) + %tmp0_31 = tt.broadcast %tmp0_28 : tensor<64x1xi64> -> tensor<64x8xi64> loc(#loc103) + %tmp0_32 = arith.addi %tmp0_30, %tmp0_31 : tensor<64x8xi64> loc(#loc103) + %tmp0_33 = arith.constant 4096 : i32 loc(#loc104) + %tmp0_34 = arith.constant 4096 : i64 loc(#loc104) + %tmp0_35 = arith.constant dense<4096> : tensor<64x1xi64> loc(#loc104) + %tmp0_36 = arith.muli %tmp0_35, %x0_10 : tensor<64x1xi64> loc(#loc104) + %tmp0_37 = tt.broadcast %tmp0_36 : tensor<64x1xi64> -> tensor<64x8xi64> loc(#loc105) + %tmp0_38 = arith.addi %tmp0_32, %tmp0_37 : tensor<64x8xi64> loc(#loc105) + %tmp0_39 = arith.constant 4096 : i32 loc(#loc106) + %tmp0_40 = arith.constant 4096 : i64 loc(#loc106) + %tmp0_41 = arith.muli %tmp0_40, %ks0 : i64 loc(#loc106) + %tmp0_42 = tt.splat %tmp0_41 : i64 -> tensor<64x1xi64> loc(#loc107) + %tmp0_43 = arith.muli %tmp0_42, %x2_18 : tensor<64x1xi64> loc(#loc107) + %tmp0_44 = tt.broadcast %tmp0_43 : tensor<64x1xi64> -> tensor<64x8xi64> loc(#loc108) + %tmp0_45 = arith.addi %tmp0_38, %tmp0_44 : tensor<64x8xi64> loc(#loc108) + %tmp0_46 = tt.splat %in_ptr0 : !tt.ptr -> tensor<64x8x!tt.ptr> loc(#loc109) + %tmp0_47 = tt.addptr %tmp0_46, %tmp0_45 : tensor<64x8x!tt.ptr>, tensor<64x8xi64> loc(#loc109) + %tmp0_48 = tt.broadcast %r0_mask_25 : tensor<1x8xi1> -> tensor<64x8xi1> loc(#loc110) + %tmp0_49 = tt.broadcast %xmask_7 : tensor<64x1xi1> -> tensor<64x8xi1> loc(#loc110) + %tmp0_50 = arith.andi %tmp0_48, %tmp0_49 : tensor<64x8xi1> loc(#loc110) + %tmp0_51 = arith.constant 0.000000e+00 : f32 loc(#loc111) + %tmp0_52 = arith.constant dense<0.000000e+00> : tensor<64x8xf32> loc(#loc111) + %tmp0_53 = arith.truncf %tmp0_52 : tensor<64x8xf32> to tensor<64x8xbf16> loc(#loc111) + %tmp0_54 = tt.load %tmp0_47, %tmp0_50, %tmp0_53 evictionPolicy = evict_first : tensor<64x8x!tt.ptr> loc(#loc111) + %tmp0_55 = arith.extf %tmp0_54 : tensor<64x8xbf16> to tensor<64x8xf32> loc(#loc112) + %tmp1 = arith.constant 128 : i32 loc(#loc113) + %tmp1_56 = arith.constant 128 : i64 loc(#loc113) + %tmp1_57 = arith.constant dense<128> : tensor<64x1xi64> loc(#loc113) + %tmp1_58 = arith.muli %tmp1_57, %x0_10 : tensor<64x1xi64> loc(#loc113) + %tmp1_59 = arith.extsi %r0_index_24 : tensor<1x8xi32> to tensor<1x8xi64> loc(#loc114) + %tmp1_60 = tt.broadcast %tmp1_59 : tensor<1x8xi64> -> tensor<64x8xi64> loc(#loc114) + %tmp1_61 = tt.broadcast %tmp1_58 : tensor<64x1xi64> -> tensor<64x8xi64> loc(#loc114) + %tmp1_62 = arith.addi %tmp1_60, %tmp1_61 : tensor<64x8xi64> loc(#loc114) + %tmp1_63 = arith.constant 128 : i32 loc(#loc115) + %tmp1_64 = arith.constant 128 : i64 loc(#loc115) + %tmp1_65 = arith.constant dense<128> : tensor<64x1xi64> loc(#loc115) + %tmp1_66 = arith.muli %tmp1_65, %x5 : tensor<64x1xi64> loc(#loc115) + %tmp1_67 = arith.constant 1 : i32 loc(#loc116) + %tmp1_68 = arith.extsi %tmp1_67 : i32 to i64 loc(#loc116) + %tmp1_69 = arith.cmpi sge, %tmp1_68, %ks0 : i64 loc(#loc116) + %tmp1_70 = arith.constant 1 : i32 loc(#loc117) + %tmp1_71 = arith.constant 1 : i32 loc(#loc117) + %tmp1_72 = arith.extui %tmp1_69 : i1 to i32 loc(#loc117) + %tmp1_73 = arith.muli %tmp1_71, %tmp1_72 : i32 loc(#loc117) + %tmp1_74 = arith.constant 1 : i32 loc(#loc118) + %tmp1_75 = arith.extsi %tmp1_74 : i32 to i64 loc(#loc118) + %tmp1_76 = arith.cmpi sgt, %ks0, %tmp1_75 : i64 loc(#loc118) + %tmp1_77 = arith.extui %tmp1_76 : i1 to i64 loc(#loc119) + %tmp1_78 = arith.muli %ks0, %tmp1_77 : i64 loc(#loc119) + %tmp1_79 = arith.extsi %tmp1_73 : i32 to i64 loc(#loc120) + %tmp1_80 = arith.addi %tmp1_79, %tmp1_78 : i64 loc(#loc120) + %tmp1_81 = tt.splat %tmp1_80 : i64 -> tensor<64x1xi64> loc(#loc121) + %tmp1_82 = arith.muli %tmp1_66, %tmp1_81 : tensor<64x1xi64> loc(#loc121) + %tmp1_83 = tt.broadcast %tmp1_82 : tensor<64x1xi64> -> tensor<64x8xi64> loc(#loc122) + %tmp1_84 = arith.addi %tmp1_62, %tmp1_83 : tensor<64x8xi64> loc(#loc122) + %tmp1_85 = tt.splat %in_ptr1 : !tt.ptr -> tensor<64x8x!tt.ptr> loc(#loc123) + %tmp1_86 = tt.addptr %tmp1_85, %tmp1_84 : tensor<64x8x!tt.ptr>, tensor<64x8xi64> loc(#loc123) + %tmp1_87 = tt.broadcast %r0_mask_25 : tensor<1x8xi1> -> tensor<64x8xi1> loc(#loc124) + %tmp1_88 = tt.broadcast %xmask_7 : tensor<64x1xi1> -> tensor<64x8xi1> loc(#loc124) + %tmp1_89 = arith.andi %tmp1_87, %tmp1_88 : tensor<64x8xi1> loc(#loc124) + %tmp1_90 = arith.constant 0.000000e+00 : f32 loc(#loc125) + %tmp1_91 = arith.constant dense<0.000000e+00> : tensor<64x8xf32> loc(#loc125) + %tmp1_92 = arith.truncf %tmp1_91 : tensor<64x8xf32> to tensor<64x8xbf16> loc(#loc125) + %tmp1_93 = tt.load %tmp1_86, %tmp1_89, %tmp1_92 evictionPolicy = evict_first : tensor<64x8x!tt.ptr> loc(#loc125) + %tmp1_94 = arith.extf %tmp1_93 : tensor<64x8xbf16> to tensor<64x8xf32> loc(#loc126) + %tmp2 = arith.mulf %tmp0_55, %tmp1_94 : tensor<64x8xf32> loc(#loc127) + %tmp5 = arith.addf %_tmp4_23, %tmp2 : tensor<64x8xf32> loc(#loc128) + %_tmp4_95 = tt.broadcast %r0_mask_25 : tensor<1x8xi1> -> tensor<64x8xi1> loc(#loc129) + %_tmp4_96 = tt.broadcast %xmask_7 : tensor<64x1xi1> -> tensor<64x8xi1> loc(#loc129) + %_tmp4_97 = arith.andi %_tmp4_95, %_tmp4_96 : tensor<64x8xi1> loc(#loc129) + %_tmp4_98 = arith.select %_tmp4_97, %tmp5, %_tmp4_23 : tensor<64x8xi1>, tensor<64x8xf32> loc(#loc130) + scf.yield %_tmp4_98 : tensor<64x8xf32> loc(#loc48) + } loc(#loc99) + %tmp4 = tt.call @"triton.language.standard.sum__fp32S64_8S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%_tmp4_20) : (tensor<64x8xf32>) -> tensor<64xf32> loc(#loc131) + %tmp4_21 = tt.expand_dims %tmp4 {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32> loc(#loc132) + %tmp7 = arith.constant 0.000000e+00 : f32 loc(#loc133) + %tmp8 = arith.constant dense<0.000000e+00> : tensor<64x1xf32> loc(#loc134) + %tmp8_22 = arith.subf %tmp4_21, %tmp8 : tensor<64x1xf32> loc(#loc134) + %4 = tt.splat %out_ptr1 : !tt.ptr -> tensor<64x1x!tt.ptr> loc(#loc53) + %5 = tt.addptr %4, %xindex_6 : tensor<64x1x!tt.ptr>, tensor<64x1xi32> loc(#loc53) + tt.store %5, %tmp8_22, %xmask_7 : tensor<64x1x!tt.ptr> loc(#loc54) + tt.return loc(#loc55) + } loc(#loc) + tt.func private @torch._inductor.runtime.triton_helpers.div_floor_integer__i32S64_1S_i64__(%a: tensor<64x1xi32> loc("a"(#loc56)), %b: i64 loc("b"(#loc56))) -> tensor<64x1xi64> attributes {noinline = false} { + %quot = arith.extsi %a : tensor<64x1xi32> to tensor<64x1xi64> loc(#loc137) + %quot_0 = tt.splat %b : i64 -> tensor<64x1xi64> loc(#loc137) + %quot_1 = arith.divsi %quot, %quot_0 : tensor<64x1xi64> loc(#loc137) + %remainder = arith.extsi %a : tensor<64x1xi32> to tensor<64x1xi64> loc(#loc138) + %remainder_2 = tt.splat %b : i64 -> tensor<64x1xi64> loc(#loc138) + %remainder_3 = arith.remsi %remainder, %remainder_2 : tensor<64x1xi64> loc(#loc138) + %fixed = arith.constant 0 : i32 loc(#loc139) + %fixed_4 = arith.extsi %fixed : i32 to i64 loc(#loc139) + %fixed_5 = tt.splat %fixed_4 : i64 -> tensor<64x1xi64> loc(#loc139) + %fixed_6 = arith.cmpi ne, %remainder_3, %fixed_5 : tensor<64x1xi64> loc(#loc139) + %fixed_7 = arith.constant 1 : i32 loc(#loc140) + %fixed_8 = arith.constant 1 : i64 loc(#loc140) + %fixed_9 = arith.constant dense<1> : tensor<64x1xi64> loc(#loc140) + %fixed_10 = arith.subi %quot_1, %fixed_9 : tensor<64x1xi64> loc(#loc140) + %fixed_11 = arith.select %fixed_6, %fixed_10, %quot_1 : tensor<64x1xi1>, tensor<64x1xi64> loc(#loc141) + %c0_i32 = arith.constant 0 : i32 loc(#loc62) + %cst = arith.constant dense<0> : tensor<64x1xi32> loc(#loc62) + %0 = arith.cmpi slt, %a, %cst : tensor<64x1xi32> loc(#loc62) + %c0_i32_12 = arith.constant 0 : i32 loc(#loc63) + %1 = arith.extsi %c0_i32_12 : i32 to i64 loc(#loc63) + %2 = arith.cmpi slt, %b, %1 : i64 loc(#loc63) + %3 = tt.splat %2 : i1 -> tensor<64x1xi1> loc(#loc64) + %4 = arith.cmpi ne, %0, %3 : tensor<64x1xi1> loc(#loc64) + %5 = arith.select %4, %fixed_11, %quot_1 : tensor<64x1xi1>, tensor<64x1xi64> loc(#loc65) + tt.return %5 : tensor<64x1xi64> loc(#loc66) + ^bb1: // no predecessors + %6 = ub.poison : tensor<64x1xi64> loc(#loc67) + tt.return %6 : tensor<64x1xi64> loc(#loc67) + } loc(#loc56) + tt.func private @"triton.language.standard.sum__fp32S64_8S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%input: tensor<64x8xf32> loc("input"(#loc68))) -> tensor<64xf32> attributes {noinline = false} { + %0 = "tt.reduce"(%input) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32 loc(unknown), %arg2: f32 loc(unknown)): + %2 = tt.call @triton.language.standard._sum_combine__fp32_fp32__(%arg1, %arg2) : (f32, f32) -> f32 loc(#loc69) + tt.reduce.return %2 : f32 loc(#loc69) + }) : (tensor<64x8xf32>) -> tensor<64xf32> loc(#loc69) + tt.return %0 : tensor<64xf32> loc(#loc71) + ^bb1: // no predecessors + %1 = ub.poison : tensor<64xf32> loc(#loc72) + tt.return %1 : tensor<64xf32> loc(#loc72) + } loc(#loc68) + tt.func private @triton.language.standard._sum_combine__fp32_fp32__(%a: f32 loc("a"(#loc73)), %b: f32 loc("b"(#loc73))) -> f32 attributes {noinline = false} { + %0 = arith.addf %a, %b : f32 loc(#loc74) + tt.return %0 : f32 loc(#loc75) + ^bb1: // no predecessors + %1 = ub.poison : f32 loc(#loc76) + tt.return %1 : f32 loc(#loc76) + } loc(#loc73) +} loc(#loc) +#loc1 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":19:15) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":22:28) +#loc3 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":22:33) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":23:36) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":23:44) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":23:23) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":24:21) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":25:27) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":25:37) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":27:19) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":28:21) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":28:28) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":29:19) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":30:51) +#loc15 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":31:43) +#loc16 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":33:40) +#loc17 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":34:31) +#loc18 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":35:29) +#loc19 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:45) +#loc20 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:41) +#loc21 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:55) +#loc22 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:50) +#loc23 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:65) +#loc24 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:69) +#loc25 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:60) +#loc26 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:34) +#loc27 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:84) +#loc28 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:74) +#loc29 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:136) +#loc30 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:45) +#loc31 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:41) +#loc32 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:54) +#loc33 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:73) +#loc34 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:65) +#loc35 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:99) +#loc36 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:90) +#loc37 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:81) +#loc38 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:58) +#loc39 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:50) +#loc40 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:34) +#loc41 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:116) +#loc42 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:106) +#loc43 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:168) +#loc44 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":41:22) +#loc45 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":43:23) +#loc46 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":44:35) +#loc47 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":44:48) +#loc48 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":44:8) +#loc49 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":45:25) +#loc50 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":45:28) +#loc51 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":47:11) +#loc52 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":48:18) +#loc53 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":49:25) +#loc54 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":49:36) +#loc55 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":49:4) +#loc57 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":72:16) +#loc58 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":73:20) +#loc59 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":74:34) +#loc60 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":74:44) +#loc61 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":74:47) +#loc62 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":75:25) +#loc63 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":75:36) +#loc64 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":75:32) +#loc65 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":75:47) +#loc66 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":75:11) +#loc67 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":75:4) +#loc69 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:36) +#loc71 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:11) +#loc72 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:4) +#loc74 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:15) +#loc75 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:11) +#loc76 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:4) +#loc84 = loc("r0_numel"(#loc1)) +#loc85 = loc("xoffset"(#loc2)) +#loc86 = loc("xoffset"(#loc3)) +#loc87 = loc("xindex"(#loc4)) +#loc88 = loc("xindex"(#loc5)) +#loc89 = loc("xindex"(#loc6)) +#loc90 = loc("xmask"(#loc7)) +#loc91 = loc("r0_base"(#loc8)) +#loc92 = loc("r0_base"(#loc9)) +#loc93 = loc("x0"(#loc10)) +#loc94 = loc("x1"(#loc11)) +#loc95 = loc("x1"(#loc12)) +#loc96 = loc("x2"(#loc13)) +#loc97 = loc("x5"(#loc14)) +#loc98 = loc("_tmp4"(#loc15)) +#loc99 = loc("_tmp4"(#loc16)) +#loc100 = loc("r0_index"(#loc17)) +#loc101 = loc("r0_mask"(#loc18)) +#loc102 = loc("tmp0"(#loc19)) +#loc103 = loc("tmp0"(#loc20)) +#loc104 = loc("tmp0"(#loc21)) +#loc105 = loc("tmp0"(#loc22)) +#loc106 = loc("tmp0"(#loc23)) +#loc107 = loc("tmp0"(#loc24)) +#loc108 = loc("tmp0"(#loc25)) +#loc109 = loc("tmp0"(#loc26)) +#loc110 = loc("tmp0"(#loc27)) +#loc111 = loc("tmp0"(#loc28)) +#loc112 = loc("tmp0"(#loc29)) +#loc113 = loc("tmp1"(#loc30)) +#loc114 = loc("tmp1"(#loc31)) +#loc115 = loc("tmp1"(#loc32)) +#loc116 = loc("tmp1"(#loc33)) +#loc117 = loc("tmp1"(#loc34)) +#loc118 = loc("tmp1"(#loc35)) +#loc119 = loc("tmp1"(#loc36)) +#loc120 = loc("tmp1"(#loc37)) +#loc121 = loc("tmp1"(#loc38)) +#loc122 = loc("tmp1"(#loc39)) +#loc123 = loc("tmp1"(#loc40)) +#loc124 = loc("tmp1"(#loc41)) +#loc125 = loc("tmp1"(#loc42)) +#loc126 = loc("tmp1"(#loc43)) +#loc127 = loc("tmp2"(#loc44)) +#loc128 = loc("tmp5"(#loc45)) +#loc129 = loc("_tmp4"(#loc46)) +#loc130 = loc("_tmp4"(#loc47)) +#loc131 = loc("tmp4"(#loc49)) +#loc132 = loc("tmp4"(#loc50)) +#loc133 = loc("tmp7"(#loc51)) +#loc134 = loc("tmp8"(#loc52)) +#loc137 = loc("quot"(#loc57)) +#loc138 = loc("remainder"(#loc58)) +#loc139 = loc("fixed"(#loc59)) +#loc140 = loc("fixed"(#loc60)) +#loc141 = loc("fixed"(#loc61)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/ADCZ7SME2JTK2CEPVRJ7J4XLDGQGUMK5VNKNGD5DTWMOCOTS4J2Q/triton_red_fused_zeros_0.ttgir b/SpecForge-ext/cache/compiled_kernels/triton/0/ADCZ7SME2JTK2CEPVRJ7J4XLDGQGUMK5VNKNGD5DTWMOCOTS4J2Q/triton_red_fused_zeros_0.ttgir new file mode 100644 index 0000000000000000000000000000000000000000..11d9cea6fbbce0f04649951a0acbbd78d3ee6ebb --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/ADCZ7SME2JTK2CEPVRJ7J4XLDGQGUMK5VNKNGD5DTWMOCOTS4J2Q/triton_red_fused_zeros_0.ttgir @@ -0,0 +1,233 @@ +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1]}> +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":18:0) +#loc1 = loc(unknown) +#loc52 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":45:25) +#loc58 = loc("in_ptr0"(#loc)) +#loc59 = loc("in_ptr1"(#loc)) +#loc60 = loc("out_ptr1"(#loc)) +#loc61 = loc("ks0"(#loc)) +#loc62 = loc("ks1"(#loc)) +#loc63 = loc("xnumel"(#loc)) +#loc64 = loc("r0_numel"(#loc)) +#loc109 = loc("tmp4"(#loc52)) +#loc120 = loc(callsite(#loc1 at #loc109)) +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @triton_red_fused_zeros_0(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %in_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr1"(#loc)), %out_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr1"(#loc)), %ks0: i64 loc("ks0"(#loc)), %ks1: i64 {tt.divisibility = 16 : i32} loc("ks1"(#loc)), %xnumel: i32 {tt.divisibility = 16 : i32} loc("xnumel"(#loc)), %r0_numel: i32 {tt.divisibility = 16 : i32} loc("r0_numel"(#loc))) attributes {noinline = false} { + %cst = arith.constant dense<128> : tensor<1x8xi32, #blocked> loc(#loc1) + %cst_0 = arith.constant dense<128> : tensor<64x1xi64, #blocked> loc(#loc1) + %cst_1 = arith.constant dense<4096> : tensor<64x1xi64, #blocked> loc(#loc1) + %cst_2 = arith.constant dense<0> : tensor<64x1xi64, #blocked> loc(#loc1) + %cst_3 = arith.constant dense<0> : tensor<64x1xi32, #blocked> loc(#loc1) + %cst_4 = arith.constant dense<1> : tensor<64x1xi64, #blocked> loc(#loc1) + %cst_5 = arith.constant dense<32> : tensor<64x1xi64, #blocked> loc(#loc1) + %c64_i32 = arith.constant 64 : i32 loc(#loc1) + %cst_6 = arith.constant dense<0.000000e+00> : tensor<64x8xbf16, #blocked> loc(#loc1) + %c0_i64 = arith.constant 0 : i64 loc(#loc1) + %c1_i64 = arith.constant 1 : i64 loc(#loc1) + %c8_i32 = arith.constant 8 : i32 loc(#loc1) + %c128_i32 = arith.constant 128 : i32 loc(#loc1) + %c0_i32 = arith.constant 0 : i32 loc(#loc1) + %c4096_i64 = arith.constant 4096 : i64 loc(#loc1) + %cst_7 = arith.constant dense<0.000000e+00> : tensor<64x8xf32, #blocked> loc(#loc1) + %xoffset = tt.get_program_id x : i32 loc(#loc65) + %xoffset_8 = arith.muli %xoffset, %c64_i32 : i32 loc(#loc66) + %xindex = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc67) + %xindex_9 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc67) + %xindex_10 = tt.expand_dims %xindex {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc67) + %xindex_11 = tt.expand_dims %xindex_9 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> loc(#loc67) + %xindex_12 = tt.splat %xoffset_8 : i32 -> tensor<64x1xi32, #blocked> loc(#loc68) + %xindex_13 = tt.splat %xoffset_8 : i32 -> tensor<64x1xi32, #blocked1> loc(#loc68) + %xindex_14 = arith.addi %xindex_12, %xindex_10 : tensor<64x1xi32, #blocked> loc(#loc68) + %xindex_15 = arith.addi %xindex_13, %xindex_11 : tensor<64x1xi32, #blocked1> loc(#loc68) + %xmask = tt.splat %xnumel : i32 -> tensor<64x1xi32, #blocked> loc(#loc69) + %xmask_16 = tt.splat %xnumel : i32 -> tensor<64x1xi32, #blocked1> loc(#loc69) + %xmask_17 = arith.cmpi slt, %xindex_14, %xmask : tensor<64x1xi32, #blocked> loc(#loc69) + %xmask_18 = arith.cmpi slt, %xindex_15, %xmask_16 : tensor<64x1xi32, #blocked1> loc(#loc69) + %r0_base = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc70) + %r0_base_19 = tt.expand_dims %r0_base {axis = 0 : i32} : tensor<8xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x8xi32, #blocked> loc(#loc70) + %x0 = arith.extsi %xindex_14 : tensor<64x1xi32, #blocked> to tensor<64x1xi64, #blocked> loc(#loc71) + %x0_20 = tt.splat %ks0 : i64 -> tensor<64x1xi64, #blocked> loc(#loc71) + %x0_21 = arith.remsi %x0, %x0_20 : tensor<64x1xi64, #blocked> loc(#loc71) + %x1 = arith.divsi %x0, %x0_20 : tensor<64x1xi64, #blocked> loc(#loc72) + %x1_22 = arith.remsi %x1, %cst_5 : tensor<64x1xi64, #blocked> loc(#loc73) + %x2 = tt.splat %ks1 : i64 -> tensor<64x1xi64, #blocked> loc(#loc74) + %x2_23 = arith.divsi %x0, %x2 : tensor<64x1xi64, #blocked> loc(#loc74) + %fixed = arith.cmpi ne, %x0_21, %cst_2 : tensor<64x1xi64, #blocked> loc(#loc111) + %fixed_24 = arith.subi %x1, %cst_4 : tensor<64x1xi64, #blocked> loc(#loc112) + %fixed_25 = arith.select %fixed, %fixed_24, %x1 : tensor<64x1xi1, #blocked>, tensor<64x1xi64, #blocked> loc(#loc113) + %x5 = arith.cmpi slt, %xindex_14, %cst_3 : tensor<64x1xi32, #blocked> loc(#loc114) + %x5_26 = arith.cmpi slt, %ks0, %c0_i64 : i64 loc(#loc115) + %x5_27 = tt.splat %x5_26 : i1 -> tensor<64x1xi1, #blocked> loc(#loc116) + %x5_28 = arith.cmpi ne, %x5, %x5_27 : tensor<64x1xi1, #blocked> loc(#loc116) + %x5_29 = arith.select %x5_28, %fixed_25, %x1 : tensor<64x1xi1, #blocked>, tensor<64x1xi64, #blocked> loc(#loc117) + %tmp0 = arith.muli %x1_22, %cst_0 : tensor<64x1xi64, #blocked> loc(#loc79) + %tmp0_30 = tt.broadcast %tmp0 : tensor<64x1xi64, #blocked> -> tensor<64x8xi64, #blocked> loc(#loc80) + %tmp0_31 = arith.muli %x0_21, %cst_1 : tensor<64x1xi64, #blocked> loc(#loc81) + %tmp0_32 = tt.broadcast %tmp0_31 : tensor<64x1xi64, #blocked> -> tensor<64x8xi64, #blocked> loc(#loc82) + %tmp0_33 = arith.muli %ks0, %c4096_i64 : i64 loc(#loc83) + %tmp0_34 = tt.splat %tmp0_33 : i64 -> tensor<64x1xi64, #blocked> loc(#loc84) + %tmp0_35 = arith.muli %tmp0_34, %x2_23 : tensor<64x1xi64, #blocked> loc(#loc84) + %tmp0_36 = tt.broadcast %tmp0_35 : tensor<64x1xi64, #blocked> -> tensor<64x8xi64, #blocked> loc(#loc85) + %tmp0_37 = tt.splat %in_ptr0 : !tt.ptr -> tensor<64x8x!tt.ptr, #blocked> loc(#loc86) + %tmp0_38 = tt.broadcast %xmask_17 : tensor<64x1xi1, #blocked> -> tensor<64x8xi1, #blocked> loc(#loc87) + %tmp1 = arith.muli %x0_21, %cst_0 : tensor<64x1xi64, #blocked> loc(#loc88) + %tmp1_39 = tt.broadcast %tmp1 : tensor<64x1xi64, #blocked> -> tensor<64x8xi64, #blocked> loc(#loc89) + %tmp1_40 = arith.muli %x5_29, %cst_0 : tensor<64x1xi64, #blocked> loc(#loc90) + %tmp1_41 = arith.cmpi sle, %ks0, %c1_i64 : i64 loc(#loc91) + %tmp1_42 = arith.cmpi sgt, %ks0, %c1_i64 : i64 loc(#loc92) + %tmp1_43 = arith.extui %tmp1_42 : i1 to i64 loc(#loc93) + %tmp1_44 = arith.muli %ks0, %tmp1_43 : i64 loc(#loc93) + %tmp1_45 = arith.extui %tmp1_41 : i1 to i64 loc(#loc118) + %tmp1_46 = arith.addi %tmp1_45, %tmp1_44 : i64 loc(#loc94) + %tmp1_47 = tt.splat %tmp1_46 : i64 -> tensor<64x1xi64, #blocked> loc(#loc96) + %tmp1_48 = arith.muli %tmp1_40, %tmp1_47 : tensor<64x1xi64, #blocked> loc(#loc96) + %tmp1_49 = tt.broadcast %tmp1_48 : tensor<64x1xi64, #blocked> -> tensor<64x8xi64, #blocked> loc(#loc97) + %tmp1_50 = tt.splat %in_ptr1 : !tt.ptr -> tensor<64x8x!tt.ptr, #blocked> loc(#loc98) + %_tmp4 = scf.for %_tmp4_53 = %c0_i32 to %c128_i32 step %c8_i32 iter_args(%arg8 = %cst_7) -> (tensor<64x8xf32, #blocked>) : i32 { + %r0_index = tt.splat %_tmp4_53 : i32 -> tensor<1x8xi32, #blocked> loc(#loc100) + %r0_index_54 = arith.addi %r0_index, %r0_base_19 : tensor<1x8xi32, #blocked> loc(#loc100) + %r0_mask = arith.cmpi slt, %r0_index_54, %cst : tensor<1x8xi32, #blocked> loc(#loc101) + %tmp0_55 = arith.extsi %r0_index_54 : tensor<1x8xi32, #blocked> to tensor<1x8xi64, #blocked> loc(#loc80) + %tmp0_56 = tt.broadcast %tmp0_55 : tensor<1x8xi64, #blocked> -> tensor<64x8xi64, #blocked> loc(#loc80) + %tmp0_57 = arith.addi %tmp0_56, %tmp0_30 : tensor<64x8xi64, #blocked> loc(#loc80) + %tmp0_58 = arith.addi %tmp0_57, %tmp0_32 : tensor<64x8xi64, #blocked> loc(#loc82) + %tmp0_59 = arith.addi %tmp0_58, %tmp0_36 : tensor<64x8xi64, #blocked> loc(#loc85) + %tmp0_60 = tt.addptr %tmp0_37, %tmp0_59 : tensor<64x8x!tt.ptr, #blocked>, tensor<64x8xi64, #blocked> loc(#loc86) + %tmp0_61 = tt.broadcast %r0_mask : tensor<1x8xi1, #blocked> -> tensor<64x8xi1, #blocked> loc(#loc87) + %tmp0_62 = arith.andi %tmp0_61, %tmp0_38 : tensor<64x8xi1, #blocked> loc(#loc87) + %tmp0_63 = tt.load %tmp0_60, %tmp0_62, %cst_6 evictionPolicy = evict_first : tensor<64x8x!tt.ptr, #blocked> loc(#loc102) + %tmp0_64 = arith.extf %tmp0_63 : tensor<64x8xbf16, #blocked> to tensor<64x8xf32, #blocked> loc(#loc103) + %tmp1_65 = arith.addi %tmp0_56, %tmp1_39 : tensor<64x8xi64, #blocked> loc(#loc89) + %tmp1_66 = arith.addi %tmp1_65, %tmp1_49 : tensor<64x8xi64, #blocked> loc(#loc97) + %tmp1_67 = tt.addptr %tmp1_50, %tmp1_66 : tensor<64x8x!tt.ptr, #blocked>, tensor<64x8xi64, #blocked> loc(#loc98) + %tmp1_68 = tt.load %tmp1_67, %tmp0_62, %cst_6 evictionPolicy = evict_first : tensor<64x8x!tt.ptr, #blocked> loc(#loc104) + %tmp1_69 = arith.extf %tmp1_68 : tensor<64x8xbf16, #blocked> to tensor<64x8xf32, #blocked> loc(#loc105) + %tmp2 = arith.mulf %tmp0_64, %tmp1_69 : tensor<64x8xf32, #blocked> loc(#loc106) + %tmp5 = arith.addf %arg8, %tmp2 : tensor<64x8xf32, #blocked> loc(#loc107) + %_tmp4_70 = arith.select %tmp0_62, %tmp5, %arg8 : tensor<64x8xi1, #blocked>, tensor<64x8xf32, #blocked> loc(#loc108) + scf.yield %_tmp4_70 : tensor<64x8xf32, #blocked> loc(#loc50) + } loc(#loc99) + %tmp4 = "tt.reduce"(%_tmp4) <{axis = 1 : i32}> ({ + ^bb0(%tmp4_53: f32 loc(callsite(#loc1 at #loc109)), %tmp4_54: f32 loc(callsite(#loc1 at #loc109))): + %tmp4_55 = arith.addf %tmp4_53, %tmp4_54 : f32 loc(#loc121) + tt.reduce.return %tmp4_55 : f32 loc(#loc119) + }) : (tensor<64x8xf32, #blocked>) -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc119) + %tmp4_51 = ttg.convert_layout %tmp4 : tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc110) + %tmp4_52 = tt.expand_dims %tmp4_51 {axis = 1 : i32} : tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xf32, #blocked1> loc(#loc110) + %0 = tt.splat %out_ptr1 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked1> loc(#loc55) + %1 = tt.addptr %0, %xindex_15 : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> loc(#loc55) + tt.store %1, %tmp4_52, %xmask_18 : tensor<64x1x!tt.ptr, #blocked1> loc(#loc56) + tt.return loc(#loc57) + } loc(#loc) +} loc(#loc) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":22:28) +#loc3 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":22:33) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":23:44) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":23:23) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":24:21) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":25:37) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":27:19) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":28:21) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":28:28) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":29:19) +#loc12 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":74:34) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":30:51) +#loc14 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":74:44) +#loc15 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":74:47) +#loc16 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":75:25) +#loc17 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":75:36) +#loc18 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":75:32) +#loc19 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":75:47) +#loc20 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:45) +#loc21 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:41) +#loc22 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:55) +#loc23 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:50) +#loc24 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:65) +#loc25 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:69) +#loc26 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:60) +#loc27 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:34) +#loc28 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:84) +#loc29 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:45) +#loc30 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:41) +#loc31 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:54) +#loc32 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:73) +#loc33 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:99) +#loc34 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:90) +#loc35 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:81) +#loc36 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:65) +#loc37 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:58) +#loc38 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:50) +#loc39 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:34) +#loc40 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":33:40) +#loc41 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":34:31) +#loc42 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":35:29) +#loc43 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:74) +#loc44 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:136) +#loc45 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:106) +#loc46 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:168) +#loc47 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":41:22) +#loc48 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":43:23) +#loc49 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":44:48) +#loc50 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":44:8) +#loc51 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:36) +#loc53 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:15) +#loc54 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":45:28) +#loc55 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":49:25) +#loc56 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":49:36) +#loc57 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":49:4) +#loc65 = loc("xoffset"(#loc2)) +#loc66 = loc("xoffset"(#loc3)) +#loc67 = loc("xindex"(#loc4)) +#loc68 = loc("xindex"(#loc5)) +#loc69 = loc("xmask"(#loc6)) +#loc70 = loc("r0_base"(#loc7)) +#loc71 = loc("x0"(#loc8)) +#loc72 = loc("x1"(#loc9)) +#loc73 = loc("x1"(#loc10)) +#loc74 = loc("x2"(#loc11)) +#loc75 = loc("fixed"(#loc12)) +#loc76 = loc("x5"(#loc13)) +#loc77 = loc("fixed"(#loc14)) +#loc78 = loc("fixed"(#loc15)) +#loc79 = loc("tmp0"(#loc20)) +#loc80 = loc("tmp0"(#loc21)) +#loc81 = loc("tmp0"(#loc22)) +#loc82 = loc("tmp0"(#loc23)) +#loc83 = loc("tmp0"(#loc24)) +#loc84 = loc("tmp0"(#loc25)) +#loc85 = loc("tmp0"(#loc26)) +#loc86 = loc("tmp0"(#loc27)) +#loc87 = loc("tmp0"(#loc28)) +#loc88 = loc("tmp1"(#loc29)) +#loc89 = loc("tmp1"(#loc30)) +#loc90 = loc("tmp1"(#loc31)) +#loc91 = loc("tmp1"(#loc32)) +#loc92 = loc("tmp1"(#loc33)) +#loc93 = loc("tmp1"(#loc34)) +#loc94 = loc("tmp1"(#loc35)) +#loc95 = loc("tmp1"(#loc36)) +#loc96 = loc("tmp1"(#loc37)) +#loc97 = loc("tmp1"(#loc38)) +#loc98 = loc("tmp1"(#loc39)) +#loc99 = loc("_tmp4"(#loc40)) +#loc100 = loc("r0_index"(#loc41)) +#loc101 = loc("r0_mask"(#loc42)) +#loc102 = loc("tmp0"(#loc43)) +#loc103 = loc("tmp0"(#loc44)) +#loc104 = loc("tmp1"(#loc45)) +#loc105 = loc("tmp1"(#loc46)) +#loc106 = loc("tmp2"(#loc47)) +#loc107 = loc("tmp5"(#loc48)) +#loc108 = loc("_tmp4"(#loc49)) +#loc110 = loc("tmp4"(#loc54)) +#loc111 = loc(callsite(#loc75 at #loc76)) +#loc112 = loc(callsite(#loc77 at #loc76)) +#loc113 = loc(callsite(#loc78 at #loc76)) +#loc114 = loc(callsite(#loc16 at #loc76)) +#loc115 = loc(callsite(#loc17 at #loc76)) +#loc116 = loc(callsite(#loc18 at #loc76)) +#loc117 = loc(callsite(#loc19 at #loc76)) +#loc118 = loc(fused[#loc94, #loc95]) +#loc119 = loc(callsite(#loc51 at #loc109)) +#loc121 = loc(callsite(#loc53 at #loc119)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/ADCZ7SME2JTK2CEPVRJ7J4XLDGQGUMK5VNKNGD5DTWMOCOTS4J2Q/triton_red_fused_zeros_0.ttir b/SpecForge-ext/cache/compiled_kernels/triton/0/ADCZ7SME2JTK2CEPVRJ7J4XLDGQGUMK5VNKNGD5DTWMOCOTS4J2Q/triton_red_fused_zeros_0.ttir new file mode 100644 index 0000000000000000000000000000000000000000..664f4aa9b8660f315ab341aaf3066c079ae87158 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/ADCZ7SME2JTK2CEPVRJ7J4XLDGQGUMK5VNKNGD5DTWMOCOTS4J2Q/triton_red_fused_zeros_0.ttir @@ -0,0 +1,228 @@ +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":18:0) +#loc6 = loc(unknown) +#loc54 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":45:25) +#loc60 = loc("in_ptr0"(#loc)) +#loc61 = loc("in_ptr1"(#loc)) +#loc62 = loc("out_ptr1"(#loc)) +#loc63 = loc("ks0"(#loc)) +#loc64 = loc("ks1"(#loc)) +#loc65 = loc("xnumel"(#loc)) +#loc66 = loc("r0_numel"(#loc)) +#loc113 = loc("tmp4"(#loc54)) +#loc124 = loc(callsite(#loc6 at #loc113)) +module { + tt.func public @triton_red_fused_zeros_0(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %in_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr1"(#loc)), %out_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr1"(#loc)), %ks0: i64 loc("ks0"(#loc)), %ks1: i64 {tt.divisibility = 16 : i32} loc("ks1"(#loc)), %xnumel: i32 {tt.divisibility = 16 : i32} loc("xnumel"(#loc)), %r0_numel: i32 {tt.divisibility = 16 : i32} loc("r0_numel"(#loc))) attributes {noinline = false} { + %fixed = arith.constant dense<1> : tensor<64x1xi64> loc(#loc115) + %x5 = arith.constant dense<0> : tensor<64x1xi32> loc(#loc116) + %fixed_0 = arith.constant dense<0> : tensor<64x1xi64> loc(#loc117) + %x5_1 = arith.constant 0 : i64 loc(#loc118) + %c1_i64 = arith.constant 1 : i64 loc(#loc6) + %cst = arith.constant dense<0.000000e+00> : tensor<64x8xbf16> loc(#loc6) + %c8_i32 = arith.constant 8 : i32 loc(#loc7) + %c128_i32 = arith.constant 128 : i32 loc(#loc7) + %c0_i32 = arith.constant 0 : i32 loc(#loc7) + %cst_2 = arith.constant dense<4096> : tensor<64x1xi64> loc(#loc6) + %c4096_i64 = arith.constant 4096 : i64 loc(#loc6) + %cst_3 = arith.constant dense<128> : tensor<64x1xi64> loc(#loc6) + %cst_4 = arith.constant dense<128> : tensor<1x8xi32> loc(#loc6) + %cst_5 = arith.constant dense<0.000000e+00> : tensor<64x8xf32> loc(#loc6) + %x1 = arith.constant dense<32> : tensor<64x1xi64> loc(#loc70) + %c64_i32 = arith.constant 64 : i32 loc(#loc6) + %xoffset = tt.get_program_id x : i32 loc(#loc71) + %xoffset_6 = arith.muli %xoffset, %c64_i32 : i32 loc(#loc72) + %xindex = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc73) + %xindex_7 = tt.expand_dims %xindex {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc74) + %xindex_8 = tt.splat %xoffset_6 : i32 -> tensor<64x1xi32> loc(#loc75) + %xindex_9 = arith.addi %xindex_8, %xindex_7 : tensor<64x1xi32> loc(#loc75) + %xmask = tt.splat %xnumel : i32 -> tensor<64x1xi32> loc(#loc76) + %xmask_10 = arith.cmpi slt, %xindex_9, %xmask : tensor<64x1xi32> loc(#loc76) + %r0_base = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> loc(#loc77) + %r0_base_11 = tt.expand_dims %r0_base {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> loc(#loc78) + %x0 = arith.extsi %xindex_9 : tensor<64x1xi32> to tensor<64x1xi64> loc(#loc79) + %x0_12 = tt.splat %ks0 : i64 -> tensor<64x1xi64> loc(#loc79) + %x0_13 = arith.remsi %x0, %x0_12 : tensor<64x1xi64> loc(#loc79) + %x1_14 = arith.divsi %x0, %x0_12 : tensor<64x1xi64> loc(#loc80) + %x1_15 = arith.remsi %x1_14, %x1 : tensor<64x1xi64> loc(#loc70) + %x2 = tt.splat %ks1 : i64 -> tensor<64x1xi64> loc(#loc81) + %x2_16 = arith.divsi %x0, %x2 : tensor<64x1xi64> loc(#loc81) + %fixed_17 = arith.cmpi ne, %x0_13, %fixed_0 : tensor<64x1xi64> loc(#loc117) + %fixed_18 = arith.subi %x1_14, %fixed : tensor<64x1xi64> loc(#loc115) + %fixed_19 = arith.select %fixed_17, %fixed_18, %x1_14 : tensor<64x1xi1>, tensor<64x1xi64> loc(#loc119) + %x5_20 = arith.cmpi slt, %xindex_9, %x5 : tensor<64x1xi32> loc(#loc116) + %x5_21 = arith.cmpi slt, %ks0, %x5_1 : i64 loc(#loc118) + %x5_22 = tt.splat %x5_21 : i1 -> tensor<64x1xi1> loc(#loc120) + %x5_23 = arith.cmpi ne, %x5_20, %x5_22 : tensor<64x1xi1> loc(#loc120) + %x5_24 = arith.select %x5_23, %fixed_19, %x1_14 : tensor<64x1xi1>, tensor<64x1xi64> loc(#loc121) + %_tmp4 = scf.for %r0_offset = %c0_i32 to %c128_i32 step %c8_i32 iter_args(%_tmp4_26 = %cst_5) -> (tensor<64x8xf32>) : i32 { + %r0_index = tt.splat %r0_offset : i32 -> tensor<1x8xi32> loc(#loc84) + %r0_index_27 = arith.addi %r0_index, %r0_base_11 : tensor<1x8xi32> loc(#loc84) + %r0_mask = arith.cmpi slt, %r0_index_27, %cst_4 : tensor<1x8xi32> loc(#loc85) + %tmp0 = arith.muli %x1_15, %cst_3 : tensor<64x1xi64> loc(#loc86) + %tmp0_28 = arith.extsi %r0_index_27 : tensor<1x8xi32> to tensor<1x8xi64> loc(#loc87) + %tmp0_29 = tt.broadcast %tmp0_28 : tensor<1x8xi64> -> tensor<64x8xi64> loc(#loc87) + %tmp0_30 = tt.broadcast %tmp0 : tensor<64x1xi64> -> tensor<64x8xi64> loc(#loc87) + %tmp0_31 = arith.addi %tmp0_29, %tmp0_30 : tensor<64x8xi64> loc(#loc87) + %tmp0_32 = arith.muli %x0_13, %cst_2 : tensor<64x1xi64> loc(#loc88) + %tmp0_33 = tt.broadcast %tmp0_32 : tensor<64x1xi64> -> tensor<64x8xi64> loc(#loc89) + %tmp0_34 = arith.addi %tmp0_31, %tmp0_33 : tensor<64x8xi64> loc(#loc89) + %tmp0_35 = arith.muli %ks0, %c4096_i64 : i64 loc(#loc90) + %tmp0_36 = tt.splat %tmp0_35 : i64 -> tensor<64x1xi64> loc(#loc91) + %tmp0_37 = arith.muli %tmp0_36, %x2_16 : tensor<64x1xi64> loc(#loc91) + %tmp0_38 = tt.broadcast %tmp0_37 : tensor<64x1xi64> -> tensor<64x8xi64> loc(#loc92) + %tmp0_39 = arith.addi %tmp0_34, %tmp0_38 : tensor<64x8xi64> loc(#loc92) + %tmp0_40 = tt.splat %in_ptr0 : !tt.ptr -> tensor<64x8x!tt.ptr> loc(#loc93) + %tmp0_41 = tt.addptr %tmp0_40, %tmp0_39 : tensor<64x8x!tt.ptr>, tensor<64x8xi64> loc(#loc93) + %tmp0_42 = tt.broadcast %r0_mask : tensor<1x8xi1> -> tensor<64x8xi1> loc(#loc94) + %tmp0_43 = tt.broadcast %xmask_10 : tensor<64x1xi1> -> tensor<64x8xi1> loc(#loc94) + %tmp0_44 = arith.andi %tmp0_42, %tmp0_43 : tensor<64x8xi1> loc(#loc94) + %tmp0_45 = tt.load %tmp0_41, %tmp0_44, %cst evictionPolicy = evict_first : tensor<64x8x!tt.ptr> loc(#loc95) + %tmp0_46 = arith.extf %tmp0_45 : tensor<64x8xbf16> to tensor<64x8xf32> loc(#loc96) + %tmp1 = arith.muli %x0_13, %cst_3 : tensor<64x1xi64> loc(#loc97) + %tmp1_47 = tt.broadcast %tmp1 : tensor<64x1xi64> -> tensor<64x8xi64> loc(#loc98) + %tmp1_48 = arith.addi %tmp0_29, %tmp1_47 : tensor<64x8xi64> loc(#loc98) + %tmp1_49 = arith.muli %x5_24, %cst_3 : tensor<64x1xi64> loc(#loc99) + %tmp1_50 = arith.cmpi sle, %ks0, %c1_i64 : i64 loc(#loc100) + %tmp1_51 = arith.cmpi sgt, %ks0, %c1_i64 : i64 loc(#loc101) + %tmp1_52 = arith.extui %tmp1_51 : i1 to i64 loc(#loc102) + %tmp1_53 = arith.muli %ks0, %tmp1_52 : i64 loc(#loc102) + %tmp1_54 = arith.extui %tmp1_50 : i1 to i64 loc(#loc122) + %tmp1_55 = arith.addi %tmp1_54, %tmp1_53 : i64 loc(#loc103) + %tmp1_56 = tt.splat %tmp1_55 : i64 -> tensor<64x1xi64> loc(#loc105) + %tmp1_57 = arith.muli %tmp1_49, %tmp1_56 : tensor<64x1xi64> loc(#loc105) + %tmp1_58 = tt.broadcast %tmp1_57 : tensor<64x1xi64> -> tensor<64x8xi64> loc(#loc106) + %tmp1_59 = arith.addi %tmp1_48, %tmp1_58 : tensor<64x8xi64> loc(#loc106) + %tmp1_60 = tt.splat %in_ptr1 : !tt.ptr -> tensor<64x8x!tt.ptr> loc(#loc107) + %tmp1_61 = tt.addptr %tmp1_60, %tmp1_59 : tensor<64x8x!tt.ptr>, tensor<64x8xi64> loc(#loc107) + %tmp1_62 = tt.load %tmp1_61, %tmp0_44, %cst evictionPolicy = evict_first : tensor<64x8x!tt.ptr> loc(#loc108) + %tmp1_63 = arith.extf %tmp1_62 : tensor<64x8xbf16> to tensor<64x8xf32> loc(#loc109) + %tmp2 = arith.mulf %tmp0_46, %tmp1_63 : tensor<64x8xf32> loc(#loc110) + %tmp5 = arith.addf %_tmp4_26, %tmp2 : tensor<64x8xf32> loc(#loc111) + %_tmp4_64 = arith.select %tmp0_44, %tmp5, %_tmp4_26 : tensor<64x8xi1>, tensor<64x8xf32> loc(#loc112) + scf.yield %_tmp4_64 : tensor<64x8xf32> loc(#loc52) + } loc(#loc83) + %tmp4 = "tt.reduce"(%_tmp4) <{axis = 1 : i32}> ({ + ^bb0(%tmp4_26: f32 loc(callsite(#loc6 at #loc113)), %tmp4_27: f32 loc(callsite(#loc6 at #loc113))): + %tmp4_28 = arith.addf %tmp4_26, %tmp4_27 : f32 loc(#loc125) + tt.reduce.return %tmp4_28 : f32 loc(#loc123) + }) : (tensor<64x8xf32>) -> tensor<64xf32> loc(#loc123) + %tmp4_25 = tt.expand_dims %tmp4 {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32> loc(#loc114) + %0 = tt.splat %out_ptr1 : !tt.ptr -> tensor<64x1x!tt.ptr> loc(#loc57) + %1 = tt.addptr %0, %xindex_9 : tensor<64x1x!tt.ptr>, tensor<64x1xi32> loc(#loc57) + tt.store %1, %tmp4_25, %xmask_10 : tensor<64x1x!tt.ptr> loc(#loc58) + tt.return loc(#loc59) + } loc(#loc) +} loc(#loc) +#loc1 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":74:44) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":30:51) +#loc3 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":75:25) +#loc4 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":74:34) +#loc5 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":75:36) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":33:40) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":28:28) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":22:28) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":22:33) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":23:36) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":23:44) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":23:23) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":24:21) +#loc15 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":25:27) +#loc16 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":25:37) +#loc17 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":27:19) +#loc18 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":28:21) +#loc19 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":29:19) +#loc20 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":74:47) +#loc21 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":75:32) +#loc22 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":75:47) +#loc23 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":34:31) +#loc24 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":35:29) +#loc25 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:45) +#loc26 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:41) +#loc27 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:55) +#loc28 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:50) +#loc29 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:65) +#loc30 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:69) +#loc31 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:60) +#loc32 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:34) +#loc33 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:84) +#loc34 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:74) +#loc35 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:136) +#loc36 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:45) +#loc37 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:41) +#loc38 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:54) +#loc39 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:73) +#loc40 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:99) +#loc41 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:90) +#loc42 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:81) +#loc43 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:65) +#loc44 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:58) +#loc45 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:50) +#loc46 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:34) +#loc47 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:106) +#loc48 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:168) +#loc49 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":41:22) +#loc50 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":43:23) +#loc51 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":44:48) +#loc52 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":44:8) +#loc53 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:36) +#loc55 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:15) +#loc56 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":45:28) +#loc57 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":49:25) +#loc58 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":49:36) +#loc59 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":49:4) +#loc67 = loc("fixed"(#loc1)) +#loc68 = loc("x5"(#loc2)) +#loc69 = loc("fixed"(#loc4)) +#loc70 = loc("x1"(#loc8)) +#loc71 = loc("xoffset"(#loc9)) +#loc72 = loc("xoffset"(#loc10)) +#loc73 = loc("xindex"(#loc11)) +#loc74 = loc("xindex"(#loc12)) +#loc75 = loc("xindex"(#loc13)) +#loc76 = loc("xmask"(#loc14)) +#loc77 = loc("r0_base"(#loc15)) +#loc78 = loc("r0_base"(#loc16)) +#loc79 = loc("x0"(#loc17)) +#loc80 = loc("x1"(#loc18)) +#loc81 = loc("x2"(#loc19)) +#loc82 = loc("fixed"(#loc20)) +#loc83 = loc("_tmp4"(#loc7)) +#loc84 = loc("r0_index"(#loc23)) +#loc85 = loc("r0_mask"(#loc24)) +#loc86 = loc("tmp0"(#loc25)) +#loc87 = loc("tmp0"(#loc26)) +#loc88 = loc("tmp0"(#loc27)) +#loc89 = loc("tmp0"(#loc28)) +#loc90 = loc("tmp0"(#loc29)) +#loc91 = loc("tmp0"(#loc30)) +#loc92 = loc("tmp0"(#loc31)) +#loc93 = loc("tmp0"(#loc32)) +#loc94 = loc("tmp0"(#loc33)) +#loc95 = loc("tmp0"(#loc34)) +#loc96 = loc("tmp0"(#loc35)) +#loc97 = loc("tmp1"(#loc36)) +#loc98 = loc("tmp1"(#loc37)) +#loc99 = loc("tmp1"(#loc38)) +#loc100 = loc("tmp1"(#loc39)) +#loc101 = loc("tmp1"(#loc40)) +#loc102 = loc("tmp1"(#loc41)) +#loc103 = loc("tmp1"(#loc42)) +#loc104 = loc("tmp1"(#loc43)) +#loc105 = loc("tmp1"(#loc44)) +#loc106 = loc("tmp1"(#loc45)) +#loc107 = loc("tmp1"(#loc46)) +#loc108 = loc("tmp1"(#loc47)) +#loc109 = loc("tmp1"(#loc48)) +#loc110 = loc("tmp2"(#loc49)) +#loc111 = loc("tmp5"(#loc50)) +#loc112 = loc("_tmp4"(#loc51)) +#loc114 = loc("tmp4"(#loc56)) +#loc115 = loc(callsite(#loc67 at #loc68)) +#loc116 = loc(callsite(#loc3 at #loc68)) +#loc117 = loc(callsite(#loc69 at #loc68)) +#loc118 = loc(callsite(#loc5 at #loc68)) +#loc119 = loc(callsite(#loc82 at #loc68)) +#loc120 = loc(callsite(#loc21 at #loc68)) +#loc121 = loc(callsite(#loc22 at #loc68)) +#loc122 = loc(fused[#loc103, #loc104]) +#loc123 = loc(callsite(#loc53 at #loc113)) +#loc125 = loc(callsite(#loc55 at #loc123)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/BQTM5LTBJQLCQDNUK7ALGANOLOCHSBE5FT7MUYNA4PYBS5SFUMZQ/__grp__triton_red_fused_zeros_0.json b/SpecForge-ext/cache/compiled_kernels/triton/0/BQTM5LTBJQLCQDNUK7ALGANOLOCHSBE5FT7MUYNA4PYBS5SFUMZQ/__grp__triton_red_fused_zeros_0.json new file mode 100644 index 0000000000000000000000000000000000000000..dd12aec64a39e4ce2a816c6a0b5ff470291a8219 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/BQTM5LTBJQLCQDNUK7ALGANOLOCHSBE5FT7MUYNA4PYBS5SFUMZQ/__grp__triton_red_fused_zeros_0.json @@ -0,0 +1 @@ +{"child_paths": {"triton_red_fused_zeros_0.source": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/BQTM5LTBJQLCQDNUK7ALGANOLOCHSBE5FT7MUYNA4PYBS5SFUMZQ/triton_red_fused_zeros_0.source", "triton_red_fused_zeros_0.ttir": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/BQTM5LTBJQLCQDNUK7ALGANOLOCHSBE5FT7MUYNA4PYBS5SFUMZQ/triton_red_fused_zeros_0.ttir", "triton_red_fused_zeros_0.ttgir": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/BQTM5LTBJQLCQDNUK7ALGANOLOCHSBE5FT7MUYNA4PYBS5SFUMZQ/triton_red_fused_zeros_0.ttgir", "triton_red_fused_zeros_0.llir": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/BQTM5LTBJQLCQDNUK7ALGANOLOCHSBE5FT7MUYNA4PYBS5SFUMZQ/triton_red_fused_zeros_0.llir", "triton_red_fused_zeros_0.ptx": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/BQTM5LTBJQLCQDNUK7ALGANOLOCHSBE5FT7MUYNA4PYBS5SFUMZQ/triton_red_fused_zeros_0.ptx", "triton_red_fused_zeros_0.cubin": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/BQTM5LTBJQLCQDNUK7ALGANOLOCHSBE5FT7MUYNA4PYBS5SFUMZQ/triton_red_fused_zeros_0.cubin", "triton_red_fused_zeros_0.json": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/BQTM5LTBJQLCQDNUK7ALGANOLOCHSBE5FT7MUYNA4PYBS5SFUMZQ/triton_red_fused_zeros_0.json"}} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/BQTM5LTBJQLCQDNUK7ALGANOLOCHSBE5FT7MUYNA4PYBS5SFUMZQ/triton_red_fused_zeros_0.cubin b/SpecForge-ext/cache/compiled_kernels/triton/0/BQTM5LTBJQLCQDNUK7ALGANOLOCHSBE5FT7MUYNA4PYBS5SFUMZQ/triton_red_fused_zeros_0.cubin new file mode 100644 index 0000000000000000000000000000000000000000..f7d8a8e80d241671315b84cce75d242b7e951520 Binary files /dev/null and b/SpecForge-ext/cache/compiled_kernels/triton/0/BQTM5LTBJQLCQDNUK7ALGANOLOCHSBE5FT7MUYNA4PYBS5SFUMZQ/triton_red_fused_zeros_0.cubin differ diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/BQTM5LTBJQLCQDNUK7ALGANOLOCHSBE5FT7MUYNA4PYBS5SFUMZQ/triton_red_fused_zeros_0.json b/SpecForge-ext/cache/compiled_kernels/triton/0/BQTM5LTBJQLCQDNUK7ALGANOLOCHSBE5FT7MUYNA4PYBS5SFUMZQ/triton_red_fused_zeros_0.json new file mode 100644 index 0000000000000000000000000000000000000000..554762f9f2fd13b86f0e11f31ad25f1eb3e76d6b --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/BQTM5LTBJQLCQDNUK7ALGANOLOCHSBE5FT7MUYNA4PYBS5SFUMZQ/triton_red_fused_zeros_0.json @@ -0,0 +1 @@ +{"hash": "0c26ceae614c16280db457c0b301ae5b8479049d2cfeca61a0e3f0197645a333", "target": {"backend": "cuda", "arch": 90, "warp_size": 32}, "num_warps": 8, "num_ctas": 1, "num_stages": 1, "warp_size": 32, "maxnreg": null, "cluster_dims": [1, 1, 1], "ptx_version": null, "ptx_options": null, "ir_override": null, "enable_fp_fusion": true, "launch_cooperative_grid": false, "launch_pdl": false, "supported_fp8_dtypes": ["fp8e4b15", "fp8e4nv", "fp8e5"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b15"], "default_dot_input_precision": "tf32", "allowed_dot_input_precisions": ["tf32", "tf32x3", "ieee"], "max_num_imprecise_acc_default": 1073741824, "extern_libs": [["libdevice", "/workspace/specforge/lib/python3.11/site-packages/triton/backends/nvidia/lib/libdevice.10.bc"]], "debug": true, "backend_name": "cuda", "sanitize_overflow": false, "arch": "sm90", "instrumentation_mode": "", "triton_version": "3.5.1", "tensordesc_meta": [], "shared": 256, "tmem_size": 0, "global_scratch_size": 0, "global_scratch_align": 1, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "triton_red_fused_zeros_0"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/BQTM5LTBJQLCQDNUK7ALGANOLOCHSBE5FT7MUYNA4PYBS5SFUMZQ/triton_red_fused_zeros_0.llir b/SpecForge-ext/cache/compiled_kernels/triton/0/BQTM5LTBJQLCQDNUK7ALGANOLOCHSBE5FT7MUYNA4PYBS5SFUMZQ/triton_red_fused_zeros_0.llir new file mode 100644 index 0000000000000000000000000000000000000000..c1d3404989117a3d1a32d73a6626debb2315d125 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/BQTM5LTBJQLCQDNUK7ALGANOLOCHSBE5FT7MUYNA4PYBS5SFUMZQ/triton_red_fused_zeros_0.llir @@ -0,0 +1,140 @@ +; ModuleID = 'LLVMDialectModule' +source_filename = "LLVMDialectModule" +target datalayout = "e-p3:32:32-p4:32:32-p5:32:32-p6:32:32-p7:32:32-i64:64-i128:128-v16:16-v32:32-n16:32:64" + +@global_smem = external local_unnamed_addr addrspace(3) global [0 x i8], align 16 + +; Function Attrs: nounwind +define ptx_kernel void @triton_red_fused_zeros_0(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, i32 %3, i32 %4, ptr addrspace(1) readnone captures(none) %5, ptr addrspace(1) readnone captures(none) %6) local_unnamed_addr #0 !dbg !4 { + %8 = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !dbg !7 + %9 = shl i32 %8, 6, !dbg !8 + %10 = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !dbg !9 + %11 = and i32 %10, 252, !dbg !9 + %12 = lshr exact i32 %11, 2, !dbg !9 + %13 = or disjoint i32 %12, %9, !dbg !10 + %14 = and i32 %10, 3, !dbg !11 + %15 = sdiv i32 %13, 2048, !dbg !12 + %16 = mul i32 %15, 2048, !dbg !13 + %.decomposed = sub i32 %13, %16, !dbg !13 + %17 = srem i32 %15, 32, !dbg !14 + %18 = sdiv i32 %13, 65536, !dbg !15 + %19 = shl nsw i32 %17, 7, !dbg !16 + %20 = shl nsw i32 %.decomposed, 12, !dbg !17 + %21 = shl i32 %18, 23, !dbg !18 + %22 = shl i32 %13, 7, !dbg !19 + %23 = add i32 %21, %20 + %24 = add i32 %23, %19 + %25 = zext nneg i32 %14 to i64, !dbg !20 + %26 = sext i32 %22 to i64, !dbg !20 + %invariant.gep = getelementptr bfloat, ptr addrspace(1) %1, i64 %26, !dbg !20 + br label %27, !dbg !20 + +27: ; preds = %7, %27 + %indvars.iv = phi i64 [ 0, %7 ], [ %indvars.iv.next, %27 ] + %28 = phi float [ 0.000000e+00, %7 ], [ %43, %27 ] + %29 = or disjoint i64 %indvars.iv, %25, !dbg !21 + %30 = trunc nuw nsw i64 %29 to i32, !dbg !22 + %31 = add i32 %24, %30, !dbg !22 + %32 = sext i32 %31 to i64, !dbg !23 + %33 = getelementptr bfloat, ptr addrspace(1) %0, i64 %32, !dbg !23 + %34 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_first.b64 $0, 1.0;", "=l"() #4, !dbg !24 + %35 = tail call i16 asm sideeffect "mov.u16 $0, $1;\0A\09@$4 ld.global.L1::evict_first.L2::cache_hint.b16 { $0 }, [ $2 + 0 ], $3;", "=c,c,l,l,b"(i16 0, ptr addrspace(1) %33, i64 %34, i1 true) #4, !dbg !24 + %36 = bitcast i16 %35 to bfloat, !dbg !24 + %37 = fpext bfloat %36 to float, !dbg !25 + %gep = getelementptr bfloat, ptr addrspace(1) %invariant.gep, i64 %29, !dbg !26 + %38 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_first.b64 $0, 1.0;", "=l"() #4, !dbg !27 + %39 = tail call i16 asm sideeffect "mov.u16 $0, $1;\0A\09@$4 ld.global.L1::evict_first.L2::cache_hint.b16 { $0 }, [ $2 + 0 ], $3;", "=c,c,l,l,b"(i16 0, ptr addrspace(1) %gep, i64 %38, i1 true) #4, !dbg !27 + %40 = bitcast i16 %39 to bfloat, !dbg !27 + %41 = fpext bfloat %40 to float, !dbg !28 + %42 = fmul float %37, %41, !dbg !29 + %43 = fadd float %28, %42, !dbg !30 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 4, !dbg !20 + %44 = icmp samesign ult i64 %indvars.iv, 124, !dbg !20 + br i1 %44, label %27, label %45, !dbg !20 + +45: ; preds = %27 + %46 = and i32 %10, 63, !dbg !9 + %47 = or disjoint i32 %9, %46, !dbg !10 + %48 = bitcast float %43 to i32, !dbg !31 + %49 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %48, i32 2, i32 31), !dbg !31 + %50 = bitcast i32 %49 to float, !dbg !31 + %51 = fadd float %43, %50, !dbg !35 + %52 = bitcast float %51 to i32, !dbg !31 + %53 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %52, i32 1, i32 31), !dbg !31 + %54 = bitcast i32 %53 to float, !dbg !31 + %55 = fadd float %51, %54, !dbg !35 + %56 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %11, !dbg !36 + store float %55, ptr addrspace(3) %56, align 4, !dbg !36 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !36 + %57 = shl nuw nsw i32 %46, 2, !dbg !36 + %58 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %57, !dbg !36 + %59 = load i32, ptr addrspace(3) %58, align 4, !dbg !36 + %60 = sext i32 %47 to i64, !dbg !37 + %61 = getelementptr float, ptr addrspace(1) %2, i64 %60, !dbg !37 + %62 = and i32 %10, 192, !dbg !38 + %63 = icmp eq i32 %62, 0, !dbg !38 + tail call void asm sideeffect "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b"(i32 %59, ptr addrspace(1) %61, i1 %63) #4, !dbg !38 + ret void, !dbg !39 +} + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare noundef range(i32 0, 2147483647) i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #1 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare noundef range(i32 0, 1024) i32 @llvm.nvvm.read.ptx.sreg.tid.x() #1 + +; Function Attrs: convergent nocallback nounwind memory(inaccessiblemem: readwrite) +declare i32 @llvm.nvvm.shfl.sync.bfly.i32(i32, i32, i32, i32) #2 + +; Function Attrs: convergent nocallback nounwind +declare void @llvm.nvvm.barrier.cta.sync.aligned.all(i32) #3 + +attributes #0 = { nounwind "nvvm.reqntid"="256" } +attributes #1 = { mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) } +attributes #2 = { convergent nocallback nounwind memory(inaccessiblemem: readwrite) } +attributes #3 = { convergent nocallback nounwind } +attributes #4 = { nounwind } + +!llvm.dbg.cu = !{!0} +!llvm.module.flags = !{!2, !3} + +!0 = distinct !DICompileUnit(language: DW_LANG_C, file: !1, producer: "triton", isOptimized: true, runtimeVersion: 0, emissionKind: LineTablesOnly) +!1 = !DIFile(filename: "cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py", directory: "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4") +!2 = !{i32 2, !"Debug Info Version", i32 3} +!3 = !{i32 4, !"nvvm-reflect-ftz", i32 1} +!4 = distinct !DISubprogram(name: "triton_red_fused_zeros_0", linkageName: "triton_red_fused_zeros_0", scope: !1, file: !1, line: 18, type: !5, scopeLine: 18, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !0) +!5 = !DISubroutineType(cc: DW_CC_normal, types: !6) +!6 = !{} +!7 = !DILocation(line: 23, column: 28, scope: !4) +!8 = !DILocation(line: 23, column: 33, scope: !4) +!9 = !DILocation(line: 24, column: 44, scope: !4) +!10 = !DILocation(line: 24, column: 23, scope: !4) +!11 = !DILocation(line: 26, column: 37, scope: !4) +!12 = !DILocation(line: 29, column: 21, scope: !4) +!13 = !DILocation(line: 28, column: 19, scope: !4) +!14 = !DILocation(line: 29, column: 29, scope: !4) +!15 = !DILocation(line: 30, column: 19, scope: !4) +!16 = !DILocation(line: 39, column: 45, scope: !4) +!17 = !DILocation(line: 39, column: 55, scope: !4) +!18 = !DILocation(line: 39, column: 68, scope: !4) +!19 = !DILocation(line: 40, column: 45, scope: !4) +!20 = !DILocation(line: 33, column: 40, scope: !4) +!21 = !DILocation(line: 34, column: 31, scope: !4) +!22 = !DILocation(line: 39, column: 60, scope: !4) +!23 = !DILocation(line: 39, column: 34, scope: !4) +!24 = !DILocation(line: 39, column: 73, scope: !4) +!25 = !DILocation(line: 39, column: 127, scope: !4) +!26 = !DILocation(line: 40, column: 34, scope: !4) +!27 = !DILocation(line: 40, column: 50, scope: !4) +!28 = !DILocation(line: 40, column: 104, scope: !4) +!29 = !DILocation(line: 41, column: 22, scope: !4) +!30 = !DILocation(line: 43, column: 23, scope: !4) +!31 = !DILocation(line: 291, column: 36, scope: !32, inlinedAt: !34) +!32 = distinct !DILexicalBlockFile(scope: !4, file: !33, discriminator: 0) +!33 = !DIFile(filename: "standard.py", directory: "/workspace/specforge/lib/python3.11/site-packages/triton/language") +!34 = !DILocation(line: 45, column: 25, scope: !4) +!35 = !DILocation(line: 261, column: 15, scope: !32, inlinedAt: !34) +!36 = !DILocation(line: 45, column: 28, scope: !4) +!37 = !DILocation(line: 49, column: 25, scope: !4) +!38 = !DILocation(line: 49, column: 36, scope: !4) +!39 = !DILocation(line: 49, column: 4, scope: !4) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/BQTM5LTBJQLCQDNUK7ALGANOLOCHSBE5FT7MUYNA4PYBS5SFUMZQ/triton_red_fused_zeros_0.ptx b/SpecForge-ext/cache/compiled_kernels/triton/0/BQTM5LTBJQLCQDNUK7ALGANOLOCHSBE5FT7MUYNA4PYBS5SFUMZQ/triton_red_fused_zeros_0.ptx new file mode 100644 index 0000000000000000000000000000000000000000..4dad29fe8fe8d0b5a46d10f33484a918e9f30276 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/BQTM5LTBJQLCQDNUK7ALGANOLOCHSBE5FT7MUYNA4PYBS5SFUMZQ/triton_red_fused_zeros_0.ptx @@ -0,0 +1,391 @@ +// +// Generated by LLVM NVPTX Back-End +// + +.version 8.7 +.target sm_90a +.address_size 64 + + // .globl triton_red_fused_zeros_0 // -- Begin function triton_red_fused_zeros_0 +.extern .shared .align 16 .b8 global_smem[]; + // @triton_red_fused_zeros_0 +.visible .entry triton_red_fused_zeros_0( + .param .u64 .ptr .global .align 1 triton_red_fused_zeros_0_param_0, + .param .u64 .ptr .global .align 1 triton_red_fused_zeros_0_param_1, + .param .u64 .ptr .global .align 1 triton_red_fused_zeros_0_param_2, + .param .u32 triton_red_fused_zeros_0_param_3, + .param .u32 triton_red_fused_zeros_0_param_4, + .param .u64 .ptr .global .align 1 triton_red_fused_zeros_0_param_5, + .param .u64 .ptr .global .align 1 triton_red_fused_zeros_0_param_6 +) +.reqntid 256 +{ + .reg .pred %p<5>; + .reg .b16 %rs<5>; + .reg .b32 %r<52>; + .reg .b64 %rd<25>; + .loc 1 18 0 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:18:0 +$L__func_begin0: + .loc 1 18 0 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:18:0 + +// %bb.0: + ld.param.b64 %rd8, [triton_red_fused_zeros_0_param_2]; + ld.param.b64 %rd7, [triton_red_fused_zeros_0_param_0]; + ld.param.b64 %rd10, [triton_red_fused_zeros_0_param_1]; +$L__tmp0: + .loc 1 23 28 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:23:28 + mov.u32 %r7, %ctaid.x; + .loc 1 23 33 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:23:33 + shl.b32 %r1, %r7, 6; + .loc 1 24 44 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:24:44 + mov.u32 %r2, %tid.x; + and.b32 %r3, %r2, 252; + bfe.u32 %r8, %r2, 2, 6; + .loc 1 24 23 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:24:23 + or.b32 %r9, %r8, %r1; + .loc 1 26 37 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:26:37 + and.b32 %r10, %r2, 3; + .loc 1 29 21 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:29:21 + bfe.s32 %r11, %r7, 25, 1; + shr.u32 %r12, %r11, 21; + add.s32 %r13, %r9, %r12; + shr.s32 %r14, %r13, 11; + .loc 1 29 29 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:29:29 + shr.u32 %r15, %r14, 27; + add.s32 %r16, %r14, %r15; + and.b32 %r17, %r16, 33554400; + sub.s32 %r18, %r14, %r17; + .loc 1 30 19 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:30:19 + shr.u32 %r19, %r11, 16; + add.s32 %r20, %r9, %r19; + .loc 1 39 45 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:39:45 + shl.b32 %r21, %r18, 7; + .loc 1 39 68 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:39:68 + shl.b32 %r22, %r20, 7; + and.b32 %r23, %r22, -8388608; + .loc 1 33 40 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:33:40 + cvt.u64.u32 %rd11, %r10; + shl.b32 %r24, %r7, 13; + shl.b32 %r25, %r8, 7; + or.b32 %r26, %r24, %r25; + cvt.s64.s32 %rd12, %r26; + or.b64 %rd13, %rd12, %rd11; + shl.b64 %rd14, %rd13, 1; + add.s64 %rd23, %rd10, %rd14; + shl.b32 %r27, %r7, 18; + add.s32 %r28, %r23, %r27; + shl.b32 %r29, %r8, 12; + or.b32 %r30, %r28, %r29; + add.s32 %r31, %r30, %r21; + or.b32 %r32, %r31, %r10; + shl.b32 %r33, %r14, 23; + sub.s32 %r34, %r32, %r33; + cvt.u64.u32 %rd2, %r34; + mov.b32 %r51, 0f00000000; + mov.b64 %rd24, -4; +$L__BB0_1: // =>This Inner Loop Header: Depth=1 + .loc 1 39 34 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:39:34 + add.s64 %rd21, %rd2, %rd24; + cvt.u32.u64 %r35, %rd21; + add.s32 %r36, %r35, 4; + mad.wide.s32 %rd16, %r36, 2, %rd7; + .loc 1 39 73 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:39:73 + // begin inline asm + mov.u64 %rd15, 0x0; + createpolicy.fractional.L2::evict_first.b64 %rd15, 1.0; + // end inline asm + mov.b16 %rs2, 0; + mov.pred %p1, -1; + // begin inline asm + mov.u16 %rs1, %rs2; + @%p1 ld.global.L1::evict_first.L2::cache_hint.b16 { %rs1 }, [ %rd16 + 0 ], %rd15; + // end inline asm + .loc 1 39 127 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:39:127 + cvt.f32.bf16 %r37, %rs1; + .loc 1 40 50 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:40:50 + // begin inline asm + mov.u64 %rd18, 0x0; + createpolicy.fractional.L2::evict_first.b64 %rd18, 1.0; + // end inline asm + // begin inline asm + mov.u16 %rs3, %rs2; + @%p1 ld.global.L1::evict_first.L2::cache_hint.b16 { %rs3 }, [ %rd23 + 0 ], %rd18; + // end inline asm + .loc 1 40 104 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:40:104 + cvt.f32.bf16 %r38, %rs3; + .loc 1 43 23 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:43:23 + fma.rn.f32 %r51, %r37, %r38, %r51; + .loc 1 33 40 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:33:40 + add.s64 %rd24, %rd24, 4; + add.s64 %rd23, %rd23, 8; + setp.lt.u64 %p3, %rd24, 124; + @%p3 bra $L__BB0_1; +// %bb.2: + .loc 1 24 44 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:24:44 + and.b32 %r40, %r2, 63; + .loc 1 24 23 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:24:23 + or.b32 %r41, %r1, %r40; +$L__tmp1: + .loc 2 291 36 // standard.py:291:36 @[ cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:45:25 ] + shfl.sync.bfly.b32 %r42, %r51, 2, 31, -1; + .loc 2 261 15 // standard.py:261:15 @[ cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:45:25 ] + add.f32 %r43, %r51, %r42; + .loc 2 291 36 // standard.py:291:36 @[ cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:45:25 ] + shfl.sync.bfly.b32 %r44, %r43, 1, 31, -1; + .loc 2 261 15 // standard.py:261:15 @[ cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:45:25 ] + add.f32 %r45, %r43, %r44; +$L__tmp2: + .loc 1 45 28 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:45:28 + mov.b32 %r46, global_smem; + add.s32 %r47, %r46, %r3; + st.shared.b32 [%r47], %r45; + bar.sync 0; + shl.b32 %r48, %r40, 2; + add.s32 %r49, %r46, %r48; + ld.shared.b32 %r39, [%r49]; + .loc 1 49 25 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:49:25 + mad.wide.s32 %rd22, %r41, 4, %rd8; + .loc 1 49 36 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:49:36 + and.b32 %r50, %r2, 192; + setp.eq.b32 %p4, %r50, 0; + // begin inline asm + @%p4 st.global.b32 [ %rd22 + 0 ], { %r39 }; + // end inline asm + .loc 1 49 4 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:49:4 + ret; +$L__tmp3: +$L__func_end0: + // -- End function +} + .file 1 "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py" + .file 2 "/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py" + .section .debug_abbrev + { +.b8 1 // Abbreviation Code +.b8 17 // DW_TAG_compile_unit +.b8 1 // DW_CHILDREN_yes +.b8 37 // DW_AT_producer +.b8 8 // DW_FORM_string +.b8 19 // DW_AT_language +.b8 5 // DW_FORM_data2 +.b8 3 // DW_AT_name +.b8 8 // DW_FORM_string +.b8 16 // DW_AT_stmt_list +.b8 6 // DW_FORM_data4 +.b8 27 // DW_AT_comp_dir +.b8 8 // DW_FORM_string +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 2 // Abbreviation Code +.b8 46 // DW_TAG_subprogram +.b8 0 // DW_CHILDREN_no +.b8 3 // DW_AT_name +.b8 8 // DW_FORM_string +.b8 32 // DW_AT_inline +.b8 11 // DW_FORM_data1 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 3 // Abbreviation Code +.b8 46 // DW_TAG_subprogram +.b8 1 // DW_CHILDREN_yes +.b8 17 // DW_AT_low_pc +.b8 1 // DW_FORM_addr +.b8 18 // DW_AT_high_pc +.b8 1 // DW_FORM_addr +.b8 49 // DW_AT_abstract_origin +.b8 19 // DW_FORM_ref4 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 4 // Abbreviation Code +.b8 29 // DW_TAG_inlined_subroutine +.b8 0 // DW_CHILDREN_no +.b8 49 // DW_AT_abstract_origin +.b8 19 // DW_FORM_ref4 +.b8 17 // DW_AT_low_pc +.b8 1 // DW_FORM_addr +.b8 18 // DW_AT_high_pc +.b8 1 // DW_FORM_addr +.b8 88 // DW_AT_call_file +.b8 11 // DW_FORM_data1 +.b8 89 // DW_AT_call_line +.b8 11 // DW_FORM_data1 +.b8 87 // DW_AT_call_column +.b8 11 // DW_FORM_data1 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 0 // EOM(3) + } + .section .debug_info + { +.b32 209 // Length of Unit +.b8 2 // DWARF version number +.b8 0 +.b32 .debug_abbrev // Offset Into Abbrev. Section +.b8 8 // Address Size (in bytes) +.b8 1 // Abbrev [1] 0xb:0xca DW_TAG_compile_unit +.b8 116 // DW_AT_producer +.b8 114 +.b8 105 +.b8 116 +.b8 111 +.b8 110 +.b8 0 +.b8 2 // DW_AT_language +.b8 0 +.b8 99 // DW_AT_name +.b8 122 +.b8 52 +.b8 105 +.b8 53 +.b8 112 +.b8 107 +.b8 117 +.b8 99 +.b8 117 +.b8 51 +.b8 108 +.b8 109 +.b8 120 +.b8 54 +.b8 113 +.b8 102 +.b8 100 +.b8 113 +.b8 97 +.b8 102 +.b8 106 +.b8 50 +.b8 122 +.b8 103 +.b8 103 +.b8 117 +.b8 110 +.b8 52 +.b8 121 +.b8 106 +.b8 114 +.b8 116 +.b8 103 +.b8 55 +.b8 97 +.b8 120 +.b8 50 +.b8 117 +.b8 102 +.b8 113 +.b8 116 +.b8 104 +.b8 107 +.b8 101 +.b8 52 +.b8 122 +.b8 97 +.b8 51 +.b8 102 +.b8 102 +.b8 55 +.b8 46 +.b8 112 +.b8 121 +.b8 0 +.b32 .debug_line // DW_AT_stmt_list +.b8 47 // DW_AT_comp_dir +.b8 119 +.b8 111 +.b8 114 +.b8 107 +.b8 115 +.b8 112 +.b8 97 +.b8 99 +.b8 101 +.b8 47 +.b8 104 +.b8 97 +.b8 110 +.b8 114 +.b8 117 +.b8 105 +.b8 47 +.b8 83 +.b8 112 +.b8 101 +.b8 99 +.b8 70 +.b8 111 +.b8 114 +.b8 103 +.b8 101 +.b8 45 +.b8 101 +.b8 120 +.b8 116 +.b8 47 +.b8 99 +.b8 97 +.b8 99 +.b8 104 +.b8 101 +.b8 47 +.b8 99 +.b8 111 +.b8 109 +.b8 112 +.b8 105 +.b8 108 +.b8 101 +.b8 100 +.b8 95 +.b8 107 +.b8 101 +.b8 114 +.b8 110 +.b8 101 +.b8 108 +.b8 115 +.b8 47 +.b8 122 +.b8 52 +.b8 0 +.b8 2 // Abbrev [2] 0x8b:0x1b DW_TAG_subprogram +.b8 116 // DW_AT_name +.b8 114 +.b8 105 +.b8 116 +.b8 111 +.b8 110 +.b8 95 +.b8 114 +.b8 101 +.b8 100 +.b8 95 +.b8 102 +.b8 117 +.b8 115 +.b8 101 +.b8 100 +.b8 95 +.b8 122 +.b8 101 +.b8 114 +.b8 111 +.b8 115 +.b8 95 +.b8 48 +.b8 0 +.b8 1 // DW_AT_inline +.b8 3 // Abbrev [3] 0xa6:0x2e DW_TAG_subprogram +.b64 $L__func_begin0 // DW_AT_low_pc +.b64 $L__func_end0 // DW_AT_high_pc +.b32 139 // DW_AT_abstract_origin +.b8 4 // Abbrev [4] 0xbb:0x18 DW_TAG_inlined_subroutine +.b32 139 // DW_AT_abstract_origin +.b64 $L__tmp1 // DW_AT_low_pc +.b64 $L__tmp2 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 45 // DW_AT_call_line +.b8 25 // DW_AT_call_column +.b8 0 // End Of Children Mark +.b8 0 // End Of Children Mark + } + .section .debug_macinfo { } diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/BQTM5LTBJQLCQDNUK7ALGANOLOCHSBE5FT7MUYNA4PYBS5SFUMZQ/triton_red_fused_zeros_0.source b/SpecForge-ext/cache/compiled_kernels/triton/0/BQTM5LTBJQLCQDNUK7ALGANOLOCHSBE5FT7MUYNA4PYBS5SFUMZQ/triton_red_fused_zeros_0.source new file mode 100644 index 0000000000000000000000000000000000000000..97d359a4c3c0416f4cb890b48fc1f3b025de4e7d --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/BQTM5LTBJQLCQDNUK7ALGANOLOCHSBE5FT7MUYNA4PYBS5SFUMZQ/triton_red_fused_zeros_0.source @@ -0,0 +1,222 @@ +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":18:0) +#loc44 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":285:0) +#loc46 = loc(unknown) +#loc49 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":260:0) +#loc53 = loc("in_ptr0"(#loc)) +#loc54 = loc("in_ptr1"(#loc)) +#loc55 = loc("out_ptr1"(#loc)) +#loc56 = loc("xnumel"(#loc)) +#loc57 = loc("r0_numel"(#loc)) +#loc97 = loc("input"(#loc44)) +#loc98 = loc("a"(#loc49)) +#loc99 = loc("b"(#loc49)) +module { + tt.func public @triton_red_fused_zeros_0(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %in_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr1"(#loc)), %out_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr1"(#loc)), %xnumel: i32 {tt.divisibility = 16 : i32} loc("xnumel"(#loc)), %r0_numel: i32 {tt.divisibility = 16 : i32} loc("r0_numel"(#loc))) attributes {noinline = false} { + %xnumel_0 = arith.constant 524288 : i32 loc(#loc58) + %r0_numel_1 = arith.constant 128 : i32 loc(#loc59) + %xoffset = tt.get_program_id x : i32 loc(#loc60) + %xoffset_2 = arith.constant 64 : i32 loc(#loc61) + %xoffset_3 = arith.constant 64 : i32 loc(#loc61) + %xoffset_4 = arith.muli %xoffset, %xoffset_3 : i32 loc(#loc61) + %xindex = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc62) + %xindex_5 = tt.expand_dims %xindex {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc63) + %xindex_6 = tt.splat %xoffset_4 : i32 -> tensor<64x1xi32> loc(#loc64) + %xindex_7 = arith.addi %xindex_6, %xindex_5 : tensor<64x1xi32> loc(#loc64) + %xmask = arith.constant true loc(#loc65) + %xmask_8 = arith.constant dense : tensor<64x4xi1> loc(#loc65) + %r0_base = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> loc(#loc66) + %r0_base_9 = tt.expand_dims %r0_base {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> loc(#loc67) + %x0 = arith.constant 2048 : i32 loc(#loc68) + %x0_10 = arith.constant 2048 : i32 loc(#loc68) + %x0_11 = arith.constant dense<2048> : tensor<64x1xi32> loc(#loc68) + %x0_12 = arith.remsi %xindex_7, %x0_11 : tensor<64x1xi32> loc(#loc68) + %x1 = arith.constant 2048 : i32 loc(#loc69) + %x1_13 = arith.constant 2048 : i32 loc(#loc69) + %x1_14 = arith.constant dense<2048> : tensor<64x1xi32> loc(#loc69) + %x1_15 = arith.divsi %xindex_7, %x1_14 : tensor<64x1xi32> loc(#loc69) + %x1_16 = arith.constant 32 : i32 loc(#loc70) + %x1_17 = arith.constant 32 : i32 loc(#loc70) + %x1_18 = arith.constant dense<32> : tensor<64x1xi32> loc(#loc70) + %x1_19 = arith.remsi %x1_15, %x1_18 : tensor<64x1xi32> loc(#loc70) + %x2 = arith.constant 65536 : i32 loc(#loc71) + %x2_20 = arith.constant 65536 : i32 loc(#loc71) + %x2_21 = arith.constant dense<65536> : tensor<64x1xi32> loc(#loc71) + %x2_22 = arith.divsi %xindex_7, %x2_21 : tensor<64x1xi32> loc(#loc71) + %_tmp4 = arith.constant 0.000000e+00 : f32 loc(#loc72) + %_tmp4_23 = arith.constant dense<0.000000e+00> : tensor<64x4xf32> loc(#loc72) + %c0_i32 = arith.constant 0 : i32 loc(#loc16) + %c4_i32 = arith.constant 4 : i32 loc(#loc16) + %0 = arith.bitcast %c0_i32 : i32 to i32 loc(#loc16) + %1 = arith.bitcast %r0_numel_1 : i32 to i32 loc(#loc16) + %2 = arith.bitcast %c4_i32 : i32 to i32 loc(#loc16) + %3 = ub.poison : i32 loc(#loc16) + %_tmp4_24 = scf.for %r0_offset = %0 to %1 step %2 iter_args(%_tmp4_27 = %_tmp4_23) -> (tensor<64x4xf32>) : i32 { + %r0_index = tt.splat %r0_offset : i32 -> tensor<1x4xi32> loc(#loc74) + %r0_index_28 = arith.addi %r0_index, %r0_base_9 : tensor<1x4xi32> loc(#loc74) + %r0_mask = arith.constant dense<128> : tensor<1x4xi32> loc(#loc75) + %r0_mask_29 = arith.cmpi slt, %r0_index_28, %r0_mask : tensor<1x4xi32> loc(#loc75) + %tmp0 = arith.constant 128 : i32 loc(#loc76) + %tmp0_30 = arith.constant 128 : i32 loc(#loc76) + %tmp0_31 = arith.constant dense<128> : tensor<64x1xi32> loc(#loc76) + %tmp0_32 = arith.muli %tmp0_31, %x1_19 : tensor<64x1xi32> loc(#loc76) + %tmp0_33 = tt.broadcast %r0_index_28 : tensor<1x4xi32> -> tensor<64x4xi32> loc(#loc77) + %tmp0_34 = tt.broadcast %tmp0_32 : tensor<64x1xi32> -> tensor<64x4xi32> loc(#loc77) + %tmp0_35 = arith.addi %tmp0_33, %tmp0_34 : tensor<64x4xi32> loc(#loc77) + %tmp0_36 = arith.constant 4096 : i32 loc(#loc78) + %tmp0_37 = arith.constant 4096 : i32 loc(#loc78) + %tmp0_38 = arith.constant dense<4096> : tensor<64x1xi32> loc(#loc78) + %tmp0_39 = arith.muli %tmp0_38, %x0_12 : tensor<64x1xi32> loc(#loc78) + %tmp0_40 = tt.broadcast %tmp0_39 : tensor<64x1xi32> -> tensor<64x4xi32> loc(#loc79) + %tmp0_41 = arith.addi %tmp0_35, %tmp0_40 : tensor<64x4xi32> loc(#loc79) + %tmp0_42 = arith.constant 8388608 : i32 loc(#loc80) + %tmp0_43 = arith.constant 8388608 : i32 loc(#loc80) + %tmp0_44 = arith.constant dense<8388608> : tensor<64x1xi32> loc(#loc80) + %tmp0_45 = arith.muli %tmp0_44, %x2_22 : tensor<64x1xi32> loc(#loc80) + %tmp0_46 = tt.broadcast %tmp0_45 : tensor<64x1xi32> -> tensor<64x4xi32> loc(#loc81) + %tmp0_47 = arith.addi %tmp0_41, %tmp0_46 : tensor<64x4xi32> loc(#loc81) + %tmp0_48 = tt.splat %in_ptr0 : !tt.ptr -> tensor<64x4x!tt.ptr> loc(#loc82) + %tmp0_49 = tt.addptr %tmp0_48, %tmp0_47 : tensor<64x4x!tt.ptr>, tensor<64x4xi32> loc(#loc82) + %tmp0_50 = arith.constant 0.000000e+00 : f32 loc(#loc83) + %tmp0_51 = tt.broadcast %r0_mask_29 : tensor<1x4xi1> -> tensor<64x4xi1> loc(#loc83) + %tmp0_52 = arith.constant dense<0.000000e+00> : tensor<64x4xf32> loc(#loc83) + %tmp0_53 = arith.truncf %tmp0_52 : tensor<64x4xf32> to tensor<64x4xbf16> loc(#loc83) + %tmp0_54 = tt.load %tmp0_49, %tmp0_51, %tmp0_53 evictionPolicy = evict_first : tensor<64x4x!tt.ptr> loc(#loc83) + %tmp0_55 = arith.extf %tmp0_54 : tensor<64x4xbf16> to tensor<64x4xf32> loc(#loc84) + %tmp1 = arith.constant 128 : i32 loc(#loc85) + %tmp1_56 = arith.constant 128 : i32 loc(#loc85) + %tmp1_57 = arith.constant dense<128> : tensor<64x1xi32> loc(#loc85) + %tmp1_58 = arith.muli %tmp1_57, %xindex_7 : tensor<64x1xi32> loc(#loc85) + %tmp1_59 = tt.broadcast %r0_index_28 : tensor<1x4xi32> -> tensor<64x4xi32> loc(#loc86) + %tmp1_60 = tt.broadcast %tmp1_58 : tensor<64x1xi32> -> tensor<64x4xi32> loc(#loc86) + %tmp1_61 = arith.addi %tmp1_59, %tmp1_60 : tensor<64x4xi32> loc(#loc86) + %tmp1_62 = tt.splat %in_ptr1 : !tt.ptr -> tensor<64x4x!tt.ptr> loc(#loc87) + %tmp1_63 = tt.addptr %tmp1_62, %tmp1_61 : tensor<64x4x!tt.ptr>, tensor<64x4xi32> loc(#loc87) + %tmp1_64 = arith.constant 0.000000e+00 : f32 loc(#loc88) + %tmp1_65 = tt.broadcast %r0_mask_29 : tensor<1x4xi1> -> tensor<64x4xi1> loc(#loc88) + %tmp1_66 = arith.constant dense<0.000000e+00> : tensor<64x4xf32> loc(#loc88) + %tmp1_67 = arith.truncf %tmp1_66 : tensor<64x4xf32> to tensor<64x4xbf16> loc(#loc88) + %tmp1_68 = tt.load %tmp1_63, %tmp1_65, %tmp1_67 evictionPolicy = evict_first : tensor<64x4x!tt.ptr> loc(#loc88) + %tmp1_69 = arith.extf %tmp1_68 : tensor<64x4xbf16> to tensor<64x4xf32> loc(#loc89) + %tmp2 = arith.mulf %tmp0_55, %tmp1_69 : tensor<64x4xf32> loc(#loc90) + %tmp5 = arith.addf %_tmp4_27, %tmp2 : tensor<64x4xf32> loc(#loc91) + %_tmp4_70 = tt.broadcast %r0_mask_29 : tensor<1x4xi1> -> tensor<64x4xi1> loc(#loc92) + %_tmp4_71 = arith.select %_tmp4_70, %tmp5, %_tmp4_27 : tensor<64x4xi1>, tensor<64x4xf32> loc(#loc92) + scf.yield %_tmp4_71 : tensor<64x4xf32> loc(#loc36) + } loc(#loc73) + %tmp4 = tt.call @"triton.language.standard.sum__fp32S64_4S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%_tmp4_24) : (tensor<64x4xf32>) -> tensor<64xf32> loc(#loc93) + %tmp4_25 = tt.expand_dims %tmp4 {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32> loc(#loc94) + %tmp7 = arith.constant 0.000000e+00 : f32 loc(#loc95) + %tmp8 = arith.constant dense<0.000000e+00> : tensor<64x1xf32> loc(#loc96) + %tmp8_26 = arith.subf %tmp4_25, %tmp8 : tensor<64x1xf32> loc(#loc96) + %4 = tt.splat %out_ptr1 : !tt.ptr -> tensor<64x1x!tt.ptr> loc(#loc41) + %5 = tt.addptr %4, %xindex_7 : tensor<64x1x!tt.ptr>, tensor<64x1xi32> loc(#loc41) + tt.store %5, %tmp8_26 : tensor<64x1x!tt.ptr> loc(#loc42) + tt.return loc(#loc43) + } loc(#loc) + tt.func private @"triton.language.standard.sum__fp32S64_4S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%input: tensor<64x4xf32> loc("input"(#loc44))) -> tensor<64xf32> attributes {noinline = false} { + %0 = "tt.reduce"(%input) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32 loc(unknown), %arg2: f32 loc(unknown)): + %2 = tt.call @triton.language.standard._sum_combine__fp32_fp32__(%arg1, %arg2) : (f32, f32) -> f32 loc(#loc45) + tt.reduce.return %2 : f32 loc(#loc45) + }) : (tensor<64x4xf32>) -> tensor<64xf32> loc(#loc45) + tt.return %0 : tensor<64xf32> loc(#loc47) + ^bb1: // no predecessors + %1 = ub.poison : tensor<64xf32> loc(#loc48) + tt.return %1 : tensor<64xf32> loc(#loc48) + } loc(#loc44) + tt.func private @triton.language.standard._sum_combine__fp32_fp32__(%a: f32 loc("a"(#loc49)), %b: f32 loc("b"(#loc49))) -> f32 attributes {noinline = false} { + %0 = arith.addf %a, %b : f32 loc(#loc50) + tt.return %0 : f32 loc(#loc51) + ^bb1: // no predecessors + %1 = ub.poison : f32 loc(#loc52) + tt.return %1 : f32 loc(#loc52) + } loc(#loc49) +} loc(#loc) +#loc1 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":19:13) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":20:15) +#loc3 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":23:28) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":23:33) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":24:36) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":24:44) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":24:23) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":25:46) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":26:27) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":26:37) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":28:19) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":29:21) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":29:29) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":30:19) +#loc15 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":32:43) +#loc16 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":33:40) +#loc17 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":34:31) +#loc18 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":35:29) +#loc19 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:45) +#loc20 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:41) +#loc21 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:55) +#loc22 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:50) +#loc23 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:68) +#loc24 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:60) +#loc25 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:34) +#loc26 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:73) +#loc27 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:127) +#loc28 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":40:45) +#loc29 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":40:41) +#loc30 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":40:34) +#loc31 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":40:50) +#loc32 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":40:104) +#loc33 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":41:22) +#loc34 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":43:23) +#loc35 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":44:40) +#loc36 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":44:8) +#loc37 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":45:25) +#loc38 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":45:28) +#loc39 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":47:11) +#loc40 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":48:18) +#loc41 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":49:25) +#loc42 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":49:36) +#loc43 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":49:4) +#loc45 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:36) +#loc47 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:11) +#loc48 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:4) +#loc50 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:15) +#loc51 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:11) +#loc52 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:4) +#loc58 = loc("xnumel"(#loc1)) +#loc59 = loc("r0_numel"(#loc2)) +#loc60 = loc("xoffset"(#loc3)) +#loc61 = loc("xoffset"(#loc4)) +#loc62 = loc("xindex"(#loc5)) +#loc63 = loc("xindex"(#loc6)) +#loc64 = loc("xindex"(#loc7)) +#loc65 = loc("xmask"(#loc8)) +#loc66 = loc("r0_base"(#loc9)) +#loc67 = loc("r0_base"(#loc10)) +#loc68 = loc("x0"(#loc11)) +#loc69 = loc("x1"(#loc12)) +#loc70 = loc("x1"(#loc13)) +#loc71 = loc("x2"(#loc14)) +#loc72 = loc("_tmp4"(#loc15)) +#loc73 = loc("_tmp4"(#loc16)) +#loc74 = loc("r0_index"(#loc17)) +#loc75 = loc("r0_mask"(#loc18)) +#loc76 = loc("tmp0"(#loc19)) +#loc77 = loc("tmp0"(#loc20)) +#loc78 = loc("tmp0"(#loc21)) +#loc79 = loc("tmp0"(#loc22)) +#loc80 = loc("tmp0"(#loc23)) +#loc81 = loc("tmp0"(#loc24)) +#loc82 = loc("tmp0"(#loc25)) +#loc83 = loc("tmp0"(#loc26)) +#loc84 = loc("tmp0"(#loc27)) +#loc85 = loc("tmp1"(#loc28)) +#loc86 = loc("tmp1"(#loc29)) +#loc87 = loc("tmp1"(#loc30)) +#loc88 = loc("tmp1"(#loc31)) +#loc89 = loc("tmp1"(#loc32)) +#loc90 = loc("tmp2"(#loc33)) +#loc91 = loc("tmp5"(#loc34)) +#loc92 = loc("_tmp4"(#loc35)) +#loc93 = loc("tmp4"(#loc37)) +#loc94 = loc("tmp4"(#loc38)) +#loc95 = loc("tmp7"(#loc39)) +#loc96 = loc("tmp8"(#loc40)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/BQTM5LTBJQLCQDNUK7ALGANOLOCHSBE5FT7MUYNA4PYBS5SFUMZQ/triton_red_fused_zeros_0.ttgir b/SpecForge-ext/cache/compiled_kernels/triton/0/BQTM5LTBJQLCQDNUK7ALGANOLOCHSBE5FT7MUYNA4PYBS5SFUMZQ/triton_red_fused_zeros_0.ttgir new file mode 100644 index 0000000000000000000000000000000000000000..265e1e5508c6619256a2bfb15cc952c5b96d4457 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/BQTM5LTBJQLCQDNUK7ALGANOLOCHSBE5FT7MUYNA4PYBS5SFUMZQ/triton_red_fused_zeros_0.ttgir @@ -0,0 +1,155 @@ +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 4], order = [0, 1]}> +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":18:0) +#loc1 = loc(unknown) +#loc33 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":45:25) +#loc39 = loc("in_ptr0"(#loc)) +#loc40 = loc("in_ptr1"(#loc)) +#loc41 = loc("out_ptr1"(#loc)) +#loc42 = loc("xnumel"(#loc)) +#loc43 = loc("r0_numel"(#loc)) +#loc73 = loc("tmp4"(#loc33)) +#loc76 = loc(callsite(#loc1 at #loc73)) +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @triton_red_fused_zeros_0(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %in_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr1"(#loc)), %out_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr1"(#loc)), %xnumel: i32 {tt.divisibility = 16 : i32} loc("xnumel"(#loc)), %r0_numel: i32 {tt.divisibility = 16 : i32} loc("r0_numel"(#loc))) attributes {noinline = false} { + %cst = arith.constant dense<2048> : tensor<64x1xi32, #blocked> loc(#loc1) + %cst_0 = arith.constant dense<32> : tensor<64x1xi32, #blocked> loc(#loc1) + %cst_1 = arith.constant dense<65536> : tensor<64x1xi32, #blocked> loc(#loc1) + %cst_2 = arith.constant dense<128> : tensor<64x1xi32, #blocked> loc(#loc1) + %cst_3 = arith.constant dense<4096> : tensor<64x1xi32, #blocked> loc(#loc1) + %cst_4 = arith.constant dense<8388608> : tensor<64x1xi32, #blocked> loc(#loc1) + %c0_i32 = arith.constant 0 : i32 loc(#loc1) + %c128_i32 = arith.constant 128 : i32 loc(#loc1) + %c4_i32 = arith.constant 4 : i32 loc(#loc1) + %cst_5 = arith.constant dense<0.000000e+00> : tensor<64x4xbf16, #blocked> loc(#loc1) + %cst_6 = arith.constant dense<128> : tensor<1x4xi32, #blocked> loc(#loc1) + %cst_7 = arith.constant dense<0.000000e+00> : tensor<64x4xf32, #blocked> loc(#loc1) + %c64_i32 = arith.constant 64 : i32 loc(#loc1) + %xoffset = tt.get_program_id x : i32 loc(#loc44) + %xoffset_8 = arith.muli %xoffset, %c64_i32 : i32 loc(#loc45) + %xindex = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc46) + %xindex_9 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc46) + %xindex_10 = tt.expand_dims %xindex {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc46) + %xindex_11 = tt.expand_dims %xindex_9 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> loc(#loc46) + %xindex_12 = tt.splat %xoffset_8 : i32 -> tensor<64x1xi32, #blocked> loc(#loc47) + %xindex_13 = tt.splat %xoffset_8 : i32 -> tensor<64x1xi32, #blocked1> loc(#loc47) + %xindex_14 = arith.addi %xindex_12, %xindex_10 : tensor<64x1xi32, #blocked> loc(#loc47) + %xindex_15 = arith.addi %xindex_13, %xindex_11 : tensor<64x1xi32, #blocked1> loc(#loc47) + %r0_base = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc48) + %r0_base_16 = tt.expand_dims %r0_base {axis = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x4xi32, #blocked> loc(#loc48) + %x0 = arith.remsi %xindex_14, %cst : tensor<64x1xi32, #blocked> loc(#loc49) + %x1 = arith.divsi %xindex_14, %cst : tensor<64x1xi32, #blocked> loc(#loc50) + %x1_17 = arith.remsi %x1, %cst_0 : tensor<64x1xi32, #blocked> loc(#loc51) + %x2 = arith.divsi %xindex_14, %cst_1 : tensor<64x1xi32, #blocked> loc(#loc52) + %tmp0 = arith.muli %x1_17, %cst_2 : tensor<64x1xi32, #blocked> loc(#loc53) + %tmp0_18 = tt.broadcast %tmp0 : tensor<64x1xi32, #blocked> -> tensor<64x4xi32, #blocked> loc(#loc54) + %tmp0_19 = arith.muli %x0, %cst_3 : tensor<64x1xi32, #blocked> loc(#loc55) + %tmp0_20 = tt.broadcast %tmp0_19 : tensor<64x1xi32, #blocked> -> tensor<64x4xi32, #blocked> loc(#loc56) + %tmp0_21 = arith.muli %x2, %cst_4 : tensor<64x1xi32, #blocked> loc(#loc57) + %tmp0_22 = tt.broadcast %tmp0_21 : tensor<64x1xi32, #blocked> -> tensor<64x4xi32, #blocked> loc(#loc58) + %tmp0_23 = tt.splat %in_ptr0 : !tt.ptr -> tensor<64x4x!tt.ptr, #blocked> loc(#loc59) + %tmp1 = arith.muli %xindex_14, %cst_2 : tensor<64x1xi32, #blocked> loc(#loc60) + %tmp1_24 = tt.broadcast %tmp1 : tensor<64x1xi32, #blocked> -> tensor<64x4xi32, #blocked> loc(#loc61) + %tmp1_25 = tt.splat %in_ptr1 : !tt.ptr -> tensor<64x4x!tt.ptr, #blocked> loc(#loc62) + %_tmp4 = scf.for %r0_offset = %c0_i32 to %c128_i32 step %c4_i32 iter_args(%_tmp4_28 = %cst_7) -> (tensor<64x4xf32, #blocked>) : i32 { + %r0_index = tt.splat %r0_offset : i32 -> tensor<1x4xi32, #blocked> loc(#loc64) + %r0_index_29 = arith.addi %r0_index, %r0_base_16 : tensor<1x4xi32, #blocked> loc(#loc64) + %r0_mask = arith.cmpi slt, %r0_index_29, %cst_6 : tensor<1x4xi32, #blocked> loc(#loc65) + %tmp0_30 = tt.broadcast %r0_index_29 : tensor<1x4xi32, #blocked> -> tensor<64x4xi32, #blocked> loc(#loc54) + %tmp0_31 = arith.addi %tmp0_30, %tmp0_18 : tensor<64x4xi32, #blocked> loc(#loc54) + %tmp0_32 = arith.addi %tmp0_31, %tmp0_20 : tensor<64x4xi32, #blocked> loc(#loc56) + %tmp0_33 = arith.addi %tmp0_32, %tmp0_22 : tensor<64x4xi32, #blocked> loc(#loc58) + %tmp0_34 = tt.addptr %tmp0_23, %tmp0_33 : tensor<64x4x!tt.ptr, #blocked>, tensor<64x4xi32, #blocked> loc(#loc59) + %tmp0_35 = tt.broadcast %r0_mask : tensor<1x4xi1, #blocked> -> tensor<64x4xi1, #blocked> loc(#loc66) + %tmp0_36 = tt.load %tmp0_34, %tmp0_35, %cst_5 evictionPolicy = evict_first : tensor<64x4x!tt.ptr, #blocked> loc(#loc66) + %tmp0_37 = arith.extf %tmp0_36 : tensor<64x4xbf16, #blocked> to tensor<64x4xf32, #blocked> loc(#loc67) + %tmp1_38 = arith.addi %tmp0_30, %tmp1_24 : tensor<64x4xi32, #blocked> loc(#loc61) + %tmp1_39 = tt.addptr %tmp1_25, %tmp1_38 : tensor<64x4x!tt.ptr, #blocked>, tensor<64x4xi32, #blocked> loc(#loc62) + %tmp1_40 = tt.load %tmp1_39, %tmp0_35, %cst_5 evictionPolicy = evict_first : tensor<64x4x!tt.ptr, #blocked> loc(#loc68) + %tmp1_41 = arith.extf %tmp1_40 : tensor<64x4xbf16, #blocked> to tensor<64x4xf32, #blocked> loc(#loc69) + %tmp2 = arith.mulf %tmp0_37, %tmp1_41 : tensor<64x4xf32, #blocked> loc(#loc70) + %tmp5 = arith.addf %_tmp4_28, %tmp2 : tensor<64x4xf32, #blocked> loc(#loc71) + %_tmp4_42 = arith.select %tmp0_35, %tmp5, %_tmp4_28 : tensor<64x4xi1, #blocked>, tensor<64x4xf32, #blocked> loc(#loc72) + scf.yield %_tmp4_42 : tensor<64x4xf32, #blocked> loc(#loc31) + } loc(#loc63) + %tmp4 = "tt.reduce"(%_tmp4) <{axis = 1 : i32}> ({ + ^bb0(%tmp4_28: f32 loc(callsite(#loc1 at #loc73)), %tmp4_29: f32 loc(callsite(#loc1 at #loc73))): + %tmp4_30 = arith.addf %tmp4_28, %tmp4_29 : f32 loc(#loc77) + tt.reduce.return %tmp4_30 : f32 loc(#loc75) + }) : (tensor<64x4xf32, #blocked>) -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc75) + %tmp4_26 = ttg.convert_layout %tmp4 : tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc74) + %tmp4_27 = tt.expand_dims %tmp4_26 {axis = 1 : i32} : tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xf32, #blocked1> loc(#loc74) + %0 = tt.splat %out_ptr1 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked1> loc(#loc36) + %1 = tt.addptr %0, %xindex_15 : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> loc(#loc36) + tt.store %1, %tmp4_27 : tensor<64x1x!tt.ptr, #blocked1> loc(#loc37) + tt.return loc(#loc38) + } loc(#loc) +} loc(#loc) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":23:28) +#loc3 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":23:33) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":24:44) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":24:23) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":26:37) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":28:19) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":29:21) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":29:29) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":30:19) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:45) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:41) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:55) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:50) +#loc15 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:68) +#loc16 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:60) +#loc17 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:34) +#loc18 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":40:45) +#loc19 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":40:41) +#loc20 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":40:34) +#loc21 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":33:40) +#loc22 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":34:31) +#loc23 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":35:29) +#loc24 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:73) +#loc25 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:127) +#loc26 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":40:50) +#loc27 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":40:104) +#loc28 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":41:22) +#loc29 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":43:23) +#loc30 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":44:40) +#loc31 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":44:8) +#loc32 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:36) +#loc34 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:15) +#loc35 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":45:28) +#loc36 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":49:25) +#loc37 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":49:36) +#loc38 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":49:4) +#loc44 = loc("xoffset"(#loc2)) +#loc45 = loc("xoffset"(#loc3)) +#loc46 = loc("xindex"(#loc4)) +#loc47 = loc("xindex"(#loc5)) +#loc48 = loc("r0_base"(#loc6)) +#loc49 = loc("x0"(#loc7)) +#loc50 = loc("x1"(#loc8)) +#loc51 = loc("x1"(#loc9)) +#loc52 = loc("x2"(#loc10)) +#loc53 = loc("tmp0"(#loc11)) +#loc54 = loc("tmp0"(#loc12)) +#loc55 = loc("tmp0"(#loc13)) +#loc56 = loc("tmp0"(#loc14)) +#loc57 = loc("tmp0"(#loc15)) +#loc58 = loc("tmp0"(#loc16)) +#loc59 = loc("tmp0"(#loc17)) +#loc60 = loc("tmp1"(#loc18)) +#loc61 = loc("tmp1"(#loc19)) +#loc62 = loc("tmp1"(#loc20)) +#loc63 = loc("_tmp4"(#loc21)) +#loc64 = loc("r0_index"(#loc22)) +#loc65 = loc("r0_mask"(#loc23)) +#loc66 = loc("tmp0"(#loc24)) +#loc67 = loc("tmp0"(#loc25)) +#loc68 = loc("tmp1"(#loc26)) +#loc69 = loc("tmp1"(#loc27)) +#loc70 = loc("tmp2"(#loc28)) +#loc71 = loc("tmp5"(#loc29)) +#loc72 = loc("_tmp4"(#loc30)) +#loc74 = loc("tmp4"(#loc35)) +#loc75 = loc(callsite(#loc32 at #loc73)) +#loc77 = loc(callsite(#loc34 at #loc75)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/BQTM5LTBJQLCQDNUK7ALGANOLOCHSBE5FT7MUYNA4PYBS5SFUMZQ/triton_red_fused_zeros_0.ttir b/SpecForge-ext/cache/compiled_kernels/triton/0/BQTM5LTBJQLCQDNUK7ALGANOLOCHSBE5FT7MUYNA4PYBS5SFUMZQ/triton_red_fused_zeros_0.ttir new file mode 100644 index 0000000000000000000000000000000000000000..6557dcac8d5f7969d66b0b4bae2a096f366aa7df --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/BQTM5LTBJQLCQDNUK7ALGANOLOCHSBE5FT7MUYNA4PYBS5SFUMZQ/triton_red_fused_zeros_0.ttir @@ -0,0 +1,152 @@ +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":18:0) +#loc1 = loc(unknown) +#loc35 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":45:25) +#loc41 = loc("in_ptr0"(#loc)) +#loc42 = loc("in_ptr1"(#loc)) +#loc43 = loc("out_ptr1"(#loc)) +#loc44 = loc("xnumel"(#loc)) +#loc45 = loc("r0_numel"(#loc)) +#loc77 = loc("tmp4"(#loc35)) +#loc80 = loc(callsite(#loc1 at #loc77)) +module { + tt.func public @triton_red_fused_zeros_0(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %in_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr1"(#loc)), %out_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr1"(#loc)), %xnumel: i32 {tt.divisibility = 16 : i32} loc("xnumel"(#loc)), %r0_numel: i32 {tt.divisibility = 16 : i32} loc("r0_numel"(#loc))) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<64x4xbf16> loc(#loc1) + %c4_i32 = arith.constant 4 : i32 loc(#loc2) + %c128_i32 = arith.constant 128 : i32 loc(#loc2) + %c0_i32 = arith.constant 0 : i32 loc(#loc2) + %cst_0 = arith.constant dense<8388608> : tensor<64x1xi32> loc(#loc1) + %cst_1 = arith.constant dense<4096> : tensor<64x1xi32> loc(#loc1) + %cst_2 = arith.constant dense<128> : tensor<64x1xi32> loc(#loc1) + %cst_3 = arith.constant dense<128> : tensor<1x4xi32> loc(#loc1) + %cst_4 = arith.constant dense<0.000000e+00> : tensor<64x4xf32> loc(#loc1) + %x2 = arith.constant dense<65536> : tensor<64x1xi32> loc(#loc46) + %x1 = arith.constant dense<32> : tensor<64x1xi32> loc(#loc47) + %cst_5 = arith.constant dense<2048> : tensor<64x1xi32> loc(#loc1) + %c64_i32 = arith.constant 64 : i32 loc(#loc1) + %xoffset = tt.get_program_id x : i32 loc(#loc48) + %xoffset_6 = arith.muli %xoffset, %c64_i32 : i32 loc(#loc49) + %xindex = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc50) + %xindex_7 = tt.expand_dims %xindex {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc51) + %xindex_8 = tt.splat %xoffset_6 : i32 -> tensor<64x1xi32> loc(#loc52) + %xindex_9 = arith.addi %xindex_8, %xindex_7 : tensor<64x1xi32> loc(#loc52) + %r0_base = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> loc(#loc53) + %r0_base_10 = tt.expand_dims %r0_base {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> loc(#loc54) + %x0 = arith.remsi %xindex_9, %cst_5 : tensor<64x1xi32> loc(#loc55) + %x1_11 = arith.divsi %xindex_9, %cst_5 : tensor<64x1xi32> loc(#loc56) + %x1_12 = arith.remsi %x1_11, %x1 : tensor<64x1xi32> loc(#loc47) + %x2_13 = arith.divsi %xindex_9, %x2 : tensor<64x1xi32> loc(#loc46) + %_tmp4 = scf.for %r0_offset = %c0_i32 to %c128_i32 step %c4_i32 iter_args(%_tmp4_15 = %cst_4) -> (tensor<64x4xf32>) : i32 { + %r0_index = tt.splat %r0_offset : i32 -> tensor<1x4xi32> loc(#loc58) + %r0_index_16 = arith.addi %r0_index, %r0_base_10 : tensor<1x4xi32> loc(#loc58) + %r0_mask = arith.cmpi slt, %r0_index_16, %cst_3 : tensor<1x4xi32> loc(#loc59) + %tmp0 = arith.muli %x1_12, %cst_2 : tensor<64x1xi32> loc(#loc60) + %tmp0_17 = tt.broadcast %r0_index_16 : tensor<1x4xi32> -> tensor<64x4xi32> loc(#loc61) + %tmp0_18 = tt.broadcast %tmp0 : tensor<64x1xi32> -> tensor<64x4xi32> loc(#loc61) + %tmp0_19 = arith.addi %tmp0_17, %tmp0_18 : tensor<64x4xi32> loc(#loc61) + %tmp0_20 = arith.muli %x0, %cst_1 : tensor<64x1xi32> loc(#loc62) + %tmp0_21 = tt.broadcast %tmp0_20 : tensor<64x1xi32> -> tensor<64x4xi32> loc(#loc63) + %tmp0_22 = arith.addi %tmp0_19, %tmp0_21 : tensor<64x4xi32> loc(#loc63) + %tmp0_23 = arith.muli %x2_13, %cst_0 : tensor<64x1xi32> loc(#loc64) + %tmp0_24 = tt.broadcast %tmp0_23 : tensor<64x1xi32> -> tensor<64x4xi32> loc(#loc65) + %tmp0_25 = arith.addi %tmp0_22, %tmp0_24 : tensor<64x4xi32> loc(#loc65) + %tmp0_26 = tt.splat %in_ptr0 : !tt.ptr -> tensor<64x4x!tt.ptr> loc(#loc66) + %tmp0_27 = tt.addptr %tmp0_26, %tmp0_25 : tensor<64x4x!tt.ptr>, tensor<64x4xi32> loc(#loc66) + %tmp0_28 = tt.broadcast %r0_mask : tensor<1x4xi1> -> tensor<64x4xi1> loc(#loc67) + %tmp0_29 = tt.load %tmp0_27, %tmp0_28, %cst evictionPolicy = evict_first : tensor<64x4x!tt.ptr> loc(#loc67) + %tmp0_30 = arith.extf %tmp0_29 : tensor<64x4xbf16> to tensor<64x4xf32> loc(#loc68) + %tmp1 = arith.muli %xindex_9, %cst_2 : tensor<64x1xi32> loc(#loc69) + %tmp1_31 = tt.broadcast %tmp1 : tensor<64x1xi32> -> tensor<64x4xi32> loc(#loc70) + %tmp1_32 = arith.addi %tmp0_17, %tmp1_31 : tensor<64x4xi32> loc(#loc70) + %tmp1_33 = tt.splat %in_ptr1 : !tt.ptr -> tensor<64x4x!tt.ptr> loc(#loc71) + %tmp1_34 = tt.addptr %tmp1_33, %tmp1_32 : tensor<64x4x!tt.ptr>, tensor<64x4xi32> loc(#loc71) + %tmp1_35 = tt.load %tmp1_34, %tmp0_28, %cst evictionPolicy = evict_first : tensor<64x4x!tt.ptr> loc(#loc72) + %tmp1_36 = arith.extf %tmp1_35 : tensor<64x4xbf16> to tensor<64x4xf32> loc(#loc73) + %tmp2 = arith.mulf %tmp0_30, %tmp1_36 : tensor<64x4xf32> loc(#loc74) + %tmp5 = arith.addf %_tmp4_15, %tmp2 : tensor<64x4xf32> loc(#loc75) + %_tmp4_37 = arith.select %tmp0_28, %tmp5, %_tmp4_15 : tensor<64x4xi1>, tensor<64x4xf32> loc(#loc76) + scf.yield %_tmp4_37 : tensor<64x4xf32> loc(#loc33) + } loc(#loc57) + %tmp4 = "tt.reduce"(%_tmp4) <{axis = 1 : i32}> ({ + ^bb0(%tmp4_15: f32 loc(callsite(#loc1 at #loc77)), %tmp4_16: f32 loc(callsite(#loc1 at #loc77))): + %tmp4_17 = arith.addf %tmp4_15, %tmp4_16 : f32 loc(#loc81) + tt.reduce.return %tmp4_17 : f32 loc(#loc79) + }) : (tensor<64x4xf32>) -> tensor<64xf32> loc(#loc79) + %tmp4_14 = tt.expand_dims %tmp4 {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32> loc(#loc78) + %0 = tt.splat %out_ptr1 : !tt.ptr -> tensor<64x1x!tt.ptr> loc(#loc38) + %1 = tt.addptr %0, %xindex_9 : tensor<64x1x!tt.ptr>, tensor<64x1xi32> loc(#loc38) + tt.store %1, %tmp4_14 : tensor<64x1x!tt.ptr> loc(#loc39) + tt.return loc(#loc40) + } loc(#loc) +} loc(#loc) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":33:40) +#loc3 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":30:19) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":29:29) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":23:28) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":23:33) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":24:36) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":24:44) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":24:23) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":26:27) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":26:37) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":28:19) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":29:21) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":34:31) +#loc15 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":35:29) +#loc16 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:45) +#loc17 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:41) +#loc18 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:55) +#loc19 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:50) +#loc20 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:68) +#loc21 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:60) +#loc22 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:34) +#loc23 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:73) +#loc24 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:127) +#loc25 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":40:45) +#loc26 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":40:41) +#loc27 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":40:34) +#loc28 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":40:50) +#loc29 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":40:104) +#loc30 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":41:22) +#loc31 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":43:23) +#loc32 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":44:40) +#loc33 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":44:8) +#loc34 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:36) +#loc36 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:15) +#loc37 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":45:28) +#loc38 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":49:25) +#loc39 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":49:36) +#loc40 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":49:4) +#loc46 = loc("x2"(#loc3)) +#loc47 = loc("x1"(#loc4)) +#loc48 = loc("xoffset"(#loc5)) +#loc49 = loc("xoffset"(#loc6)) +#loc50 = loc("xindex"(#loc7)) +#loc51 = loc("xindex"(#loc8)) +#loc52 = loc("xindex"(#loc9)) +#loc53 = loc("r0_base"(#loc10)) +#loc54 = loc("r0_base"(#loc11)) +#loc55 = loc("x0"(#loc12)) +#loc56 = loc("x1"(#loc13)) +#loc57 = loc("_tmp4"(#loc2)) +#loc58 = loc("r0_index"(#loc14)) +#loc59 = loc("r0_mask"(#loc15)) +#loc60 = loc("tmp0"(#loc16)) +#loc61 = loc("tmp0"(#loc17)) +#loc62 = loc("tmp0"(#loc18)) +#loc63 = loc("tmp0"(#loc19)) +#loc64 = loc("tmp0"(#loc20)) +#loc65 = loc("tmp0"(#loc21)) +#loc66 = loc("tmp0"(#loc22)) +#loc67 = loc("tmp0"(#loc23)) +#loc68 = loc("tmp0"(#loc24)) +#loc69 = loc("tmp1"(#loc25)) +#loc70 = loc("tmp1"(#loc26)) +#loc71 = loc("tmp1"(#loc27)) +#loc72 = loc("tmp1"(#loc28)) +#loc73 = loc("tmp1"(#loc29)) +#loc74 = loc("tmp2"(#loc30)) +#loc75 = loc("tmp5"(#loc31)) +#loc76 = loc("_tmp4"(#loc32)) +#loc78 = loc("tmp4"(#loc37)) +#loc79 = loc(callsite(#loc34 at #loc77)) +#loc81 = loc(callsite(#loc36 at #loc79)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/DMLY6XP6QTS3ZPAAXQQN4VR72RBHXLCD3EYHGZKXQBPDY7DRKV2A/__grp__triton_red_fused_zeros_0.json b/SpecForge-ext/cache/compiled_kernels/triton/0/DMLY6XP6QTS3ZPAAXQQN4VR72RBHXLCD3EYHGZKXQBPDY7DRKV2A/__grp__triton_red_fused_zeros_0.json new file mode 100644 index 0000000000000000000000000000000000000000..7ecb0ccdb9ff4ba090bc7e44519d4507aa881b09 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/DMLY6XP6QTS3ZPAAXQQN4VR72RBHXLCD3EYHGZKXQBPDY7DRKV2A/__grp__triton_red_fused_zeros_0.json @@ -0,0 +1 @@ +{"child_paths": {"triton_red_fused_zeros_0.source": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/DMLY6XP6QTS3ZPAAXQQN4VR72RBHXLCD3EYHGZKXQBPDY7DRKV2A/triton_red_fused_zeros_0.source", "triton_red_fused_zeros_0.ttir": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/DMLY6XP6QTS3ZPAAXQQN4VR72RBHXLCD3EYHGZKXQBPDY7DRKV2A/triton_red_fused_zeros_0.ttir", "triton_red_fused_zeros_0.ttgir": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/DMLY6XP6QTS3ZPAAXQQN4VR72RBHXLCD3EYHGZKXQBPDY7DRKV2A/triton_red_fused_zeros_0.ttgir", "triton_red_fused_zeros_0.llir": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/DMLY6XP6QTS3ZPAAXQQN4VR72RBHXLCD3EYHGZKXQBPDY7DRKV2A/triton_red_fused_zeros_0.llir", "triton_red_fused_zeros_0.ptx": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/DMLY6XP6QTS3ZPAAXQQN4VR72RBHXLCD3EYHGZKXQBPDY7DRKV2A/triton_red_fused_zeros_0.ptx", "triton_red_fused_zeros_0.cubin": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/DMLY6XP6QTS3ZPAAXQQN4VR72RBHXLCD3EYHGZKXQBPDY7DRKV2A/triton_red_fused_zeros_0.cubin", "triton_red_fused_zeros_0.json": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/DMLY6XP6QTS3ZPAAXQQN4VR72RBHXLCD3EYHGZKXQBPDY7DRKV2A/triton_red_fused_zeros_0.json"}} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/DMLY6XP6QTS3ZPAAXQQN4VR72RBHXLCD3EYHGZKXQBPDY7DRKV2A/triton_red_fused_zeros_0.cubin b/SpecForge-ext/cache/compiled_kernels/triton/0/DMLY6XP6QTS3ZPAAXQQN4VR72RBHXLCD3EYHGZKXQBPDY7DRKV2A/triton_red_fused_zeros_0.cubin new file mode 100644 index 0000000000000000000000000000000000000000..6bd75961e8746de4eb9073a1f89eba64d49c72e6 Binary files /dev/null and b/SpecForge-ext/cache/compiled_kernels/triton/0/DMLY6XP6QTS3ZPAAXQQN4VR72RBHXLCD3EYHGZKXQBPDY7DRKV2A/triton_red_fused_zeros_0.cubin differ diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/DMLY6XP6QTS3ZPAAXQQN4VR72RBHXLCD3EYHGZKXQBPDY7DRKV2A/triton_red_fused_zeros_0.json b/SpecForge-ext/cache/compiled_kernels/triton/0/DMLY6XP6QTS3ZPAAXQQN4VR72RBHXLCD3EYHGZKXQBPDY7DRKV2A/triton_red_fused_zeros_0.json new file mode 100644 index 0000000000000000000000000000000000000000..ba053c60a955b83d36a1f6296a066d6581f27134 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/DMLY6XP6QTS3ZPAAXQQN4VR72RBHXLCD3EYHGZKXQBPDY7DRKV2A/triton_red_fused_zeros_0.json @@ -0,0 +1 @@ +{"hash": "1b178f5dfe84e5bcbc00bc20de563fd4427bac43d930736557805e3c7c715574", "target": {"backend": "cuda", "arch": 90, "warp_size": 32}, "num_warps": 4, "num_ctas": 1, "num_stages": 1, "warp_size": 32, "maxnreg": null, "cluster_dims": [1, 1, 1], "ptx_version": null, "ptx_options": null, "ir_override": null, "enable_fp_fusion": true, "launch_cooperative_grid": false, "launch_pdl": false, "supported_fp8_dtypes": ["fp8e4b15", "fp8e4nv", "fp8e5"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b15"], "default_dot_input_precision": "tf32", "allowed_dot_input_precisions": ["tf32", "tf32x3", "ieee"], "max_num_imprecise_acc_default": 1073741824, "extern_libs": [["libdevice", "/workspace/specforge/lib/python3.11/site-packages/triton/backends/nvidia/lib/libdevice.10.bc"]], "debug": true, "backend_name": "cuda", "sanitize_overflow": false, "arch": "sm90", "instrumentation_mode": "", "triton_version": "3.5.1", "tensordesc_meta": [], "shared": 16, "tmem_size": 0, "global_scratch_size": 0, "global_scratch_align": 1, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "triton_red_fused_zeros_0"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/DMLY6XP6QTS3ZPAAXQQN4VR72RBHXLCD3EYHGZKXQBPDY7DRKV2A/triton_red_fused_zeros_0.llir b/SpecForge-ext/cache/compiled_kernels/triton/0/DMLY6XP6QTS3ZPAAXQQN4VR72RBHXLCD3EYHGZKXQBPDY7DRKV2A/triton_red_fused_zeros_0.llir new file mode 100644 index 0000000000000000000000000000000000000000..af1868ef364781cac51f8e8a270895f1bc768883 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/DMLY6XP6QTS3ZPAAXQQN4VR72RBHXLCD3EYHGZKXQBPDY7DRKV2A/triton_red_fused_zeros_0.llir @@ -0,0 +1,158 @@ +; ModuleID = 'LLVMDialectModule' +source_filename = "LLVMDialectModule" +target datalayout = "e-p3:32:32-p4:32:32-p5:32:32-p6:32:32-p7:32:32-i64:64-i128:128-v16:16-v32:32-n16:32:64" + +@global_smem = external local_unnamed_addr addrspace(3) global [0 x i8], align 16 + +; Function Attrs: nounwind +define ptx_kernel void @triton_red_fused_zeros_0(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, i32 %3, i32 %4, ptr addrspace(1) readnone captures(none) %5, ptr addrspace(1) readnone captures(none) %6) local_unnamed_addr #0 !dbg !4 { + %8 = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !dbg !7 + %9 = shl i32 %8, 2, !dbg !8 + %10 = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !dbg !9 + %11 = and i32 %10, 96, !dbg !9 + %12 = lshr exact i32 %11, 5, !dbg !9 + %13 = and i32 %10, 3, !dbg !9 + %14 = or disjoint i32 %12, %9, !dbg !10 + %15 = or disjoint i32 %9, %13, !dbg !10 + %16 = shl nuw nsw i32 %10, 2, !dbg !11 + %17 = and i32 %16, 124, !dbg !11 + %18 = sdiv i32 %14, 2048, !dbg !12 + %19 = mul i32 %18, 2048, !dbg !13 + %.decomposed = sub i32 %14, %19, !dbg !13 + %20 = srem i32 %18, 32, !dbg !14 + %21 = sdiv i32 %14, 65536, !dbg !15 + %22 = shl nsw i32 %20, 7, !dbg !16 + %23 = shl nsw i32 %.decomposed, 12, !dbg !17 + %24 = shl i32 %21, 23, !dbg !18 + %25 = or disjoint i32 %23, %17, !dbg !19 + %26 = add i32 %25, %24, !dbg !20 + %27 = add i32 %26, %22, !dbg !21 + %28 = sext i32 %27 to i64, !dbg !22 + %29 = getelementptr bfloat, ptr addrspace(1) %0, i64 %28, !dbg !22 + %30 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_first.b64 $0, 1.0;", "=l"() #4, !dbg !23 + %31 = tail call { i32, i32 } asm sideeffect "mov.u32 $0, $2;\0A\09mov.u32 $1, $3;\0A\09@$6 ld.global.L1::evict_first.L2::cache_hint.v2.b32 { $0, $1 }, [ $4 + 0 ], $5;", "=r,=r,r,r,l,l,b"(i32 0, i32 0, ptr addrspace(1) %29, i64 %30, i1 true) #4, !dbg !23 + %32 = extractvalue { i32, i32 } %31, 0, !dbg !23 + %33 = bitcast i32 %32 to <2 x bfloat>, !dbg !23 + %34 = extractvalue { i32, i32 } %31, 1, !dbg !23 + %35 = bitcast i32 %34 to <2 x bfloat>, !dbg !23 + %36 = shl i32 %14, 7, !dbg !24 + %37 = or disjoint i32 %36, %17, !dbg !25 + %38 = sext i32 %37 to i64, !dbg !26 + %39 = getelementptr bfloat, ptr addrspace(1) %1, i64 %38, !dbg !26 + %40 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_first.b64 $0, 1.0;", "=l"() #4, !dbg !27 + %41 = tail call { i32, i32 } asm sideeffect "mov.u32 $0, $2;\0A\09mov.u32 $1, $3;\0A\09@$6 ld.global.L1::evict_first.L2::cache_hint.v2.b32 { $0, $1 }, [ $4 + 0 ], $5;", "=r,=r,r,r,l,l,b"(i32 0, i32 0, ptr addrspace(1) %39, i64 %40, i1 true) #4, !dbg !27 + %42 = extractvalue { i32, i32 } %41, 0, !dbg !27 + %43 = bitcast i32 %42 to <2 x bfloat>, !dbg !27 + %44 = extractvalue { i32, i32 } %41, 1, !dbg !27 + %45 = bitcast i32 %44 to <2 x bfloat>, !dbg !27 + %46 = fpext <2 x bfloat> %33 to <2 x float>, !dbg !28 + %47 = fpext <2 x bfloat> %43 to <2 x float>, !dbg !29 + %48 = fmul <2 x float> %46, %47, !dbg !30 + %49 = fadd <2 x float> %48, zeroinitializer, !dbg !31 + %50 = fpext <2 x bfloat> %35 to <2 x float>, !dbg !28 + %51 = fpext <2 x bfloat> %45 to <2 x float>, !dbg !29 + %52 = fmul <2 x float> %50, %51, !dbg !30 + %53 = fadd <2 x float> %52, zeroinitializer, !dbg !31 + %shift = shufflevector <2 x float> %49, <2 x float> poison, <2 x i32> , !dbg !32 + %foldExtExtBinop = fadd <2 x float> %49, %shift, !dbg !32 + %foldExtExtBinop2 = fadd <2 x float> %53, %foldExtExtBinop, !dbg !32 + %shift4 = shufflevector <2 x float> %53, <2 x float> poison, <2 x i32> , !dbg !32 + %foldExtExtBinop5 = fadd <2 x float> %shift4, %foldExtExtBinop2, !dbg !32 + %54 = extractelement <2 x float> %foldExtExtBinop5, i64 0, !dbg !32 + %55 = bitcast float %54 to i32, !dbg !36 + %56 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %55, i32 16, i32 31), !dbg !36 + %57 = bitcast i32 %56 to float, !dbg !36 + %58 = fadd float %54, %57, !dbg !32 + %59 = bitcast float %58 to i32, !dbg !36 + %60 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %59, i32 8, i32 31), !dbg !36 + %61 = bitcast i32 %60 to float, !dbg !36 + %62 = fadd float %58, %61, !dbg !32 + %63 = bitcast float %62 to i32, !dbg !36 + %64 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %63, i32 4, i32 31), !dbg !36 + %65 = bitcast i32 %64 to float, !dbg !36 + %66 = fadd float %62, %65, !dbg !32 + %67 = bitcast float %66 to i32, !dbg !36 + %68 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %67, i32 2, i32 31), !dbg !36 + %69 = bitcast i32 %68 to float, !dbg !36 + %70 = fadd float %66, %69, !dbg !32 + %71 = bitcast float %70 to i32, !dbg !36 + %72 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %71, i32 1, i32 31), !dbg !36 + %73 = bitcast i32 %72 to float, !dbg !36 + %74 = fadd float %70, %73, !dbg !32 + %75 = lshr exact i32 %11, 3, !dbg !37 + %76 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %75, !dbg !37 + store float %74, ptr addrspace(3) %76, align 4, !dbg !37 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !37 + %77 = shl nuw nsw i32 %13, 2, !dbg !37 + %78 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %77, !dbg !37 + %79 = load i32, ptr addrspace(3) %78, align 4, !dbg !37 + %80 = sext i32 %15 to i64, !dbg !38 + %81 = getelementptr float, ptr addrspace(1) %2, i64 %80, !dbg !38 + %82 = and i32 %10, 124, !dbg !39 + %83 = icmp eq i32 %82, 0, !dbg !39 + tail call void asm sideeffect "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b"(i32 %79, ptr addrspace(1) %81, i1 %83) #4, !dbg !39 + ret void, !dbg !40 +} + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare noundef range(i32 0, 2147483647) i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #1 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare noundef range(i32 0, 1024) i32 @llvm.nvvm.read.ptx.sreg.tid.x() #1 + +; Function Attrs: convergent nocallback nounwind memory(inaccessiblemem: readwrite) +declare i32 @llvm.nvvm.shfl.sync.bfly.i32(i32, i32, i32, i32) #2 + +; Function Attrs: convergent nocallback nounwind +declare void @llvm.nvvm.barrier.cta.sync.aligned.all(i32) #3 + +attributes #0 = { nounwind "nvvm.reqntid"="128" } +attributes #1 = { mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) } +attributes #2 = { convergent nocallback nounwind memory(inaccessiblemem: readwrite) } +attributes #3 = { convergent nocallback nounwind } +attributes #4 = { nounwind } + +!llvm.dbg.cu = !{!0} +!llvm.module.flags = !{!2, !3} + +!0 = distinct !DICompileUnit(language: DW_LANG_C, file: !1, producer: "triton", isOptimized: true, runtimeVersion: 0, emissionKind: LineTablesOnly) +!1 = !DIFile(filename: "cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py", directory: "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4") +!2 = !{i32 2, !"Debug Info Version", i32 3} +!3 = !{i32 4, !"nvvm-reflect-ftz", i32 1} +!4 = distinct !DISubprogram(name: "triton_red_fused_zeros_0", linkageName: "triton_red_fused_zeros_0", scope: !1, file: !1, line: 18, type: !5, scopeLine: 18, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !0) +!5 = !DISubroutineType(cc: DW_CC_normal, types: !6) +!6 = !{} +!7 = !DILocation(line: 23, column: 28, scope: !4) +!8 = !DILocation(line: 23, column: 33, scope: !4) +!9 = !DILocation(line: 24, column: 44, scope: !4) +!10 = !DILocation(line: 24, column: 23, scope: !4) +!11 = !DILocation(line: 26, column: 37, scope: !4) +!12 = !DILocation(line: 29, column: 21, scope: !4) +!13 = !DILocation(line: 28, column: 19, scope: !4) +!14 = !DILocation(line: 29, column: 29, scope: !4) +!15 = !DILocation(line: 30, column: 19, scope: !4) +!16 = !DILocation(line: 39, column: 45, scope: !4) +!17 = !DILocation(line: 39, column: 55, scope: !4) +!18 = !DILocation(line: 39, column: 68, scope: !4) +!19 = !DILocation(line: 39, column: 41, scope: !4) +!20 = !DILocation(line: 39, column: 50, scope: !4) +!21 = !DILocation(line: 39, column: 60, scope: !4) +!22 = !DILocation(line: 39, column: 34, scope: !4) +!23 = !DILocation(line: 39, column: 73, scope: !4) +!24 = !DILocation(line: 40, column: 45, scope: !4) +!25 = !DILocation(line: 40, column: 41, scope: !4) +!26 = !DILocation(line: 40, column: 34, scope: !4) +!27 = !DILocation(line: 40, column: 50, scope: !4) +!28 = !DILocation(line: 39, column: 127, scope: !4) +!29 = !DILocation(line: 40, column: 104, scope: !4) +!30 = !DILocation(line: 41, column: 22, scope: !4) +!31 = !DILocation(line: 43, column: 23, scope: !4) +!32 = !DILocation(line: 261, column: 15, scope: !33, inlinedAt: !35) +!33 = distinct !DILexicalBlockFile(scope: !4, file: !34, discriminator: 0) +!34 = !DIFile(filename: "standard.py", directory: "/workspace/specforge/lib/python3.11/site-packages/triton/language") +!35 = !DILocation(line: 45, column: 25, scope: !4) +!36 = !DILocation(line: 291, column: 36, scope: !33, inlinedAt: !35) +!37 = !DILocation(line: 45, column: 28, scope: !4) +!38 = !DILocation(line: 49, column: 25, scope: !4) +!39 = !DILocation(line: 49, column: 36, scope: !4) +!40 = !DILocation(line: 49, column: 4, scope: !4) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/DMLY6XP6QTS3ZPAAXQQN4VR72RBHXLCD3EYHGZKXQBPDY7DRKV2A/triton_red_fused_zeros_0.ptx b/SpecForge-ext/cache/compiled_kernels/triton/0/DMLY6XP6QTS3ZPAAXQQN4VR72RBHXLCD3EYHGZKXQBPDY7DRKV2A/triton_red_fused_zeros_0.ptx new file mode 100644 index 0000000000000000000000000000000000000000..6c7c8986621c6f63dbeab2bf5d73c3c0770965ca --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/DMLY6XP6QTS3ZPAAXQQN4VR72RBHXLCD3EYHGZKXQBPDY7DRKV2A/triton_red_fused_zeros_0.ptx @@ -0,0 +1,412 @@ +// +// Generated by LLVM NVPTX Back-End +// + +.version 8.7 +.target sm_90a +.address_size 64 + + // .globl triton_red_fused_zeros_0 // -- Begin function triton_red_fused_zeros_0 +.extern .shared .align 16 .b8 global_smem[]; + // @triton_red_fused_zeros_0 +.visible .entry triton_red_fused_zeros_0( + .param .u64 .ptr .global .align 1 triton_red_fused_zeros_0_param_0, + .param .u64 .ptr .global .align 1 triton_red_fused_zeros_0_param_1, + .param .u64 .ptr .global .align 1 triton_red_fused_zeros_0_param_2, + .param .u32 triton_red_fused_zeros_0_param_3, + .param .u32 triton_red_fused_zeros_0_param_4, + .param .u64 .ptr .global .align 1 triton_red_fused_zeros_0_param_5, + .param .u64 .ptr .global .align 1 triton_red_fused_zeros_0_param_6 +) +.reqntid 128 +{ + .reg .pred %p<4>; + .reg .b16 %rs<9>; + .reg .b32 %r<72>; + .reg .b64 %rd<11>; + .loc 1 18 0 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:18:0 +$L__func_begin0: + .loc 1 18 0 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:18:0 + +// %bb.0: + ld.param.b64 %rd8, [triton_red_fused_zeros_0_param_0]; + ld.param.b64 %rd9, [triton_red_fused_zeros_0_param_1]; +$L__tmp0: + .loc 1 23 28 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:23:28 + mov.u32 %r10, %ctaid.x; + .loc 1 23 33 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:23:33 + shl.b32 %r11, %r10, 2; + ld.param.b64 %rd10, [triton_red_fused_zeros_0_param_2]; + .loc 1 24 44 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:24:44 + mov.u32 %r12, %tid.x; + and.b32 %r13, %r12, 96; + bfe.u32 %r14, %r12, 5, 2; + and.b32 %r15, %r12, 3; + .loc 1 24 23 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:24:23 + or.b32 %r16, %r14, %r11; + or.b32 %r17, %r11, %r15; + .loc 1 26 37 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:26:37 + shl.b32 %r18, %r12, 2; + and.b32 %r19, %r18, 124; + .loc 1 29 21 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:29:21 + bfe.s32 %r20, %r10, 29, 1; + shr.u32 %r21, %r20, 21; + add.s32 %r22, %r16, %r21; + shr.s32 %r23, %r22, 11; + .loc 1 28 19 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:28:19 + and.b32 %r24, %r22, 1046528; + sub.s32 %r25, %r16, %r24; + .loc 1 29 29 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:29:29 + shr.u32 %r26, %r23, 27; + add.s32 %r27, %r23, %r26; + and.b32 %r28, %r27, 33554400; + sub.s32 %r29, %r23, %r28; + .loc 1 30 19 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:30:19 + shr.u32 %r30, %r20, 16; + add.s32 %r31, %r16, %r30; + .loc 1 39 45 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:39:45 + shl.b32 %r32, %r29, 7; + .loc 1 39 55 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:39:55 + shl.b32 %r33, %r25, 12; + .loc 1 39 68 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:39:68 + shl.b32 %r34, %r31, 7; + and.b32 %r35, %r34, -8388608; + .loc 1 39 41 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:39:41 + or.b32 %r36, %r33, %r19; + .loc 1 39 50 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:39:50 + add.s32 %r37, %r36, %r35; + .loc 1 39 60 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:39:60 + add.s32 %r38, %r37, %r32; + .loc 1 39 34 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:39:34 + mad.wide.s32 %rd2, %r38, 2, %rd8; + .loc 1 39 73 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:39:73 + // begin inline asm + mov.u64 %rd3, 0x0; + createpolicy.fractional.L2::evict_first.b64 %rd3, 1.0; + // end inline asm + mov.b32 %r3, 0; + mov.pred %p1, -1; + // begin inline asm + mov.u32 %r1, %r3; + mov.u32 %r2, %r3; + @%p1 ld.global.L1::evict_first.L2::cache_hint.v2.b32 { %r1, %r2 }, [ %rd2 + 0 ], %rd3; + // end inline asm + .loc 1 40 45 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:40:45 + shl.b32 %r39, %r16, 7; + .loc 1 40 41 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:40:41 + or.b32 %r40, %r39, %r19; + .loc 1 40 34 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:40:34 + mad.wide.s32 %rd5, %r40, 2, %rd9; + .loc 1 40 50 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:40:50 + // begin inline asm + mov.u64 %rd6, 0x0; + createpolicy.fractional.L2::evict_first.b64 %rd6, 1.0; + // end inline asm + // begin inline asm + mov.u32 %r5, %r3; + mov.u32 %r6, %r3; + @%p1 ld.global.L1::evict_first.L2::cache_hint.v2.b32 { %r5, %r6 }, [ %rd5 + 0 ], %rd6; + // end inline asm + .loc 1 39 127 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:39:127 + mov.b32 {%rs1, %rs2}, %r1; + cvt.f32.bf16 %r41, %rs1; + cvt.f32.bf16 %r42, %rs2; + .loc 1 40 104 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:40:104 + mov.b32 {%rs3, %rs4}, %r5; + cvt.f32.bf16 %r43, %rs3; + cvt.f32.bf16 %r44, %rs4; + .loc 1 43 23 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:43:23 + fma.rn.f32 %r45, %r42, %r44, 0f00000000; + fma.rn.f32 %r46, %r41, %r43, 0f00000000; + .loc 1 39 127 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:39:127 + mov.b32 {%rs5, %rs6}, %r2; + cvt.f32.bf16 %r47, %rs5; + cvt.f32.bf16 %r48, %rs6; + .loc 1 40 104 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:40:104 + mov.b32 {%rs7, %rs8}, %r6; + cvt.f32.bf16 %r49, %rs7; + cvt.f32.bf16 %r50, %rs8; + .loc 1 43 23 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:43:23 + fma.rn.f32 %r51, %r48, %r50, 0f00000000; + fma.rn.f32 %r52, %r47, %r49, 0f00000000; +$L__tmp1: + .loc 2 261 15 // standard.py:261:15 @[ cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:45:25 ] + add.f32 %r53, %r46, %r45; + add.f32 %r54, %r52, %r53; + add.f32 %r55, %r51, %r54; + .loc 2 291 36 // standard.py:291:36 @[ cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:45:25 ] + shfl.sync.bfly.b32 %r56, %r55, 16, 31, -1; + .loc 2 261 15 // standard.py:261:15 @[ cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:45:25 ] + add.f32 %r57, %r55, %r56; + .loc 2 291 36 // standard.py:291:36 @[ cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:45:25 ] + shfl.sync.bfly.b32 %r58, %r57, 8, 31, -1; + .loc 2 261 15 // standard.py:261:15 @[ cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:45:25 ] + add.f32 %r59, %r57, %r58; + .loc 2 291 36 // standard.py:291:36 @[ cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:45:25 ] + shfl.sync.bfly.b32 %r60, %r59, 4, 31, -1; + .loc 2 261 15 // standard.py:261:15 @[ cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:45:25 ] + add.f32 %r61, %r59, %r60; + .loc 2 291 36 // standard.py:291:36 @[ cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:45:25 ] + shfl.sync.bfly.b32 %r62, %r61, 2, 31, -1; + .loc 2 261 15 // standard.py:261:15 @[ cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:45:25 ] + add.f32 %r63, %r61, %r62; + .loc 2 291 36 // standard.py:291:36 @[ cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:45:25 ] + shfl.sync.bfly.b32 %r64, %r63, 1, 31, -1; + .loc 2 261 15 // standard.py:261:15 @[ cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:45:25 ] + add.f32 %r65, %r63, %r64; +$L__tmp2: + .loc 1 45 28 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:45:28 + shr.u32 %r66, %r13, 3; + mov.b32 %r67, global_smem; + add.s32 %r68, %r67, %r66; + st.shared.b32 [%r68], %r65; + bar.sync 0; + shl.b32 %r69, %r15, 2; + add.s32 %r70, %r67, %r69; + ld.shared.b32 %r9, [%r70]; + .loc 1 49 25 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:49:25 + mad.wide.s32 %rd7, %r17, 4, %rd10; + .loc 1 49 36 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:49:36 + and.b32 %r71, %r12, 124; + setp.eq.b32 %p3, %r71, 0; + // begin inline asm + @%p3 st.global.b32 [ %rd7 + 0 ], { %r9 }; + // end inline asm + .loc 1 49 4 // cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py:49:4 + ret; +$L__tmp3: +$L__func_end0: + // -- End function +} + .file 1 "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py" + .file 2 "/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py" + .section .debug_abbrev + { +.b8 1 // Abbreviation Code +.b8 17 // DW_TAG_compile_unit +.b8 1 // DW_CHILDREN_yes +.b8 37 // DW_AT_producer +.b8 8 // DW_FORM_string +.b8 19 // DW_AT_language +.b8 5 // DW_FORM_data2 +.b8 3 // DW_AT_name +.b8 8 // DW_FORM_string +.b8 16 // DW_AT_stmt_list +.b8 6 // DW_FORM_data4 +.b8 27 // DW_AT_comp_dir +.b8 8 // DW_FORM_string +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 2 // Abbreviation Code +.b8 46 // DW_TAG_subprogram +.b8 0 // DW_CHILDREN_no +.b8 3 // DW_AT_name +.b8 8 // DW_FORM_string +.b8 32 // DW_AT_inline +.b8 11 // DW_FORM_data1 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 3 // Abbreviation Code +.b8 46 // DW_TAG_subprogram +.b8 1 // DW_CHILDREN_yes +.b8 17 // DW_AT_low_pc +.b8 1 // DW_FORM_addr +.b8 18 // DW_AT_high_pc +.b8 1 // DW_FORM_addr +.b8 49 // DW_AT_abstract_origin +.b8 19 // DW_FORM_ref4 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 4 // Abbreviation Code +.b8 29 // DW_TAG_inlined_subroutine +.b8 0 // DW_CHILDREN_no +.b8 49 // DW_AT_abstract_origin +.b8 19 // DW_FORM_ref4 +.b8 17 // DW_AT_low_pc +.b8 1 // DW_FORM_addr +.b8 18 // DW_AT_high_pc +.b8 1 // DW_FORM_addr +.b8 88 // DW_AT_call_file +.b8 11 // DW_FORM_data1 +.b8 89 // DW_AT_call_line +.b8 11 // DW_FORM_data1 +.b8 87 // DW_AT_call_column +.b8 11 // DW_FORM_data1 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 0 // EOM(3) + } + .section .debug_info + { +.b32 209 // Length of Unit +.b8 2 // DWARF version number +.b8 0 +.b32 .debug_abbrev // Offset Into Abbrev. Section +.b8 8 // Address Size (in bytes) +.b8 1 // Abbrev [1] 0xb:0xca DW_TAG_compile_unit +.b8 116 // DW_AT_producer +.b8 114 +.b8 105 +.b8 116 +.b8 111 +.b8 110 +.b8 0 +.b8 2 // DW_AT_language +.b8 0 +.b8 99 // DW_AT_name +.b8 122 +.b8 52 +.b8 105 +.b8 53 +.b8 112 +.b8 107 +.b8 117 +.b8 99 +.b8 117 +.b8 51 +.b8 108 +.b8 109 +.b8 120 +.b8 54 +.b8 113 +.b8 102 +.b8 100 +.b8 113 +.b8 97 +.b8 102 +.b8 106 +.b8 50 +.b8 122 +.b8 103 +.b8 103 +.b8 117 +.b8 110 +.b8 52 +.b8 121 +.b8 106 +.b8 114 +.b8 116 +.b8 103 +.b8 55 +.b8 97 +.b8 120 +.b8 50 +.b8 117 +.b8 102 +.b8 113 +.b8 116 +.b8 104 +.b8 107 +.b8 101 +.b8 52 +.b8 122 +.b8 97 +.b8 51 +.b8 102 +.b8 102 +.b8 55 +.b8 46 +.b8 112 +.b8 121 +.b8 0 +.b32 .debug_line // DW_AT_stmt_list +.b8 47 // DW_AT_comp_dir +.b8 119 +.b8 111 +.b8 114 +.b8 107 +.b8 115 +.b8 112 +.b8 97 +.b8 99 +.b8 101 +.b8 47 +.b8 104 +.b8 97 +.b8 110 +.b8 114 +.b8 117 +.b8 105 +.b8 47 +.b8 83 +.b8 112 +.b8 101 +.b8 99 +.b8 70 +.b8 111 +.b8 114 +.b8 103 +.b8 101 +.b8 45 +.b8 101 +.b8 120 +.b8 116 +.b8 47 +.b8 99 +.b8 97 +.b8 99 +.b8 104 +.b8 101 +.b8 47 +.b8 99 +.b8 111 +.b8 109 +.b8 112 +.b8 105 +.b8 108 +.b8 101 +.b8 100 +.b8 95 +.b8 107 +.b8 101 +.b8 114 +.b8 110 +.b8 101 +.b8 108 +.b8 115 +.b8 47 +.b8 122 +.b8 52 +.b8 0 +.b8 2 // Abbrev [2] 0x8b:0x1b DW_TAG_subprogram +.b8 116 // DW_AT_name +.b8 114 +.b8 105 +.b8 116 +.b8 111 +.b8 110 +.b8 95 +.b8 114 +.b8 101 +.b8 100 +.b8 95 +.b8 102 +.b8 117 +.b8 115 +.b8 101 +.b8 100 +.b8 95 +.b8 122 +.b8 101 +.b8 114 +.b8 111 +.b8 115 +.b8 95 +.b8 48 +.b8 0 +.b8 1 // DW_AT_inline +.b8 3 // Abbrev [3] 0xa6:0x2e DW_TAG_subprogram +.b64 $L__func_begin0 // DW_AT_low_pc +.b64 $L__func_end0 // DW_AT_high_pc +.b32 139 // DW_AT_abstract_origin +.b8 4 // Abbrev [4] 0xbb:0x18 DW_TAG_inlined_subroutine +.b32 139 // DW_AT_abstract_origin +.b64 $L__tmp1 // DW_AT_low_pc +.b64 $L__tmp2 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 45 // DW_AT_call_line +.b8 25 // DW_AT_call_column +.b8 0 // End Of Children Mark +.b8 0 // End Of Children Mark + } + .section .debug_macinfo { } diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/DMLY6XP6QTS3ZPAAXQQN4VR72RBHXLCD3EYHGZKXQBPDY7DRKV2A/triton_red_fused_zeros_0.source b/SpecForge-ext/cache/compiled_kernels/triton/0/DMLY6XP6QTS3ZPAAXQQN4VR72RBHXLCD3EYHGZKXQBPDY7DRKV2A/triton_red_fused_zeros_0.source new file mode 100644 index 0000000000000000000000000000000000000000..a53070604fe63608dfb52a4d67332e5617575f48 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/DMLY6XP6QTS3ZPAAXQQN4VR72RBHXLCD3EYHGZKXQBPDY7DRKV2A/triton_red_fused_zeros_0.source @@ -0,0 +1,222 @@ +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":18:0) +#loc44 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":285:0) +#loc46 = loc(unknown) +#loc49 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":260:0) +#loc53 = loc("in_ptr0"(#loc)) +#loc54 = loc("in_ptr1"(#loc)) +#loc55 = loc("out_ptr1"(#loc)) +#loc56 = loc("xnumel"(#loc)) +#loc57 = loc("r0_numel"(#loc)) +#loc97 = loc("input"(#loc44)) +#loc98 = loc("a"(#loc49)) +#loc99 = loc("b"(#loc49)) +module { + tt.func public @triton_red_fused_zeros_0(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %in_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr1"(#loc)), %out_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr1"(#loc)), %xnumel: i32 {tt.divisibility = 16 : i32} loc("xnumel"(#loc)), %r0_numel: i32 {tt.divisibility = 16 : i32} loc("r0_numel"(#loc))) attributes {noinline = false} { + %xnumel_0 = arith.constant 524288 : i32 loc(#loc58) + %r0_numel_1 = arith.constant 128 : i32 loc(#loc59) + %xoffset = tt.get_program_id x : i32 loc(#loc60) + %xoffset_2 = arith.constant 4 : i32 loc(#loc61) + %xoffset_3 = arith.constant 4 : i32 loc(#loc61) + %xoffset_4 = arith.muli %xoffset, %xoffset_3 : i32 loc(#loc61) + %xindex = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> loc(#loc62) + %xindex_5 = tt.expand_dims %xindex {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> loc(#loc63) + %xindex_6 = tt.splat %xoffset_4 : i32 -> tensor<4x1xi32> loc(#loc64) + %xindex_7 = arith.addi %xindex_6, %xindex_5 : tensor<4x1xi32> loc(#loc64) + %xmask = arith.constant true loc(#loc65) + %xmask_8 = arith.constant dense : tensor<4x128xi1> loc(#loc65) + %r0_base = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc66) + %r0_base_9 = tt.expand_dims %r0_base {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc67) + %x0 = arith.constant 2048 : i32 loc(#loc68) + %x0_10 = arith.constant 2048 : i32 loc(#loc68) + %x0_11 = arith.constant dense<2048> : tensor<4x1xi32> loc(#loc68) + %x0_12 = arith.remsi %xindex_7, %x0_11 : tensor<4x1xi32> loc(#loc68) + %x1 = arith.constant 2048 : i32 loc(#loc69) + %x1_13 = arith.constant 2048 : i32 loc(#loc69) + %x1_14 = arith.constant dense<2048> : tensor<4x1xi32> loc(#loc69) + %x1_15 = arith.divsi %xindex_7, %x1_14 : tensor<4x1xi32> loc(#loc69) + %x1_16 = arith.constant 32 : i32 loc(#loc70) + %x1_17 = arith.constant 32 : i32 loc(#loc70) + %x1_18 = arith.constant dense<32> : tensor<4x1xi32> loc(#loc70) + %x1_19 = arith.remsi %x1_15, %x1_18 : tensor<4x1xi32> loc(#loc70) + %x2 = arith.constant 65536 : i32 loc(#loc71) + %x2_20 = arith.constant 65536 : i32 loc(#loc71) + %x2_21 = arith.constant dense<65536> : tensor<4x1xi32> loc(#loc71) + %x2_22 = arith.divsi %xindex_7, %x2_21 : tensor<4x1xi32> loc(#loc71) + %_tmp4 = arith.constant 0.000000e+00 : f32 loc(#loc72) + %_tmp4_23 = arith.constant dense<0.000000e+00> : tensor<4x128xf32> loc(#loc72) + %c0_i32 = arith.constant 0 : i32 loc(#loc16) + %c128_i32 = arith.constant 128 : i32 loc(#loc16) + %0 = arith.bitcast %c0_i32 : i32 to i32 loc(#loc16) + %1 = arith.bitcast %r0_numel_1 : i32 to i32 loc(#loc16) + %2 = arith.bitcast %c128_i32 : i32 to i32 loc(#loc16) + %3 = ub.poison : i32 loc(#loc16) + %_tmp4_24 = scf.for %r0_offset = %0 to %1 step %2 iter_args(%_tmp4_27 = %_tmp4_23) -> (tensor<4x128xf32>) : i32 { + %r0_index = tt.splat %r0_offset : i32 -> tensor<1x128xi32> loc(#loc74) + %r0_index_28 = arith.addi %r0_index, %r0_base_9 : tensor<1x128xi32> loc(#loc74) + %r0_mask = arith.constant dense<128> : tensor<1x128xi32> loc(#loc75) + %r0_mask_29 = arith.cmpi slt, %r0_index_28, %r0_mask : tensor<1x128xi32> loc(#loc75) + %tmp0 = arith.constant 128 : i32 loc(#loc76) + %tmp0_30 = arith.constant 128 : i32 loc(#loc76) + %tmp0_31 = arith.constant dense<128> : tensor<4x1xi32> loc(#loc76) + %tmp0_32 = arith.muli %tmp0_31, %x1_19 : tensor<4x1xi32> loc(#loc76) + %tmp0_33 = tt.broadcast %r0_index_28 : tensor<1x128xi32> -> tensor<4x128xi32> loc(#loc77) + %tmp0_34 = tt.broadcast %tmp0_32 : tensor<4x1xi32> -> tensor<4x128xi32> loc(#loc77) + %tmp0_35 = arith.addi %tmp0_33, %tmp0_34 : tensor<4x128xi32> loc(#loc77) + %tmp0_36 = arith.constant 4096 : i32 loc(#loc78) + %tmp0_37 = arith.constant 4096 : i32 loc(#loc78) + %tmp0_38 = arith.constant dense<4096> : tensor<4x1xi32> loc(#loc78) + %tmp0_39 = arith.muli %tmp0_38, %x0_12 : tensor<4x1xi32> loc(#loc78) + %tmp0_40 = tt.broadcast %tmp0_39 : tensor<4x1xi32> -> tensor<4x128xi32> loc(#loc79) + %tmp0_41 = arith.addi %tmp0_35, %tmp0_40 : tensor<4x128xi32> loc(#loc79) + %tmp0_42 = arith.constant 8388608 : i32 loc(#loc80) + %tmp0_43 = arith.constant 8388608 : i32 loc(#loc80) + %tmp0_44 = arith.constant dense<8388608> : tensor<4x1xi32> loc(#loc80) + %tmp0_45 = arith.muli %tmp0_44, %x2_22 : tensor<4x1xi32> loc(#loc80) + %tmp0_46 = tt.broadcast %tmp0_45 : tensor<4x1xi32> -> tensor<4x128xi32> loc(#loc81) + %tmp0_47 = arith.addi %tmp0_41, %tmp0_46 : tensor<4x128xi32> loc(#loc81) + %tmp0_48 = tt.splat %in_ptr0 : !tt.ptr -> tensor<4x128x!tt.ptr> loc(#loc82) + %tmp0_49 = tt.addptr %tmp0_48, %tmp0_47 : tensor<4x128x!tt.ptr>, tensor<4x128xi32> loc(#loc82) + %tmp0_50 = arith.constant 0.000000e+00 : f32 loc(#loc83) + %tmp0_51 = tt.broadcast %r0_mask_29 : tensor<1x128xi1> -> tensor<4x128xi1> loc(#loc83) + %tmp0_52 = arith.constant dense<0.000000e+00> : tensor<4x128xf32> loc(#loc83) + %tmp0_53 = arith.truncf %tmp0_52 : tensor<4x128xf32> to tensor<4x128xbf16> loc(#loc83) + %tmp0_54 = tt.load %tmp0_49, %tmp0_51, %tmp0_53 evictionPolicy = evict_first : tensor<4x128x!tt.ptr> loc(#loc83) + %tmp0_55 = arith.extf %tmp0_54 : tensor<4x128xbf16> to tensor<4x128xf32> loc(#loc84) + %tmp1 = arith.constant 128 : i32 loc(#loc85) + %tmp1_56 = arith.constant 128 : i32 loc(#loc85) + %tmp1_57 = arith.constant dense<128> : tensor<4x1xi32> loc(#loc85) + %tmp1_58 = arith.muli %tmp1_57, %xindex_7 : tensor<4x1xi32> loc(#loc85) + %tmp1_59 = tt.broadcast %r0_index_28 : tensor<1x128xi32> -> tensor<4x128xi32> loc(#loc86) + %tmp1_60 = tt.broadcast %tmp1_58 : tensor<4x1xi32> -> tensor<4x128xi32> loc(#loc86) + %tmp1_61 = arith.addi %tmp1_59, %tmp1_60 : tensor<4x128xi32> loc(#loc86) + %tmp1_62 = tt.splat %in_ptr1 : !tt.ptr -> tensor<4x128x!tt.ptr> loc(#loc87) + %tmp1_63 = tt.addptr %tmp1_62, %tmp1_61 : tensor<4x128x!tt.ptr>, tensor<4x128xi32> loc(#loc87) + %tmp1_64 = arith.constant 0.000000e+00 : f32 loc(#loc88) + %tmp1_65 = tt.broadcast %r0_mask_29 : tensor<1x128xi1> -> tensor<4x128xi1> loc(#loc88) + %tmp1_66 = arith.constant dense<0.000000e+00> : tensor<4x128xf32> loc(#loc88) + %tmp1_67 = arith.truncf %tmp1_66 : tensor<4x128xf32> to tensor<4x128xbf16> loc(#loc88) + %tmp1_68 = tt.load %tmp1_63, %tmp1_65, %tmp1_67 evictionPolicy = evict_first : tensor<4x128x!tt.ptr> loc(#loc88) + %tmp1_69 = arith.extf %tmp1_68 : tensor<4x128xbf16> to tensor<4x128xf32> loc(#loc89) + %tmp2 = arith.mulf %tmp0_55, %tmp1_69 : tensor<4x128xf32> loc(#loc90) + %tmp5 = arith.addf %_tmp4_27, %tmp2 : tensor<4x128xf32> loc(#loc91) + %_tmp4_70 = tt.broadcast %r0_mask_29 : tensor<1x128xi1> -> tensor<4x128xi1> loc(#loc92) + %_tmp4_71 = arith.select %_tmp4_70, %tmp5, %_tmp4_27 : tensor<4x128xi1>, tensor<4x128xf32> loc(#loc92) + scf.yield %_tmp4_71 : tensor<4x128xf32> loc(#loc36) + } loc(#loc73) + %tmp4 = tt.call @"triton.language.standard.sum__fp32S4_128S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%_tmp4_24) : (tensor<4x128xf32>) -> tensor<4xf32> loc(#loc93) + %tmp4_25 = tt.expand_dims %tmp4 {axis = 1 : i32} : tensor<4xf32> -> tensor<4x1xf32> loc(#loc94) + %tmp7 = arith.constant 0.000000e+00 : f32 loc(#loc95) + %tmp8 = arith.constant dense<0.000000e+00> : tensor<4x1xf32> loc(#loc96) + %tmp8_26 = arith.subf %tmp4_25, %tmp8 : tensor<4x1xf32> loc(#loc96) + %4 = tt.splat %out_ptr1 : !tt.ptr -> tensor<4x1x!tt.ptr> loc(#loc41) + %5 = tt.addptr %4, %xindex_7 : tensor<4x1x!tt.ptr>, tensor<4x1xi32> loc(#loc41) + tt.store %5, %tmp8_26 : tensor<4x1x!tt.ptr> loc(#loc42) + tt.return loc(#loc43) + } loc(#loc) + tt.func private @"triton.language.standard.sum__fp32S4_128S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%input: tensor<4x128xf32> loc("input"(#loc44))) -> tensor<4xf32> attributes {noinline = false} { + %0 = "tt.reduce"(%input) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32 loc(unknown), %arg2: f32 loc(unknown)): + %2 = tt.call @triton.language.standard._sum_combine__fp32_fp32__(%arg1, %arg2) : (f32, f32) -> f32 loc(#loc45) + tt.reduce.return %2 : f32 loc(#loc45) + }) : (tensor<4x128xf32>) -> tensor<4xf32> loc(#loc45) + tt.return %0 : tensor<4xf32> loc(#loc47) + ^bb1: // no predecessors + %1 = ub.poison : tensor<4xf32> loc(#loc48) + tt.return %1 : tensor<4xf32> loc(#loc48) + } loc(#loc44) + tt.func private @triton.language.standard._sum_combine__fp32_fp32__(%a: f32 loc("a"(#loc49)), %b: f32 loc("b"(#loc49))) -> f32 attributes {noinline = false} { + %0 = arith.addf %a, %b : f32 loc(#loc50) + tt.return %0 : f32 loc(#loc51) + ^bb1: // no predecessors + %1 = ub.poison : f32 loc(#loc52) + tt.return %1 : f32 loc(#loc52) + } loc(#loc49) +} loc(#loc) +#loc1 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":19:13) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":20:15) +#loc3 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":23:28) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":23:33) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":24:36) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":24:44) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":24:23) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":25:46) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":26:27) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":26:37) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":28:19) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":29:21) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":29:29) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":30:19) +#loc15 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":32:43) +#loc16 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":33:40) +#loc17 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":34:31) +#loc18 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":35:29) +#loc19 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:45) +#loc20 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:41) +#loc21 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:55) +#loc22 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:50) +#loc23 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:68) +#loc24 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:60) +#loc25 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:34) +#loc26 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:73) +#loc27 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:127) +#loc28 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":40:45) +#loc29 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":40:41) +#loc30 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":40:34) +#loc31 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":40:50) +#loc32 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":40:104) +#loc33 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":41:22) +#loc34 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":43:23) +#loc35 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":44:40) +#loc36 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":44:8) +#loc37 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":45:25) +#loc38 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":45:28) +#loc39 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":47:11) +#loc40 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":48:18) +#loc41 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":49:25) +#loc42 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":49:36) +#loc43 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":49:4) +#loc45 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:36) +#loc47 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:11) +#loc48 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:4) +#loc50 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:15) +#loc51 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:11) +#loc52 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:4) +#loc58 = loc("xnumel"(#loc1)) +#loc59 = loc("r0_numel"(#loc2)) +#loc60 = loc("xoffset"(#loc3)) +#loc61 = loc("xoffset"(#loc4)) +#loc62 = loc("xindex"(#loc5)) +#loc63 = loc("xindex"(#loc6)) +#loc64 = loc("xindex"(#loc7)) +#loc65 = loc("xmask"(#loc8)) +#loc66 = loc("r0_base"(#loc9)) +#loc67 = loc("r0_base"(#loc10)) +#loc68 = loc("x0"(#loc11)) +#loc69 = loc("x1"(#loc12)) +#loc70 = loc("x1"(#loc13)) +#loc71 = loc("x2"(#loc14)) +#loc72 = loc("_tmp4"(#loc15)) +#loc73 = loc("_tmp4"(#loc16)) +#loc74 = loc("r0_index"(#loc17)) +#loc75 = loc("r0_mask"(#loc18)) +#loc76 = loc("tmp0"(#loc19)) +#loc77 = loc("tmp0"(#loc20)) +#loc78 = loc("tmp0"(#loc21)) +#loc79 = loc("tmp0"(#loc22)) +#loc80 = loc("tmp0"(#loc23)) +#loc81 = loc("tmp0"(#loc24)) +#loc82 = loc("tmp0"(#loc25)) +#loc83 = loc("tmp0"(#loc26)) +#loc84 = loc("tmp0"(#loc27)) +#loc85 = loc("tmp1"(#loc28)) +#loc86 = loc("tmp1"(#loc29)) +#loc87 = loc("tmp1"(#loc30)) +#loc88 = loc("tmp1"(#loc31)) +#loc89 = loc("tmp1"(#loc32)) +#loc90 = loc("tmp2"(#loc33)) +#loc91 = loc("tmp5"(#loc34)) +#loc92 = loc("_tmp4"(#loc35)) +#loc93 = loc("tmp4"(#loc37)) +#loc94 = loc("tmp4"(#loc38)) +#loc95 = loc("tmp7"(#loc39)) +#loc96 = loc("tmp8"(#loc40)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/DMLY6XP6QTS3ZPAAXQQN4VR72RBHXLCD3EYHGZKXQBPDY7DRKV2A/triton_red_fused_zeros_0.ttgir b/SpecForge-ext/cache/compiled_kernels/triton/0/DMLY6XP6QTS3ZPAAXQQN4VR72RBHXLCD3EYHGZKXQBPDY7DRKV2A/triton_red_fused_zeros_0.ttgir new file mode 100644 index 0000000000000000000000000000000000000000..19c7f42d41c6e9b430927ffffcdb4fc3e0e84866 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/DMLY6XP6QTS3ZPAAXQQN4VR72RBHXLCD3EYHGZKXQBPDY7DRKV2A/triton_red_fused_zeros_0.ttgir @@ -0,0 +1,142 @@ +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1]}> +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":18:0) +#loc1 = loc(unknown) +#loc30 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":45:25) +#loc36 = loc("in_ptr0"(#loc)) +#loc37 = loc("in_ptr1"(#loc)) +#loc38 = loc("out_ptr1"(#loc)) +#loc39 = loc("xnumel"(#loc)) +#loc40 = loc("r0_numel"(#loc)) +#loc68 = loc("tmp4"(#loc30)) +#loc71 = loc(callsite(#loc1 at #loc68)) +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @triton_red_fused_zeros_0(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %in_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr1"(#loc)), %out_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr1"(#loc)), %xnumel: i32 {tt.divisibility = 16 : i32} loc("xnumel"(#loc)), %r0_numel: i32 {tt.divisibility = 16 : i32} loc("r0_numel"(#loc))) attributes {noinline = false} { + %cst = arith.constant dense<128> : tensor<1x128xi32, #blocked> loc(#loc1) + %cst_0 = arith.constant dense<128> : tensor<4x1xi32, #blocked> loc(#loc1) + %cst_1 = arith.constant dense<4096> : tensor<4x1xi32, #blocked> loc(#loc1) + %cst_2 = arith.constant dense<8388608> : tensor<4x1xi32, #blocked> loc(#loc1) + %cst_3 = arith.constant dense<65536> : tensor<4x1xi32, #blocked> loc(#loc1) + %cst_4 = arith.constant dense<32> : tensor<4x1xi32, #blocked> loc(#loc1) + %cst_5 = arith.constant dense<2048> : tensor<4x1xi32, #blocked> loc(#loc1) + %c4_i32 = arith.constant 4 : i32 loc(#loc1) + %cst_6 = arith.constant dense<0.000000e+00> : tensor<4x128xbf16, #blocked> loc(#loc1) + %cst_7 = arith.constant dense<0.000000e+00> : tensor<4x128xf32, #blocked> loc(#loc1) + %xoffset = tt.get_program_id x : i32 loc(#loc41) + %xoffset_8 = arith.muli %xoffset, %c4_i32 : i32 loc(#loc42) + %xindex = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc43) + %xindex_9 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc43) + %xindex_10 = tt.expand_dims %xindex {axis = 1 : i32} : tensor<4xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<4x1xi32, #blocked> loc(#loc43) + %xindex_11 = tt.expand_dims %xindex_9 {axis = 1 : i32} : tensor<4xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<4x1xi32, #blocked1> loc(#loc43) + %xindex_12 = tt.splat %xoffset_8 : i32 -> tensor<4x1xi32, #blocked> loc(#loc44) + %xindex_13 = tt.splat %xoffset_8 : i32 -> tensor<4x1xi32, #blocked1> loc(#loc44) + %xindex_14 = arith.addi %xindex_12, %xindex_10 : tensor<4x1xi32, #blocked> loc(#loc44) + %xindex_15 = arith.addi %xindex_13, %xindex_11 : tensor<4x1xi32, #blocked1> loc(#loc44) + %r0_base = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc45) + %r0_base_16 = tt.expand_dims %r0_base {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> loc(#loc45) + %x0 = arith.remsi %xindex_14, %cst_5 : tensor<4x1xi32, #blocked> loc(#loc46) + %x1 = arith.divsi %xindex_14, %cst_5 : tensor<4x1xi32, #blocked> loc(#loc47) + %x1_17 = arith.remsi %x1, %cst_4 : tensor<4x1xi32, #blocked> loc(#loc48) + %x2 = arith.divsi %xindex_14, %cst_3 : tensor<4x1xi32, #blocked> loc(#loc49) + %r0_mask = arith.cmpi slt, %r0_base_16, %cst : tensor<1x128xi32, #blocked> loc(#loc50) + %tmp0 = arith.muli %x1_17, %cst_0 : tensor<4x1xi32, #blocked> loc(#loc51) + %tmp0_18 = tt.broadcast %r0_base_16 : tensor<1x128xi32, #blocked> -> tensor<4x128xi32, #blocked> loc(#loc52) + %tmp0_19 = tt.broadcast %tmp0 : tensor<4x1xi32, #blocked> -> tensor<4x128xi32, #blocked> loc(#loc52) + %tmp0_20 = arith.addi %tmp0_18, %tmp0_19 : tensor<4x128xi32, #blocked> loc(#loc52) + %tmp0_21 = arith.muli %x0, %cst_1 : tensor<4x1xi32, #blocked> loc(#loc53) + %tmp0_22 = tt.broadcast %tmp0_21 : tensor<4x1xi32, #blocked> -> tensor<4x128xi32, #blocked> loc(#loc54) + %tmp0_23 = arith.addi %tmp0_20, %tmp0_22 : tensor<4x128xi32, #blocked> loc(#loc54) + %tmp0_24 = arith.muli %x2, %cst_2 : tensor<4x1xi32, #blocked> loc(#loc55) + %tmp0_25 = tt.broadcast %tmp0_24 : tensor<4x1xi32, #blocked> -> tensor<4x128xi32, #blocked> loc(#loc56) + %tmp0_26 = arith.addi %tmp0_23, %tmp0_25 : tensor<4x128xi32, #blocked> loc(#loc56) + %tmp0_27 = tt.splat %in_ptr0 : !tt.ptr -> tensor<4x128x!tt.ptr, #blocked> loc(#loc57) + %tmp0_28 = tt.addptr %tmp0_27, %tmp0_26 : tensor<4x128x!tt.ptr, #blocked>, tensor<4x128xi32, #blocked> loc(#loc57) + %tmp0_29 = tt.broadcast %r0_mask : tensor<1x128xi1, #blocked> -> tensor<4x128xi1, #blocked> loc(#loc58) + %tmp0_30 = tt.load %tmp0_28, %tmp0_29, %cst_6 evictionPolicy = evict_first : tensor<4x128x!tt.ptr, #blocked> loc(#loc58) + %tmp0_31 = arith.extf %tmp0_30 : tensor<4x128xbf16, #blocked> to tensor<4x128xf32, #blocked> loc(#loc59) + %tmp1 = arith.muli %xindex_14, %cst_0 : tensor<4x1xi32, #blocked> loc(#loc60) + %tmp1_32 = tt.broadcast %tmp1 : tensor<4x1xi32, #blocked> -> tensor<4x128xi32, #blocked> loc(#loc61) + %tmp1_33 = arith.addi %tmp0_18, %tmp1_32 : tensor<4x128xi32, #blocked> loc(#loc61) + %tmp1_34 = tt.splat %in_ptr1 : !tt.ptr -> tensor<4x128x!tt.ptr, #blocked> loc(#loc62) + %tmp1_35 = tt.addptr %tmp1_34, %tmp1_33 : tensor<4x128x!tt.ptr, #blocked>, tensor<4x128xi32, #blocked> loc(#loc62) + %tmp1_36 = tt.load %tmp1_35, %tmp0_29, %cst_6 evictionPolicy = evict_first : tensor<4x128x!tt.ptr, #blocked> loc(#loc63) + %tmp1_37 = arith.extf %tmp1_36 : tensor<4x128xbf16, #blocked> to tensor<4x128xf32, #blocked> loc(#loc64) + %tmp2 = arith.mulf %tmp0_31, %tmp1_37 : tensor<4x128xf32, #blocked> loc(#loc65) + %tmp5 = arith.addf %tmp2, %cst_7 : tensor<4x128xf32, #blocked> loc(#loc66) + %_tmp4 = arith.select %tmp0_29, %tmp5, %cst_7 : tensor<4x128xi1, #blocked>, tensor<4x128xf32, #blocked> loc(#loc67) + %tmp4 = "tt.reduce"(%_tmp4) <{axis = 1 : i32}> ({ + ^bb0(%tmp4_40: f32 loc(callsite(#loc1 at #loc68)), %tmp4_41: f32 loc(callsite(#loc1 at #loc68))): + %tmp4_42 = arith.addf %tmp4_40, %tmp4_41 : f32 loc(#loc72) + tt.reduce.return %tmp4_42 : f32 loc(#loc70) + }) : (tensor<4x128xf32, #blocked>) -> tensor<4xf32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc70) + %tmp4_38 = ttg.convert_layout %tmp4 : tensor<4xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<4xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc69) + %tmp4_39 = tt.expand_dims %tmp4_38 {axis = 1 : i32} : tensor<4xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<4x1xf32, #blocked1> loc(#loc69) + %0 = tt.splat %out_ptr1 : !tt.ptr -> tensor<4x1x!tt.ptr, #blocked1> loc(#loc33) + %1 = tt.addptr %0, %xindex_15 : tensor<4x1x!tt.ptr, #blocked1>, tensor<4x1xi32, #blocked1> loc(#loc33) + tt.store %1, %tmp4_39 : tensor<4x1x!tt.ptr, #blocked1> loc(#loc34) + tt.return loc(#loc35) + } loc(#loc) +} loc(#loc) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":23:28) +#loc3 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":23:33) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":24:44) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":24:23) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":26:37) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":28:19) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":29:21) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":29:29) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":30:19) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":35:29) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:45) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:41) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:55) +#loc15 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:50) +#loc16 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:68) +#loc17 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:60) +#loc18 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:34) +#loc19 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:73) +#loc20 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:127) +#loc21 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":40:45) +#loc22 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":40:41) +#loc23 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":40:34) +#loc24 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":40:50) +#loc25 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":40:104) +#loc26 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":41:22) +#loc27 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":43:23) +#loc28 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":44:40) +#loc29 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:36) +#loc31 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:15) +#loc32 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":45:28) +#loc33 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":49:25) +#loc34 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":49:36) +#loc35 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":49:4) +#loc41 = loc("xoffset"(#loc2)) +#loc42 = loc("xoffset"(#loc3)) +#loc43 = loc("xindex"(#loc4)) +#loc44 = loc("xindex"(#loc5)) +#loc45 = loc("r0_base"(#loc6)) +#loc46 = loc("x0"(#loc7)) +#loc47 = loc("x1"(#loc8)) +#loc48 = loc("x1"(#loc9)) +#loc49 = loc("x2"(#loc10)) +#loc50 = loc("r0_mask"(#loc11)) +#loc51 = loc("tmp0"(#loc12)) +#loc52 = loc("tmp0"(#loc13)) +#loc53 = loc("tmp0"(#loc14)) +#loc54 = loc("tmp0"(#loc15)) +#loc55 = loc("tmp0"(#loc16)) +#loc56 = loc("tmp0"(#loc17)) +#loc57 = loc("tmp0"(#loc18)) +#loc58 = loc("tmp0"(#loc19)) +#loc59 = loc("tmp0"(#loc20)) +#loc60 = loc("tmp1"(#loc21)) +#loc61 = loc("tmp1"(#loc22)) +#loc62 = loc("tmp1"(#loc23)) +#loc63 = loc("tmp1"(#loc24)) +#loc64 = loc("tmp1"(#loc25)) +#loc65 = loc("tmp2"(#loc26)) +#loc66 = loc("tmp5"(#loc27)) +#loc67 = loc("_tmp4"(#loc28)) +#loc69 = loc("tmp4"(#loc32)) +#loc70 = loc(callsite(#loc29 at #loc68)) +#loc72 = loc(callsite(#loc31 at #loc70)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/DMLY6XP6QTS3ZPAAXQQN4VR72RBHXLCD3EYHGZKXQBPDY7DRKV2A/triton_red_fused_zeros_0.ttir b/SpecForge-ext/cache/compiled_kernels/triton/0/DMLY6XP6QTS3ZPAAXQQN4VR72RBHXLCD3EYHGZKXQBPDY7DRKV2A/triton_red_fused_zeros_0.ttir new file mode 100644 index 0000000000000000000000000000000000000000..cbb5dff13532226b0e7c2650e4c12daea72a9da4 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/DMLY6XP6QTS3ZPAAXQQN4VR72RBHXLCD3EYHGZKXQBPDY7DRKV2A/triton_red_fused_zeros_0.ttir @@ -0,0 +1,139 @@ +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":18:0) +#loc1 = loc(unknown) +#loc32 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":45:25) +#loc38 = loc("in_ptr0"(#loc)) +#loc39 = loc("in_ptr1"(#loc)) +#loc40 = loc("out_ptr1"(#loc)) +#loc41 = loc("xnumel"(#loc)) +#loc42 = loc("r0_numel"(#loc)) +#loc72 = loc("tmp4"(#loc32)) +#loc75 = loc(callsite(#loc1 at #loc72)) +module { + tt.func public @triton_red_fused_zeros_0(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %in_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr1"(#loc)), %out_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr1"(#loc)), %xnumel: i32 {tt.divisibility = 16 : i32} loc("xnumel"(#loc)), %r0_numel: i32 {tt.divisibility = 16 : i32} loc("r0_numel"(#loc))) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<4x128xbf16> loc(#loc1) + %cst_0 = arith.constant dense<8388608> : tensor<4x1xi32> loc(#loc1) + %cst_1 = arith.constant dense<4096> : tensor<4x1xi32> loc(#loc1) + %cst_2 = arith.constant dense<128> : tensor<4x1xi32> loc(#loc1) + %cst_3 = arith.constant dense<128> : tensor<1x128xi32> loc(#loc1) + %cst_4 = arith.constant dense<0.000000e+00> : tensor<4x128xf32> loc(#loc1) + %x2 = arith.constant dense<65536> : tensor<4x1xi32> loc(#loc43) + %x1 = arith.constant dense<32> : tensor<4x1xi32> loc(#loc44) + %cst_5 = arith.constant dense<2048> : tensor<4x1xi32> loc(#loc1) + %c4_i32 = arith.constant 4 : i32 loc(#loc1) + %xoffset = tt.get_program_id x : i32 loc(#loc45) + %xoffset_6 = arith.muli %xoffset, %c4_i32 : i32 loc(#loc46) + %xindex = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> loc(#loc47) + %xindex_7 = tt.expand_dims %xindex {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> loc(#loc48) + %xindex_8 = tt.splat %xoffset_6 : i32 -> tensor<4x1xi32> loc(#loc49) + %xindex_9 = arith.addi %xindex_8, %xindex_7 : tensor<4x1xi32> loc(#loc49) + %r0_base = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc50) + %r0_base_10 = tt.expand_dims %r0_base {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc51) + %x0 = arith.remsi %xindex_9, %cst_5 : tensor<4x1xi32> loc(#loc52) + %x1_11 = arith.divsi %xindex_9, %cst_5 : tensor<4x1xi32> loc(#loc53) + %x1_12 = arith.remsi %x1_11, %x1 : tensor<4x1xi32> loc(#loc44) + %x2_13 = arith.divsi %xindex_9, %x2 : tensor<4x1xi32> loc(#loc43) + %r0_mask = arith.cmpi slt, %r0_base_10, %cst_3 : tensor<1x128xi32> loc(#loc54) + %tmp0 = arith.muli %x1_12, %cst_2 : tensor<4x1xi32> loc(#loc55) + %tmp0_14 = tt.broadcast %r0_base_10 : tensor<1x128xi32> -> tensor<4x128xi32> loc(#loc56) + %tmp0_15 = tt.broadcast %tmp0 : tensor<4x1xi32> -> tensor<4x128xi32> loc(#loc56) + %tmp0_16 = arith.addi %tmp0_14, %tmp0_15 : tensor<4x128xi32> loc(#loc56) + %tmp0_17 = arith.muli %x0, %cst_1 : tensor<4x1xi32> loc(#loc57) + %tmp0_18 = tt.broadcast %tmp0_17 : tensor<4x1xi32> -> tensor<4x128xi32> loc(#loc58) + %tmp0_19 = arith.addi %tmp0_16, %tmp0_18 : tensor<4x128xi32> loc(#loc58) + %tmp0_20 = arith.muli %x2_13, %cst_0 : tensor<4x1xi32> loc(#loc59) + %tmp0_21 = tt.broadcast %tmp0_20 : tensor<4x1xi32> -> tensor<4x128xi32> loc(#loc60) + %tmp0_22 = arith.addi %tmp0_19, %tmp0_21 : tensor<4x128xi32> loc(#loc60) + %tmp0_23 = tt.splat %in_ptr0 : !tt.ptr -> tensor<4x128x!tt.ptr> loc(#loc61) + %tmp0_24 = tt.addptr %tmp0_23, %tmp0_22 : tensor<4x128x!tt.ptr>, tensor<4x128xi32> loc(#loc61) + %tmp0_25 = tt.broadcast %r0_mask : tensor<1x128xi1> -> tensor<4x128xi1> loc(#loc62) + %tmp0_26 = tt.load %tmp0_24, %tmp0_25, %cst evictionPolicy = evict_first : tensor<4x128x!tt.ptr> loc(#loc62) + %tmp0_27 = arith.extf %tmp0_26 : tensor<4x128xbf16> to tensor<4x128xf32> loc(#loc63) + %tmp1 = arith.muli %xindex_9, %cst_2 : tensor<4x1xi32> loc(#loc64) + %tmp1_28 = tt.broadcast %tmp1 : tensor<4x1xi32> -> tensor<4x128xi32> loc(#loc65) + %tmp1_29 = arith.addi %tmp0_14, %tmp1_28 : tensor<4x128xi32> loc(#loc65) + %tmp1_30 = tt.splat %in_ptr1 : !tt.ptr -> tensor<4x128x!tt.ptr> loc(#loc66) + %tmp1_31 = tt.addptr %tmp1_30, %tmp1_29 : tensor<4x128x!tt.ptr>, tensor<4x128xi32> loc(#loc66) + %tmp1_32 = tt.load %tmp1_31, %tmp0_25, %cst evictionPolicy = evict_first : tensor<4x128x!tt.ptr> loc(#loc67) + %tmp1_33 = arith.extf %tmp1_32 : tensor<4x128xbf16> to tensor<4x128xf32> loc(#loc68) + %tmp2 = arith.mulf %tmp0_27, %tmp1_33 : tensor<4x128xf32> loc(#loc69) + %tmp5 = arith.addf %tmp2, %cst_4 : tensor<4x128xf32> loc(#loc70) + %_tmp4 = arith.select %tmp0_25, %tmp5, %cst_4 : tensor<4x128xi1>, tensor<4x128xf32> loc(#loc71) + %tmp4 = "tt.reduce"(%_tmp4) <{axis = 1 : i32}> ({ + ^bb0(%tmp4_35: f32 loc(callsite(#loc1 at #loc72)), %tmp4_36: f32 loc(callsite(#loc1 at #loc72))): + %tmp4_37 = arith.addf %tmp4_35, %tmp4_36 : f32 loc(#loc76) + tt.reduce.return %tmp4_37 : f32 loc(#loc74) + }) : (tensor<4x128xf32>) -> tensor<4xf32> loc(#loc74) + %tmp4_34 = tt.expand_dims %tmp4 {axis = 1 : i32} : tensor<4xf32> -> tensor<4x1xf32> loc(#loc73) + %0 = tt.splat %out_ptr1 : !tt.ptr -> tensor<4x1x!tt.ptr> loc(#loc35) + %1 = tt.addptr %0, %xindex_9 : tensor<4x1x!tt.ptr>, tensor<4x1xi32> loc(#loc35) + tt.store %1, %tmp4_34 : tensor<4x1x!tt.ptr> loc(#loc36) + tt.return loc(#loc37) + } loc(#loc) +} loc(#loc) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":30:19) +#loc3 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":29:29) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":23:28) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":23:33) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":24:36) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":24:44) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":24:23) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":26:27) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":26:37) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":28:19) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":29:21) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":35:29) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:45) +#loc15 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:41) +#loc16 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:55) +#loc17 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:50) +#loc18 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:68) +#loc19 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:60) +#loc20 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:34) +#loc21 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:73) +#loc22 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":39:127) +#loc23 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":40:45) +#loc24 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":40:41) +#loc25 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":40:34) +#loc26 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":40:50) +#loc27 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":40:104) +#loc28 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":41:22) +#loc29 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":43:23) +#loc30 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":44:40) +#loc31 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:36) +#loc33 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:15) +#loc34 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":45:28) +#loc35 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":49:25) +#loc36 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":49:36) +#loc37 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py":49:4) +#loc43 = loc("x2"(#loc2)) +#loc44 = loc("x1"(#loc3)) +#loc45 = loc("xoffset"(#loc4)) +#loc46 = loc("xoffset"(#loc5)) +#loc47 = loc("xindex"(#loc6)) +#loc48 = loc("xindex"(#loc7)) +#loc49 = loc("xindex"(#loc8)) +#loc50 = loc("r0_base"(#loc9)) +#loc51 = loc("r0_base"(#loc10)) +#loc52 = loc("x0"(#loc11)) +#loc53 = loc("x1"(#loc12)) +#loc54 = loc("r0_mask"(#loc13)) +#loc55 = loc("tmp0"(#loc14)) +#loc56 = loc("tmp0"(#loc15)) +#loc57 = loc("tmp0"(#loc16)) +#loc58 = loc("tmp0"(#loc17)) +#loc59 = loc("tmp0"(#loc18)) +#loc60 = loc("tmp0"(#loc19)) +#loc61 = loc("tmp0"(#loc20)) +#loc62 = loc("tmp0"(#loc21)) +#loc63 = loc("tmp0"(#loc22)) +#loc64 = loc("tmp1"(#loc23)) +#loc65 = loc("tmp1"(#loc24)) +#loc66 = loc("tmp1"(#loc25)) +#loc67 = loc("tmp1"(#loc26)) +#loc68 = loc("tmp1"(#loc27)) +#loc69 = loc("tmp2"(#loc28)) +#loc70 = loc("tmp5"(#loc29)) +#loc71 = loc("_tmp4"(#loc30)) +#loc73 = loc("tmp4"(#loc34)) +#loc74 = loc(callsite(#loc31 at #loc72)) +#loc76 = loc(callsite(#loc33 at #loc74)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/EFSFIQ3KOQUXLGL22JESW5DLVY6LJWVBODRP6EOX4ISS2B62HCMA/__grp__triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.json b/SpecForge-ext/cache/compiled_kernels/triton/0/EFSFIQ3KOQUXLGL22JESW5DLVY6LJWVBODRP6EOX4ISS2B62HCMA/__grp__triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.json new file mode 100644 index 0000000000000000000000000000000000000000..4461860d9c1c7e163aa02862aa86adbf20b80fa7 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/EFSFIQ3KOQUXLGL22JESW5DLVY6LJWVBODRP6EOX4ISS2B62HCMA/__grp__triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.json @@ -0,0 +1 @@ +{"child_paths": {"triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.source": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/EFSFIQ3KOQUXLGL22JESW5DLVY6LJWVBODRP6EOX4ISS2B62HCMA/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.source", "triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ttir": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/EFSFIQ3KOQUXLGL22JESW5DLVY6LJWVBODRP6EOX4ISS2B62HCMA/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ttir", "triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ttgir": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/EFSFIQ3KOQUXLGL22JESW5DLVY6LJWVBODRP6EOX4ISS2B62HCMA/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ttgir", "triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.llir": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/EFSFIQ3KOQUXLGL22JESW5DLVY6LJWVBODRP6EOX4ISS2B62HCMA/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.llir", "triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ptx": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/EFSFIQ3KOQUXLGL22JESW5DLVY6LJWVBODRP6EOX4ISS2B62HCMA/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ptx", "triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.cubin": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/EFSFIQ3KOQUXLGL22JESW5DLVY6LJWVBODRP6EOX4ISS2B62HCMA/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.cubin", "triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.json": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/EFSFIQ3KOQUXLGL22JESW5DLVY6LJWVBODRP6EOX4ISS2B62HCMA/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.json"}} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/EFSFIQ3KOQUXLGL22JESW5DLVY6LJWVBODRP6EOX4ISS2B62HCMA/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.cubin b/SpecForge-ext/cache/compiled_kernels/triton/0/EFSFIQ3KOQUXLGL22JESW5DLVY6LJWVBODRP6EOX4ISS2B62HCMA/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.cubin new file mode 100644 index 0000000000000000000000000000000000000000..132394623f0c7fb705d29e825c34b9df97cbdba4 Binary files /dev/null and b/SpecForge-ext/cache/compiled_kernels/triton/0/EFSFIQ3KOQUXLGL22JESW5DLVY6LJWVBODRP6EOX4ISS2B62HCMA/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.cubin differ diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/EFSFIQ3KOQUXLGL22JESW5DLVY6LJWVBODRP6EOX4ISS2B62HCMA/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.json b/SpecForge-ext/cache/compiled_kernels/triton/0/EFSFIQ3KOQUXLGL22JESW5DLVY6LJWVBODRP6EOX4ISS2B62HCMA/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.json new file mode 100644 index 0000000000000000000000000000000000000000..ffc570e253857448771e8107fc02cc3a37b0af51 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/EFSFIQ3KOQUXLGL22JESW5DLVY6LJWVBODRP6EOX4ISS2B62HCMA/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.json @@ -0,0 +1 @@ +{"hash": "216454436a742975997ad2492b746bae3cb4daa170e2ff11d7e2252d07da3898", "target": {"backend": "cuda", "arch": 90, "warp_size": 32}, "num_warps": 4, "num_ctas": 1, "num_stages": 1, "warp_size": 32, "maxnreg": null, "cluster_dims": [1, 1, 1], "ptx_version": null, "ptx_options": null, "ir_override": null, "enable_fp_fusion": true, "launch_cooperative_grid": false, "launch_pdl": false, "supported_fp8_dtypes": ["fp8e4b15", "fp8e4nv", "fp8e5"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b15"], "default_dot_input_precision": "tf32", "allowed_dot_input_precisions": ["tf32", "tf32x3", "ieee"], "max_num_imprecise_acc_default": 1073741824, "extern_libs": [["libdevice", "/workspace/specforge/lib/python3.11/site-packages/triton/backends/nvidia/lib/libdevice.10.bc"]], "debug": true, "backend_name": "cuda", "sanitize_overflow": false, "arch": "sm90", "instrumentation_mode": "", "triton_version": "3.5.1", "tensordesc_meta": [], "shared": 4096, "tmem_size": 0, "global_scratch_size": 0, "global_scratch_align": 1, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/EFSFIQ3KOQUXLGL22JESW5DLVY6LJWVBODRP6EOX4ISS2B62HCMA/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.llir b/SpecForge-ext/cache/compiled_kernels/triton/0/EFSFIQ3KOQUXLGL22JESW5DLVY6LJWVBODRP6EOX4ISS2B62HCMA/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.llir new file mode 100644 index 0000000000000000000000000000000000000000..5094f17e9fb033d6ea649d5e4d9d1855e53ec591 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/EFSFIQ3KOQUXLGL22JESW5DLVY6LJWVBODRP6EOX4ISS2B62HCMA/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.llir @@ -0,0 +1,1533 @@ +; ModuleID = 'LLVMDialectModule' +source_filename = "LLVMDialectModule" +target datalayout = "e-p3:32:32-p4:32:32-p5:32:32-p6:32:32-p7:32:32-i64:64-i128:128-v16:16-v32:32-n16:32:64" + +@assertFunc_1 = internal constant [8 x i8] c"unknown\00" +@assertFile_1 = internal constant [114 x i8] c"/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py\00" +@assertMessage_1 = internal constant [37 x i8] c"index out of bounds: 0 <= tmp49 < 17\00" +@assertFunc_0 = internal constant [8 x i8] c"unknown\00" +@assertFile_0 = internal constant [114 x i8] c"/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py\00" +@assertMessage_0 = internal constant [37 x i8] c"index out of bounds: 0 <= tmp40 < 17\00" +@global_smem = external local_unnamed_addr addrspace(3) global [0 x i8], align 16 + +; Function Attrs: noreturn +declare !dbg !5 void @__assertfail(ptr, ptr, i32, ptr, i64) local_unnamed_addr #0 + +define ptx_kernel void @triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, ptr addrspace(1) %3, ptr addrspace(1) %4, ptr addrspace(1) %5, ptr addrspace(1) %6, i32 %7, i32 %8, ptr addrspace(1) readnone captures(none) %9, ptr addrspace(1) readnone captures(none) %10) local_unnamed_addr #1 !dbg !9 { + %12 = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !dbg !10 + %13 = shl i32 %12, 5, !dbg !11 + %14 = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !dbg !12 + %15 = and i32 %14, 124, !dbg !12 + %16 = lshr exact i32 %15, 2, !dbg !12 + %17 = and i32 %14, 31, !dbg !12 + %18 = or disjoint i32 %16, %13, !dbg !13 + %19 = icmp slt i32 %18, 128, !dbg !14 + %20 = and i32 %14, 3, !dbg !15 + %21 = shl i32 %18, 4, !dbg !16 + %22 = and i32 %14, 1, !dbg !17 + %23 = lshr i32 %14, 1, !dbg !17 + %.lobit = and i32 %23, 1, !dbg !17 + %24 = xor i32 %22, 1, !dbg !21 + %25 = xor i32 %.lobit, 1, !dbg !21 + %26 = trunc i32 %14 to i1, !dbg !22 + %27 = trunc i32 %23 to i1, !dbg !22 + %28 = insertelement <2 x i1> poison, i1 %27, i64 0, !dbg !23 + %29 = shufflevector <2 x i1> %28, <2 x i1> poison, <2 x i32> zeroinitializer, !dbg !23 + %30 = insertelement <4 x i32> poison, i32 %25, i64 0, !dbg !25 + %31 = shufflevector <4 x i32> %30, <4 x i32> poison, <4 x i32> zeroinitializer, !dbg !25 + %32 = insertelement <4 x i32> poison, i32 %.lobit, i64 0, !dbg !26 + %33 = shufflevector <4 x i32> %32, <4 x i32> poison, <4 x i32> zeroinitializer, !dbg !26 + %34 = insertelement <4 x i32> poison, i32 %24, i64 0, !dbg !27 + %35 = shufflevector <4 x i32> %34, <4 x i32> poison, <4 x i32> zeroinitializer, !dbg !27 + %36 = insertelement <4 x i32> poison, i32 %22, i64 0, !dbg !28 + %37 = shufflevector <4 x i32> %36, <4 x i32> poison, <4 x i32> zeroinitializer, !dbg !28 + %38 = insertelement <2 x i1> poison, i1 %26, i64 0, !dbg !22 + %39 = shufflevector <2 x i1> %38, <2 x i1> poison, <2 x i32> zeroinitializer, !dbg !22 + %40 = insertelement <2 x i32> poison, i32 %24, i64 0, !dbg !29 + %41 = shufflevector <2 x i32> %40, <2 x i32> poison, <2 x i32> zeroinitializer, !dbg !29 + %42 = insertelement <2 x i32> poison, i32 %22, i64 0, !dbg !30 + %43 = shufflevector <2 x i32> %42, <2 x i32> poison, <2 x i32> zeroinitializer, !dbg !30 + %44 = insertelement <2 x i32> %42, i32 %24, i64 1, !dbg !31 + %45 = insertelement <2 x i32> %40, i32 %22, i64 1, !dbg !32 + %46 = insertelement <4 x i1> poison, i1 %19, i64 0, !dbg !33 + %47 = shufflevector <4 x i1> %46, <4 x i1> poison, <4 x i32> zeroinitializer, !dbg !33 + %48 = shl nuw nsw i32 %15, 1, !dbg !34 + %49 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %48, !dbg !34 + %50 = shl nuw nsw i32 %17, 3, !dbg !34 + %51 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %50, !dbg !34 + %52 = shl nuw nsw i32 %20, 2, !dbg !15 + %53 = or disjoint i32 %52, 1, !dbg !15 + %54 = or disjoint i32 %52, 2, !dbg !15 + %55 = or disjoint i32 %52, 3, !dbg !15 + %56 = or disjoint i32 %21, %52, !dbg !35 + %57 = or disjoint i32 %21, %54, !dbg !35 + %58 = sext i32 %56 to i64, !dbg !36 + %59 = getelementptr i64, ptr addrspace(1) %0, i64 %58, !dbg !36 + %60 = sext i32 %57 to i64, !dbg !36 + %61 = getelementptr i64, ptr addrspace(1) %0, i64 %60, !dbg !36 + %62 = tail call { i64, i64 } asm sideeffect "mov.u64 $0, 0x0;\0A\09mov.u64 $1, 0x0;\0A\09@$3 ld.global.v2.b64 { $0, $1 }, [ $2 + 0 ];", "=l,=l,l,b"(ptr addrspace(1) %59, i1 %19) #6, !dbg !37 + %63 = extractvalue { i64, i64 } %62, 0, !dbg !37 + %64 = extractvalue { i64, i64 } %62, 1, !dbg !37 + %65 = tail call { i64, i64 } asm sideeffect "mov.u64 $0, 0x0;\0A\09mov.u64 $1, 0x0;\0A\09@$3 ld.global.v2.b64 { $0, $1 }, [ $2 + 0 ];", "=l,=l,l,b"(ptr addrspace(1) %61, i1 %19) #6, !dbg !37 + %66 = extractvalue { i64, i64 } %65, 0, !dbg !37 + %67 = extractvalue { i64, i64 } %65, 1, !dbg !37 + %68 = xor i32 %53, %52, !dbg !38 + %69 = xor i32 %54, %55, !dbg !38 + %70 = insertelement <4 x i64> poison, i64 %63, i64 0, !dbg !39 + %71 = insertelement <4 x i64> %70, i64 %64, i64 1, !dbg !39 + %72 = insertelement <4 x i64> %71, i64 %66, i64 2, !dbg !39 + %73 = insertelement <4 x i64> %72, i64 %67, i64 3, !dbg !39 + %74 = add <4 x i64> %73, splat (i64 -1), !dbg !39 + %75 = icmp ult <4 x i64> %74, splat (i64 16383), !dbg !39 + %76 = extractelement <4 x i1> %75, i64 0, !dbg !40 + %77 = zext i1 %76 to i32, !dbg !41 + %78 = extractelement <4 x i1> %75, i64 1, !dbg !40 + %79 = zext i1 %78 to i32, !dbg !41 + %80 = extractelement <4 x i1> %75, i64 2, !dbg !42 + %81 = zext i1 %80 to i32, !dbg !41 + %82 = extractelement <4 x i1> %75, i64 3, !dbg !42 + %83 = zext i1 %82 to i32, !dbg !41 + %84 = xor i1 %76, true, !dbg !40 + %85 = and i1 %78, %84, !dbg !40 + %.not = xor i1 %82, true, !dbg !42 + %86 = or i1 %80, %.not, !dbg !42 + %87 = xor i32 %77, %79, !dbg !43 + %88 = xor i32 %81, %83, !dbg !43 + %89 = select i1 %85, i32 %87, i32 0, !dbg !44 + %90 = select i1 %86, i32 %88, i32 0, !dbg !44 + %91 = xor i32 %89, %77, !dbg !45 + %92 = xor i32 %89, %79, !dbg !45 + %93 = xor i32 %90, %81, !dbg !45 + %94 = xor i32 %90, %83, !dbg !45 + %95 = select i1 %85, i32 %68, i32 0, !dbg !46 + %96 = select i1 %86, i32 %69, i32 0, !dbg !46 + %97 = xor i32 %95, %52, !dbg !47 + %98 = xor i32 %95, %53, !dbg !47 + %99 = xor i32 %96, %54, !dbg !47 + %100 = xor i32 %96, %55, !dbg !47 + %101 = insertelement <2 x i32> poison, i32 %91, i64 0, !dbg !22 + %102 = insertelement <2 x i32> %101, i32 %92, i64 1, !dbg !22 + %103 = insertelement <2 x i32> poison, i32 %93, i64 0, !dbg !22 + %104 = insertelement <2 x i32> %103, i32 %94, i64 1, !dbg !22 + %105 = icmp samesign uge <2 x i32> %102, %104, !dbg !22 + %106 = insertelement <2 x i32> %104, i32 %92, i64 1, !dbg !22 + %107 = insertelement <2 x i32> %102, i32 %94, i64 1, !dbg !22 + %108 = icmp ne <2 x i32> %106, %107, !dbg !22 + %109 = insertelement <2 x i32> poison, i32 %97, i64 0, !dbg !22 + %110 = insertelement <2 x i32> %109, i32 %98, i64 1, !dbg !22 + %111 = insertelement <2 x i32> poison, i32 %99, i64 0, !dbg !22 + %112 = insertelement <2 x i32> %111, i32 %100, i64 1, !dbg !22 + %113 = icmp ule <2 x i32> %110, %112, !dbg !22 + %114 = or <2 x i1> %113, %108, !dbg !22 + %115 = and <2 x i1> %105, %114, !dbg !22 + %116 = xor <2 x i1> %115, %39, !dbg !22 + %117 = xor <2 x i32> %107, %106, !dbg !43 + %118 = select <2 x i1> %116, <2 x i32> zeroinitializer, <2 x i32> %117, !dbg !44 + %119 = extractelement <2 x i32> %118, i64 0, !dbg !45 + %120 = xor i32 %119, %91, !dbg !45 + %121 = extractelement <2 x i32> %118, i64 1, !dbg !45 + %122 = xor i32 %121, %92, !dbg !45 + %123 = xor <2 x i32> %118, %104, !dbg !45 + %124 = xor i32 %99, %97, !dbg !38 + %125 = xor i32 %100, %98, !dbg !38 + %126 = extractelement <2 x i1> %116, i64 0, !dbg !46 + %127 = select i1 %126, i32 0, i32 %124, !dbg !46 + %128 = extractelement <2 x i1> %116, i64 1, !dbg !46 + %129 = select i1 %128, i32 0, i32 %125, !dbg !46 + %130 = xor i32 %127, %97, !dbg !47 + %131 = xor i32 %129, %98, !dbg !47 + %132 = xor i32 %127, %99, !dbg !47 + %133 = xor i32 %129, %100, !dbg !47 + %134 = icmp samesign uge i32 %120, %122, !dbg !22 + %135 = icmp ne i32 %120, %122, !dbg !22 + %136 = icmp samesign ule i32 %130, %131, !dbg !22 + %137 = or i1 %135, %136, !dbg !22 + %138 = and i1 %134, %137, !dbg !22 + %.not3 = xor i1 %138, %26, !dbg !22 + %139 = extractelement <2 x i32> %123, i64 0, !dbg !22 + %140 = extractelement <2 x i32> %123, i64 1, !dbg !22 + %141 = icmp samesign uge i32 %139, %140, !dbg !22 + %142 = icmp ne i32 %139, %140, !dbg !22 + %143 = icmp samesign ule i32 %132, %133, !dbg !22 + %144 = or i1 %142, %143, !dbg !22 + %145 = and i1 %141, %144, !dbg !22 + %.not4 = xor i1 %145, %26, !dbg !22 + %146 = xor i32 %120, %122, !dbg !43 + %147 = xor i32 %139, %140, !dbg !43 + %148 = select i1 %.not3, i32 0, i32 %146, !dbg !44 + %149 = select i1 %.not4, i32 0, i32 %147, !dbg !44 + %150 = xor i32 %148, %120, !dbg !45 + %151 = xor i32 %148, %122, !dbg !45 + %152 = xor i32 %149, %139, !dbg !45 + %153 = xor i32 %149, %140, !dbg !45 + %154 = xor i32 %130, %131, !dbg !38 + %155 = xor i32 %132, %133, !dbg !38 + %156 = select i1 %.not3, i32 0, i32 %154, !dbg !46 + %157 = select i1 %.not4, i32 0, i32 %155, !dbg !46 + %158 = xor i32 %156, %130, !dbg !47 + %159 = xor i32 %156, %131, !dbg !47 + %160 = xor i32 %157, %132, !dbg !47 + %161 = xor i32 %157, %133, !dbg !47 + %162 = mul nuw nsw i32 %150, %24, !dbg !29 + %163 = mul nuw nsw i32 %151, %24, !dbg !29 + %164 = mul nuw nsw i32 %152, %24, !dbg !29 + %165 = mul nuw nsw i32 %153, %24, !dbg !29 + %166 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %162, i32 1, i32 31), !dbg !48 + %167 = add i32 %162, %166, !dbg !51 + %168 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %163, i32 1, i32 31), !dbg !48 + %169 = add i32 %163, %168, !dbg !51 + %170 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %164, i32 1, i32 31), !dbg !48 + %171 = add i32 %164, %170, !dbg !51 + %172 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %165, i32 1, i32 31), !dbg !48 + %173 = add i32 %165, %172, !dbg !51 + %174 = mul nuw nsw i32 %150, %22, !dbg !30 + %175 = mul nuw nsw i32 %151, %22, !dbg !30 + %176 = mul nuw nsw i32 %152, %22, !dbg !30 + %177 = mul nuw nsw i32 %153, %22, !dbg !30 + %178 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %174, i32 1, i32 31), !dbg !48 + %179 = add i32 %174, %178, !dbg !51 + %180 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %175, i32 1, i32 31), !dbg !48 + %181 = add i32 %175, %180, !dbg !51 + %182 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %176, i32 1, i32 31), !dbg !48 + %183 = add i32 %176, %182, !dbg !51 + %184 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %177, i32 1, i32 31), !dbg !48 + %185 = add i32 %177, %184, !dbg !51 + %186 = mul nuw nsw i32 %158, %24, !dbg !32 + %187 = mul nuw nsw i32 %159, %24, !dbg !32 + %188 = mul nuw nsw i32 %160, %24, !dbg !32 + %189 = mul nuw nsw i32 %161, %24, !dbg !32 + %190 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %186, i32 1, i32 31), !dbg !48 + %191 = add i32 %186, %190, !dbg !51 + %192 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %187, i32 1, i32 31), !dbg !48 + %193 = add i32 %187, %192, !dbg !51 + %194 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %188, i32 1, i32 31), !dbg !48 + %195 = add i32 %188, %194, !dbg !51 + %196 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %189, i32 1, i32 31), !dbg !48 + %197 = add i32 %189, %196, !dbg !51 + %198 = mul nuw nsw i32 %158, %22, !dbg !31 + %199 = mul nuw nsw i32 %159, %22, !dbg !31 + %200 = mul nuw nsw i32 %160, %22, !dbg !31 + %201 = mul nuw nsw i32 %161, %22, !dbg !31 + %202 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %198, i32 1, i32 31), !dbg !48 + %203 = add i32 %198, %202, !dbg !51 + %204 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %199, i32 1, i32 31), !dbg !48 + %205 = add i32 %199, %204, !dbg !51 + %206 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %200, i32 1, i32 31), !dbg !48 + %207 = add i32 %200, %206, !dbg !51 + %208 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %201, i32 1, i32 31), !dbg !48 + %209 = add i32 %201, %208, !dbg !51 + %210 = icmp sge i32 %167, %179, !dbg !22 + %211 = icmp ne i32 %167, %179, !dbg !22 + %212 = icmp sle i32 %191, %203, !dbg !22 + %213 = or i1 %211, %212, !dbg !22 + %214 = and i1 %210, %213, !dbg !22 + %.not5 = xor i1 %214, %27, !dbg !22 + %215 = icmp sge i32 %169, %181, !dbg !22 + %216 = icmp ne i32 %169, %181, !dbg !22 + %217 = icmp sle i32 %193, %205, !dbg !22 + %218 = or i1 %216, %217, !dbg !22 + %219 = and i1 %215, %218, !dbg !22 + %.not6 = xor i1 %219, %27, !dbg !22 + %220 = icmp sge i32 %171, %183, !dbg !22 + %221 = icmp ne i32 %171, %183, !dbg !22 + %222 = icmp sle i32 %195, %207, !dbg !22 + %223 = or i1 %221, %222, !dbg !22 + %224 = and i1 %220, %223, !dbg !22 + %.not7 = xor i1 %224, %27, !dbg !22 + %225 = icmp sge i32 %173, %185, !dbg !22 + %226 = icmp ne i32 %173, %185, !dbg !22 + %227 = icmp sle i32 %197, %209, !dbg !22 + %228 = or i1 %226, %227, !dbg !22 + %229 = and i1 %225, %228, !dbg !22 + %.not8 = xor i1 %229, %27, !dbg !22 + %230 = xor i32 %167, %179, !dbg !43 + %231 = xor i32 %169, %181, !dbg !43 + %232 = xor i32 %171, %183, !dbg !43 + %233 = xor i32 %173, %185, !dbg !43 + %234 = select i1 %.not5, i32 0, i32 %230, !dbg !44 + %235 = select i1 %.not6, i32 0, i32 %231, !dbg !44 + %236 = select i1 %.not7, i32 0, i32 %232, !dbg !44 + %237 = select i1 %.not8, i32 0, i32 %233, !dbg !44 + %238 = xor i32 %234, %150, !dbg !45 + %239 = xor i32 %235, %151, !dbg !45 + %240 = xor i32 %236, %152, !dbg !45 + %241 = xor i32 %237, %153, !dbg !45 + %242 = xor i32 %191, %203, !dbg !38 + %243 = xor i32 %193, %205, !dbg !38 + %244 = xor i32 %195, %207, !dbg !38 + %245 = xor i32 %197, %209, !dbg !38 + %246 = select i1 %.not5, i32 0, i32 %242, !dbg !46 + %247 = select i1 %.not6, i32 0, i32 %243, !dbg !46 + %248 = select i1 %.not7, i32 0, i32 %244, !dbg !46 + %249 = select i1 %.not8, i32 0, i32 %245, !dbg !46 + %250 = xor i32 %246, %158, !dbg !47 + %251 = xor i32 %247, %159, !dbg !47 + %252 = xor i32 %248, %160, !dbg !47 + %253 = xor i32 %249, %161, !dbg !47 + %254 = icmp sge i32 %238, %240, !dbg !22 + %255 = icmp ne i32 %238, %240, !dbg !22 + %256 = icmp sle i32 %250, %252, !dbg !22 + %257 = or i1 %255, %256, !dbg !22 + %258 = and i1 %254, %257, !dbg !22 + %.not9 = xor i1 %258, %27, !dbg !22 + %259 = icmp sge i32 %239, %241, !dbg !22 + %260 = icmp ne i32 %239, %241, !dbg !22 + %261 = icmp sle i32 %251, %253, !dbg !22 + %262 = or i1 %260, %261, !dbg !22 + %263 = and i1 %259, %262, !dbg !22 + %.not10 = xor i1 %263, %27, !dbg !22 + %264 = xor i32 %238, %240, !dbg !43 + %265 = xor i32 %239, %241, !dbg !43 + %266 = select i1 %.not9, i32 0, i32 %264, !dbg !44 + %267 = select i1 %.not10, i32 0, i32 %265, !dbg !44 + %268 = xor i32 %266, %238, !dbg !45 + %269 = xor i32 %267, %239, !dbg !45 + %270 = xor i32 %266, %240, !dbg !45 + %271 = xor i32 %267, %241, !dbg !45 + %272 = xor i32 %250, %252, !dbg !38 + %273 = xor i32 %251, %253, !dbg !38 + %274 = select i1 %.not9, i32 0, i32 %272, !dbg !46 + %275 = select i1 %.not10, i32 0, i32 %273, !dbg !46 + %276 = xor i32 %274, %250, !dbg !47 + %277 = xor i32 %275, %251, !dbg !47 + %278 = xor i32 %274, %252, !dbg !47 + %279 = xor i32 %275, %253, !dbg !47 + %280 = icmp sge i32 %268, %269, !dbg !22 + %281 = icmp ne i32 %268, %269, !dbg !22 + %282 = icmp sle i32 %276, %277, !dbg !22 + %283 = or i1 %281, %282, !dbg !22 + %284 = and i1 %280, %283, !dbg !22 + %.not11 = xor i1 %284, %27, !dbg !22 + %285 = icmp sge i32 %270, %271, !dbg !22 + %286 = icmp ne i32 %270, %271, !dbg !22 + %287 = icmp sle i32 %278, %279, !dbg !22 + %288 = or i1 %286, %287, !dbg !22 + %289 = and i1 %285, %288, !dbg !22 + %.not12 = xor i1 %289, %27, !dbg !22 + %290 = xor i32 %268, %269, !dbg !43 + %291 = xor i32 %270, %271, !dbg !43 + %292 = select i1 %.not11, i32 0, i32 %290, !dbg !44 + %293 = select i1 %.not12, i32 0, i32 %291, !dbg !44 + %294 = xor i32 %276, %277, !dbg !38 + %295 = xor i32 %278, %279, !dbg !38 + %296 = select i1 %.not11, i32 0, i32 %294, !dbg !46 + %297 = select i1 %.not12, i32 0, i32 %295, !dbg !46 + %298 = xor i32 %296, %276, !dbg !47 + %299 = xor i32 %296, %277, !dbg !47 + %300 = xor i32 %297, %278, !dbg !47 + %301 = xor i32 %297, %279, !dbg !47 + %302 = mul nuw nsw i32 %298, %25, !dbg !32 + %303 = mul nuw nsw i32 %299, %25, !dbg !32 + %304 = mul nuw nsw i32 %300, %25, !dbg !32 + %305 = mul nuw nsw i32 %301, %25, !dbg !32 + %306 = mul nuw nsw i32 %298, %.lobit, !dbg !31 + %307 = mul nuw nsw i32 %299, %.lobit, !dbg !31 + %308 = mul nuw nsw i32 %300, %.lobit, !dbg !31 + %309 = mul nuw nsw i32 %301, %.lobit, !dbg !31 + %310 = insertelement <2 x i32> poison, i32 %292, i64 0, !dbg !45 + %311 = shufflevector <2 x i32> %310, <2 x i32> poison, <2 x i32> zeroinitializer, !dbg !45 + %312 = insertelement <2 x i32> poison, i32 %269, i64 0, !dbg !45 + %313 = insertelement <2 x i32> %312, i32 %268, i64 1, !dbg !45 + %314 = xor <2 x i32> %311, %313, !dbg !45 + %315 = insertelement <2 x i32> poison, i32 %293, i64 0, !dbg !45 + %316 = shufflevector <2 x i32> %315, <2 x i32> poison, <2 x i32> zeroinitializer, !dbg !45 + %317 = insertelement <2 x i32> poison, i32 %271, i64 0, !dbg !45 + %318 = insertelement <2 x i32> %317, i32 %270, i64 1, !dbg !45 + %319 = xor <2 x i32> %316, %318, !dbg !45 + %320 = extractelement <2 x i32> %314, i64 1, !dbg !30 + %321 = mul nuw nsw i32 %320, %25, !dbg !29 + %322 = extractelement <2 x i32> %314, i64 0, !dbg !30 + %323 = mul nuw nsw i32 %322, %25, !dbg !29 + %324 = extractelement <2 x i32> %319, i64 1, !dbg !30 + %325 = mul nuw nsw i32 %324, %25, !dbg !29 + %326 = extractelement <2 x i32> %319, i64 0, !dbg !30 + %327 = mul nuw nsw i32 %326, %25, !dbg !29 + %328 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %321, i32 2, i32 31), !dbg !48 + %329 = add i32 %321, %328, !dbg !51 + %330 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %323, i32 2, i32 31), !dbg !48 + %331 = add i32 %323, %330, !dbg !51 + %332 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %325, i32 2, i32 31), !dbg !48 + %333 = add i32 %325, %332, !dbg !51 + %334 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %327, i32 2, i32 31), !dbg !48 + %335 = add i32 %327, %334, !dbg !51 + %336 = mul nuw nsw i32 %320, %.lobit, !dbg !30 + %337 = mul nuw nsw i32 %322, %.lobit, !dbg !30 + %338 = mul nuw nsw i32 %324, %.lobit, !dbg !30 + %339 = mul nuw nsw i32 %326, %.lobit, !dbg !30 + %340 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %336, i32 2, i32 31), !dbg !48 + %341 = add i32 %336, %340, !dbg !51 + %342 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %337, i32 2, i32 31), !dbg !48 + %343 = add i32 %337, %342, !dbg !51 + %344 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %338, i32 2, i32 31), !dbg !48 + %345 = add i32 %338, %344, !dbg !51 + %346 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %339, i32 2, i32 31), !dbg !48 + %347 = add i32 %339, %346, !dbg !51 + %348 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %302, i32 2, i32 31), !dbg !48 + %349 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %303, i32 2, i32 31), !dbg !48 + %350 = insertelement <2 x i32> poison, i32 %303, i64 0, !dbg !51 + %351 = insertelement <2 x i32> %350, i32 %302, i64 1, !dbg !51 + %352 = insertelement <2 x i32> poison, i32 %349, i64 0, !dbg !51 + %353 = insertelement <2 x i32> %352, i32 %348, i64 1, !dbg !51 + %354 = add <2 x i32> %351, %353, !dbg !51 + %355 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %304, i32 2, i32 31), !dbg !48 + %356 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %305, i32 2, i32 31), !dbg !48 + %357 = insertelement <2 x i32> poison, i32 %305, i64 0, !dbg !51 + %358 = insertelement <2 x i32> %357, i32 %304, i64 1, !dbg !51 + %359 = insertelement <2 x i32> poison, i32 %356, i64 0, !dbg !51 + %360 = insertelement <2 x i32> %359, i32 %355, i64 1, !dbg !51 + %361 = add <2 x i32> %358, %360, !dbg !51 + %362 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %306, i32 2, i32 31), !dbg !48 + %363 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %307, i32 2, i32 31), !dbg !48 + %364 = insertelement <2 x i32> poison, i32 %307, i64 0, !dbg !51 + %365 = insertelement <2 x i32> %364, i32 %306, i64 1, !dbg !51 + %366 = insertelement <2 x i32> poison, i32 %363, i64 0, !dbg !51 + %367 = insertelement <2 x i32> %366, i32 %362, i64 1, !dbg !51 + %368 = add <2 x i32> %365, %367, !dbg !51 + %369 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %308, i32 2, i32 31), !dbg !48 + %370 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %309, i32 2, i32 31), !dbg !48 + %371 = insertelement <2 x i32> poison, i32 %309, i64 0, !dbg !51 + %372 = insertelement <2 x i32> %371, i32 %308, i64 1, !dbg !51 + %373 = insertelement <2 x i32> poison, i32 %370, i64 0, !dbg !51 + %374 = insertelement <2 x i32> %373, i32 %369, i64 1, !dbg !51 + %375 = add <2 x i32> %372, %374, !dbg !51 + %376 = insertelement <2 x i32> poison, i32 %331, i64 0, !dbg !40 + %377 = insertelement <2 x i32> %376, i32 %329, i64 1, !dbg !40 + %378 = insertelement <2 x i32> poison, i32 %343, i64 0, !dbg !40 + %379 = insertelement <2 x i32> %378, i32 %341, i64 1, !dbg !40 + %380 = icmp slt <2 x i32> %377, %379, !dbg !40 + %381 = icmp eq <2 x i32> %377, %379, !dbg !52 + %382 = icmp sgt <2 x i32> %354, %368, !dbg !53 + %383 = and <2 x i1> %381, %382, !dbg !54 + %384 = or <2 x i1> %380, %383, !dbg !55 + %385 = insertelement <2 x i32> poison, i32 %335, i64 0, !dbg !40 + %386 = insertelement <2 x i32> %385, i32 %333, i64 1, !dbg !40 + %387 = insertelement <2 x i32> poison, i32 %347, i64 0, !dbg !40 + %388 = insertelement <2 x i32> %387, i32 %345, i64 1, !dbg !40 + %389 = icmp slt <2 x i32> %386, %388, !dbg !40 + %390 = icmp eq <2 x i32> %386, %388, !dbg !52 + %391 = icmp sgt <2 x i32> %361, %375, !dbg !53 + %392 = and <2 x i1> %390, %391, !dbg !54 + %393 = or <2 x i1> %389, %392, !dbg !55 + %394 = xor <2 x i32> %377, %379, !dbg !43 + %395 = xor <2 x i32> %386, %388, !dbg !43 + %396 = select <2 x i1> %384, <2 x i32> %394, <2 x i32> zeroinitializer, !dbg !44 + %397 = select <2 x i1> %393, <2 x i32> %395, <2 x i32> zeroinitializer, !dbg !44 + %398 = xor <2 x i32> %396, %314, !dbg !45 + %399 = xor <2 x i32> %397, %319, !dbg !45 + %400 = xor <2 x i32> %354, %368, !dbg !38 + %401 = xor <2 x i32> %361, %375, !dbg !38 + %402 = select <2 x i1> %384, <2 x i32> %400, <2 x i32> zeroinitializer, !dbg !46 + %403 = select <2 x i1> %393, <2 x i32> %401, <2 x i32> zeroinitializer, !dbg !46 + %404 = insertelement <2 x i32> poison, i32 %299, i64 0, !dbg !47 + %405 = insertelement <2 x i32> %404, i32 %298, i64 1, !dbg !47 + %406 = xor <2 x i32> %402, %405, !dbg !47 + %407 = insertelement <2 x i32> poison, i32 %301, i64 0, !dbg !47 + %408 = insertelement <2 x i32> %407, i32 %300, i64 1, !dbg !47 + %409 = xor <2 x i32> %403, %408, !dbg !47 + %410 = mul nuw nsw <2 x i32> %398, %41, !dbg !29 + %411 = mul nuw nsw <2 x i32> %399, %41, !dbg !29 + %412 = extractelement <2 x i32> %410, i64 1, !dbg !48 + %413 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %412, i32 1, i32 31), !dbg !48 + %414 = extractelement <2 x i32> %410, i64 0, !dbg !48 + %415 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %414, i32 1, i32 31), !dbg !48 + %416 = insertelement <2 x i32> poison, i32 %415, i64 0, !dbg !51 + %417 = insertelement <2 x i32> %416, i32 %413, i64 1, !dbg !51 + %418 = add <2 x i32> %410, %417, !dbg !51 + %419 = extractelement <2 x i32> %411, i64 1, !dbg !48 + %420 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %419, i32 1, i32 31), !dbg !48 + %421 = extractelement <2 x i32> %411, i64 0, !dbg !48 + %422 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %421, i32 1, i32 31), !dbg !48 + %423 = insertelement <2 x i32> poison, i32 %422, i64 0, !dbg !51 + %424 = insertelement <2 x i32> %423, i32 %420, i64 1, !dbg !51 + %425 = add <2 x i32> %411, %424, !dbg !51 + %426 = mul nuw nsw <2 x i32> %398, %43, !dbg !30 + %427 = mul nuw nsw <2 x i32> %399, %43, !dbg !30 + %428 = extractelement <2 x i32> %426, i64 1, !dbg !48 + %429 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %428, i32 1, i32 31), !dbg !48 + %430 = extractelement <2 x i32> %426, i64 0, !dbg !48 + %431 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %430, i32 1, i32 31), !dbg !48 + %432 = insertelement <2 x i32> poison, i32 %431, i64 0, !dbg !51 + %433 = insertelement <2 x i32> %432, i32 %429, i64 1, !dbg !51 + %434 = add <2 x i32> %426, %433, !dbg !51 + %435 = extractelement <2 x i32> %427, i64 1, !dbg !48 + %436 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %435, i32 1, i32 31), !dbg !48 + %437 = extractelement <2 x i32> %427, i64 0, !dbg !48 + %438 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %437, i32 1, i32 31), !dbg !48 + %439 = insertelement <2 x i32> poison, i32 %438, i64 0, !dbg !51 + %440 = insertelement <2 x i32> %439, i32 %436, i64 1, !dbg !51 + %441 = add <2 x i32> %427, %440, !dbg !51 + %442 = mul nuw nsw <2 x i32> %406, %41, !dbg !32 + %443 = extractelement <2 x i32> %442, i64 1, !dbg !48 + %444 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %443, i32 1, i32 31), !dbg !48 + %445 = mul nuw nsw <2 x i32> %409, %41, !dbg !32 + %446 = extractelement <2 x i32> %442, i64 0, !dbg !48 + %447 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %446, i32 1, i32 31), !dbg !48 + %448 = insertelement <2 x i32> poison, i32 %447, i64 0, !dbg !51 + %449 = insertelement <2 x i32> %448, i32 %444, i64 1, !dbg !51 + %450 = add <2 x i32> %442, %449, !dbg !51 + %451 = mul nuw nsw <2 x i32> %409, %44, !dbg !31 + %452 = extractelement <2 x i32> %445, i64 1, !dbg !48 + %453 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %452, i32 1, i32 31), !dbg !48 + %454 = mul nuw nsw <2 x i32> %409, %45, !dbg !32 + %455 = extractelement <2 x i32> %445, i64 0, !dbg !48 + %456 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %455, i32 1, i32 31), !dbg !48 + %457 = insertelement <2 x i32> poison, i32 %456, i64 0, !dbg !51 + %458 = insertelement <2 x i32> %457, i32 %453, i64 1, !dbg !51 + %459 = add <2 x i32> %445, %458, !dbg !51 + %460 = mul nuw nsw <2 x i32> %406, %43, !dbg !31 + %461 = mul nuw nsw <2 x i32> %409, %43, !dbg !31 + %462 = extractelement <2 x i32> %460, i64 1, !dbg !48 + %463 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %462, i32 1, i32 31), !dbg !48 + %464 = extractelement <2 x i32> %460, i64 0, !dbg !48 + %465 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %464, i32 1, i32 31), !dbg !48 + %466 = insertelement <2 x i32> poison, i32 %465, i64 0, !dbg !51 + %467 = insertelement <2 x i32> %466, i32 %463, i64 1, !dbg !51 + %468 = add <2 x i32> %460, %467, !dbg !51 + %469 = extractelement <2 x i32> %461, i64 1, !dbg !48 + %470 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %469, i32 1, i32 31), !dbg !48 + %471 = insertelement <2 x i32> %458, i32 %470, i64 1, !dbg !51 + %472 = add <2 x i32> %454, %471, !dbg !51 + %473 = extractelement <2 x i32> %461, i64 0, !dbg !48 + %474 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %473, i32 1, i32 31), !dbg !48 + %475 = insertelement <2 x i32> %471, i32 %474, i64 0, !dbg !51 + %476 = add <2 x i32> %475, %461, !dbg !51 + %477 = insertelement <2 x i32> %475, i32 %453, i64 1, !dbg !51 + %478 = add <2 x i32> %477, %451, !dbg !51 + %479 = icmp slt <2 x i32> %418, %434, !dbg !40 + %480 = icmp slt <2 x i32> %425, %441, !dbg !40 + %481 = icmp eq <2 x i32> %418, %434, !dbg !52 + %482 = icmp eq <2 x i32> %425, %441, !dbg !52 + %483 = icmp sgt <2 x i32> %450, %468, !dbg !53 + %484 = icmp sgt <2 x i32> %459, %476, !dbg !53 + %485 = and <2 x i1> %481, %483, !dbg !54 + %486 = and <2 x i1> %482, %484, !dbg !54 + %487 = or <2 x i1> %479, %485, !dbg !55 + %488 = or <2 x i1> %480, %486, !dbg !55 + %489 = xor <2 x i32> %418, %434, !dbg !43 + %490 = xor <2 x i32> %425, %441, !dbg !43 + %491 = select <2 x i1> %487, <2 x i32> %489, <2 x i32> zeroinitializer, !dbg !44 + %492 = select <2 x i1> %488, <2 x i32> %490, <2 x i32> zeroinitializer, !dbg !44 + %493 = xor <2 x i32> %491, %398, !dbg !45 + %494 = xor <2 x i32> %492, %399, !dbg !45 + %495 = xor <2 x i32> %450, %468, !dbg !38 + %496 = xor <2 x i32> %478, %472, !dbg !38 + %497 = select <2 x i1> %487, <2 x i32> %495, <2 x i32> zeroinitializer, !dbg !46 + %498 = select <2 x i1> %488, <2 x i32> %496, <2 x i32> zeroinitializer, !dbg !46 + %499 = xor <2 x i32> %497, %406, !dbg !47 + %500 = xor <2 x i32> %498, %409, !dbg !47 + %501 = icmp slt <2 x i32> %493, %494, !dbg !40 + %502 = icmp eq <2 x i32> %493, %494, !dbg !52 + %503 = icmp sgt <2 x i32> %499, %500, !dbg !53 + %504 = and <2 x i1> %502, %503, !dbg !54 + %505 = or <2 x i1> %501, %504, !dbg !55 + %506 = xor <2 x i32> %493, %494, !dbg !43 + %507 = select <2 x i1> %505, <2 x i32> %506, <2 x i32> zeroinitializer, !dbg !44 + %508 = shufflevector <2 x i32> %507, <2 x i32> poison, <2 x i32> , !dbg !45 + %509 = shufflevector <2 x i32> %493, <2 x i32> %494, <2 x i32> , !dbg !45 + %510 = xor <2 x i32> %508, %509, !dbg !45 + %511 = shufflevector <2 x i32> %507, <2 x i32> poison, <2 x i32> , !dbg !45 + %512 = shufflevector <2 x i32> %494, <2 x i32> %493, <2 x i32> , !dbg !45 + %513 = xor <2 x i32> %511, %512, !dbg !45 + %514 = shufflevector <2 x i32> %507, <2 x i32> poison, <2 x i32> zeroinitializer, !dbg !45 + %515 = shufflevector <2 x i32> %494, <2 x i32> %493, <2 x i32> , !dbg !45 + %516 = xor <2 x i32> %514, %515, !dbg !45 + %517 = shufflevector <2 x i32> %494, <2 x i32> %493, <2 x i32> , !dbg !45 + %518 = xor <2 x i32> %507, %517, !dbg !45 + %519 = xor <2 x i32> %499, %500, !dbg !38 + %520 = select <2 x i1> %505, <2 x i32> %519, <2 x i32> zeroinitializer, !dbg !46 + %521 = xor <2 x i32> %520, %499, !dbg !47 + %522 = xor <2 x i32> %520, %500, !dbg !47 + %523 = icmp slt <2 x i32> %513, %516, !dbg !40 + %524 = icmp eq <2 x i32> %518, %510, !dbg !52 + %525 = shufflevector <2 x i32> %521, <2 x i32> %522, <2 x i32> , !dbg !53 + %526 = shufflevector <2 x i32> %522, <2 x i32> %521, <2 x i32> , !dbg !53 + %527 = icmp sgt <2 x i32> %525, %526, !dbg !53 + %528 = and <2 x i1> %524, %527, !dbg !54 + %529 = or <2 x i1> %523, %528, !dbg !55 + %530 = xor <2 x i32> %526, %525, !dbg !38 + %531 = select <2 x i1> %529, <2 x i32> %530, <2 x i32> zeroinitializer, !dbg !46 + %532 = shufflevector <2 x i32> %531, <2 x i32> poison, <4 x i32> , !dbg !46 + %533 = shufflevector <2 x i32> %521, <2 x i32> %522, <4 x i32> , !dbg !47 + %534 = xor <4 x i32> %532, %533, !dbg !47 + %535 = select <4 x i1> %47, <4 x i1> %75, <4 x i1> zeroinitializer, !dbg !33 + %536 = bitcast <4 x i1> %535 to i4, !dbg !56 + %537 = tail call range(i4 0, 5) i4 @llvm.ctpop.i4(i4 %536), !dbg !56 + %538 = zext nneg i4 %537 to i64, !dbg !56 + %539 = zext nneg i4 %537 to i32, !dbg !58 + %540 = icmp eq <4 x i64> %73, splat (i64 16384), !dbg !59 + %541 = extractelement <4 x i1> %540, i64 0, !dbg !60 + %542 = zext i1 %541 to i32, !dbg !41 + %543 = extractelement <4 x i1> %540, i64 1, !dbg !60 + %544 = zext i1 %543 to i32, !dbg !41 + %545 = extractelement <4 x i1> %540, i64 2, !dbg !61 + %546 = zext i1 %545 to i32, !dbg !41 + %547 = extractelement <4 x i1> %540, i64 3, !dbg !61 + %548 = zext i1 %547 to i32, !dbg !41 + %549 = xor i1 %541, true, !dbg !60 + %550 = and i1 %543, %549, !dbg !60 + %.not13 = xor i1 %547, true, !dbg !61 + %551 = or i1 %545, %.not13, !dbg !61 + %552 = xor i32 %542, %544, !dbg !62 + %553 = xor i32 %546, %548, !dbg !62 + %554 = select i1 %550, i32 %552, i32 0, !dbg !63 + %555 = select i1 %551, i32 %553, i32 0, !dbg !63 + %556 = xor i32 %554, %542, !dbg !64 + %557 = xor i32 %554, %544, !dbg !64 + %558 = xor i32 %555, %546, !dbg !64 + %559 = xor i32 %555, %548, !dbg !64 + %560 = select i1 %550, i32 %68, i32 0, !dbg !65 + %561 = select i1 %551, i32 %69, i32 0, !dbg !65 + %562 = xor i32 %560, %52, !dbg !66 + %563 = xor i32 %560, %53, !dbg !66 + %564 = xor i32 %561, %54, !dbg !66 + %565 = xor i32 %561, %55, !dbg !66 + %566 = insertelement <2 x i32> poison, i32 %556, i64 0, !dbg !23 + %567 = insertelement <2 x i32> %566, i32 %557, i64 1, !dbg !23 + %568 = insertelement <2 x i32> poison, i32 %558, i64 0, !dbg !23 + %569 = insertelement <2 x i32> %568, i32 %559, i64 1, !dbg !23 + %570 = icmp samesign uge <2 x i32> %567, %569, !dbg !23 + %571 = insertelement <2 x i32> %569, i32 %557, i64 1, !dbg !23 + %572 = insertelement <2 x i32> %567, i32 %559, i64 1, !dbg !23 + %573 = icmp ne <2 x i32> %571, %572, !dbg !23 + %574 = insertelement <2 x i32> poison, i32 %562, i64 0, !dbg !23 + %575 = insertelement <2 x i32> %574, i32 %563, i64 1, !dbg !23 + %576 = insertelement <2 x i32> poison, i32 %564, i64 0, !dbg !23 + %577 = insertelement <2 x i32> %576, i32 %565, i64 1, !dbg !23 + %578 = icmp ule <2 x i32> %575, %577, !dbg !23 + %579 = or <2 x i1> %578, %573, !dbg !23 + %580 = and <2 x i1> %570, %579, !dbg !23 + %581 = xor <2 x i1> %580, %39, !dbg !23 + %582 = xor <2 x i32> %572, %571, !dbg !62 + %583 = select <2 x i1> %581, <2 x i32> zeroinitializer, <2 x i32> %582, !dbg !63 + %584 = extractelement <2 x i32> %583, i64 0, !dbg !64 + %585 = xor i32 %584, %556, !dbg !64 + %586 = extractelement <2 x i32> %583, i64 1, !dbg !64 + %587 = xor i32 %586, %557, !dbg !64 + %588 = xor <2 x i32> %583, %569, !dbg !64 + %589 = xor i32 %564, %562, !dbg !67 + %590 = xor i32 %565, %563, !dbg !67 + %591 = extractelement <2 x i1> %581, i64 0, !dbg !65 + %592 = select i1 %591, i32 0, i32 %589, !dbg !65 + %593 = extractelement <2 x i1> %581, i64 1, !dbg !65 + %594 = select i1 %593, i32 0, i32 %590, !dbg !65 + %595 = xor i32 %592, %562, !dbg !66 + %596 = xor i32 %594, %563, !dbg !66 + %597 = xor i32 %592, %564, !dbg !66 + %598 = xor i32 %594, %565, !dbg !66 + %599 = icmp samesign uge i32 %585, %587, !dbg !23 + %600 = icmp ne i32 %585, %587, !dbg !23 + %601 = icmp samesign ule i32 %595, %596, !dbg !23 + %602 = or i1 %600, %601, !dbg !23 + %603 = and i1 %599, %602, !dbg !23 + %.not16 = xor i1 %603, %26, !dbg !23 + %604 = extractelement <2 x i32> %588, i64 0, !dbg !23 + %605 = extractelement <2 x i32> %588, i64 1, !dbg !23 + %606 = icmp samesign uge i32 %604, %605, !dbg !23 + %607 = icmp ne i32 %604, %605, !dbg !23 + %608 = icmp samesign ule i32 %597, %598, !dbg !23 + %609 = or i1 %607, %608, !dbg !23 + %610 = and i1 %606, %609, !dbg !23 + %.not17 = xor i1 %610, %26, !dbg !23 + %611 = xor i32 %585, %587, !dbg !62 + %612 = xor i32 %604, %605, !dbg !62 + %613 = select i1 %.not16, i32 0, i32 %611, !dbg !63 + %614 = select i1 %.not17, i32 0, i32 %612, !dbg !63 + %615 = xor i32 %613, %585, !dbg !64 + %616 = xor i32 %613, %587, !dbg !64 + %617 = xor i32 %614, %604, !dbg !64 + %618 = xor i32 %614, %605, !dbg !64 + %619 = xor i32 %595, %596, !dbg !67 + %620 = xor i32 %597, %598, !dbg !67 + %621 = select i1 %.not16, i32 0, i32 %619, !dbg !65 + %622 = select i1 %.not17, i32 0, i32 %620, !dbg !65 + %623 = mul nuw nsw i32 %615, %24, !dbg !27 + %624 = mul nuw nsw i32 %616, %24, !dbg !27 + %625 = mul nuw nsw i32 %617, %24, !dbg !27 + %626 = mul nuw nsw i32 %618, %24, !dbg !27 + %627 = mul nuw nsw i32 %615, %22, !dbg !28 + %628 = mul nuw nsw i32 %616, %22, !dbg !28 + %629 = mul nuw nsw i32 %617, %22, !dbg !28 + %630 = mul nuw nsw i32 %618, %22, !dbg !28 + %631 = insertelement <4 x i32> poison, i32 %622, i64 0, !dbg !66 + %632 = insertelement <4 x i32> %631, i32 %621, i64 1, !dbg !66 + %633 = shufflevector <4 x i32> %632, <4 x i32> poison, <4 x i32> , !dbg !66 + %634 = insertelement <4 x i32> poison, i32 %595, i64 0, !dbg !66 + %635 = insertelement <4 x i32> %634, i32 %596, i64 1, !dbg !66 + %636 = insertelement <4 x i32> %635, i32 %597, i64 2, !dbg !66 + %637 = insertelement <4 x i32> %636, i32 %598, i64 3, !dbg !66 + %638 = xor <4 x i32> %633, %637, !dbg !66 + %639 = extractelement <4 x i32> %638, i64 0, !dbg !26 + %640 = mul nuw nsw i32 %639, %24, !dbg !25 + %641 = extractelement <4 x i32> %638, i64 1, !dbg !26 + %642 = mul nuw nsw i32 %641, %24, !dbg !25 + %643 = extractelement <4 x i32> %638, i64 2, !dbg !26 + %644 = mul nuw nsw i32 %643, %24, !dbg !25 + %645 = extractelement <4 x i32> %638, i64 3, !dbg !26 + %646 = mul nuw nsw i32 %645, %24, !dbg !25 + %647 = mul nuw nsw i32 %639, %22, !dbg !26 + %648 = mul nuw nsw i32 %641, %22, !dbg !26 + %649 = mul nuw nsw i32 %643, %22, !dbg !26 + %650 = mul nuw nsw i32 %645, %22, !dbg !26 + %651 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %623, i32 1, i32 31), !dbg !68 + %652 = add i32 %651, %623, !dbg !69 + %653 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %624, i32 1, i32 31), !dbg !68 + %654 = add i32 %653, %624, !dbg !69 + %655 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %625, i32 1, i32 31), !dbg !68 + %656 = add i32 %655, %625, !dbg !69 + %657 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %626, i32 1, i32 31), !dbg !68 + %658 = add i32 %657, %626, !dbg !69 + %659 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %627, i32 1, i32 31), !dbg !68 + %660 = add i32 %659, %627, !dbg !69 + %661 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %628, i32 1, i32 31), !dbg !68 + %662 = add i32 %661, %628, !dbg !69 + %663 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %629, i32 1, i32 31), !dbg !68 + %664 = add i32 %663, %629, !dbg !69 + %665 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %630, i32 1, i32 31), !dbg !68 + %666 = add i32 %665, %630, !dbg !69 + %667 = icmp sge i32 %654, %662, !dbg !23 + %668 = icmp ne i32 %654, %662, !dbg !23 + %669 = insertelement <2 x i32> poison, i32 %652, i64 0, !dbg !23 + %670 = insertelement <2 x i32> %669, i32 %656, i64 1, !dbg !23 + %671 = insertelement <2 x i32> poison, i32 %660, i64 0, !dbg !23 + %672 = insertelement <2 x i32> %671, i32 %664, i64 1, !dbg !23 + %673 = icmp sge <2 x i32> %670, %672, !dbg !23 + %674 = icmp ne <2 x i32> %670, %672, !dbg !23 + %675 = icmp sge i32 %658, %666, !dbg !23 + %676 = icmp ne i32 %658, %666, !dbg !23 + %677 = xor i32 %660, %652, !dbg !62 + %678 = xor i32 %662, %654, !dbg !62 + %679 = xor i32 %664, %656, !dbg !62 + %680 = xor i32 %666, %658, !dbg !62 + %681 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %640, i32 1, i32 31), !dbg !68 + %682 = add i32 %681, %640, !dbg !69 + %683 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %642, i32 1, i32 31), !dbg !68 + %684 = add i32 %683, %642, !dbg !69 + %685 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %644, i32 1, i32 31), !dbg !68 + %686 = add i32 %685, %644, !dbg !69 + %687 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %646, i32 1, i32 31), !dbg !68 + %688 = add i32 %687, %646, !dbg !69 + %689 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %647, i32 1, i32 31), !dbg !68 + %690 = add i32 %689, %647, !dbg !69 + %691 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %648, i32 1, i32 31), !dbg !68 + %692 = add i32 %691, %648, !dbg !69 + %693 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %649, i32 1, i32 31), !dbg !68 + %694 = add i32 %693, %649, !dbg !69 + %695 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %650, i32 1, i32 31), !dbg !68 + %696 = add i32 %695, %650, !dbg !69 + %697 = insertelement <2 x i32> poison, i32 %682, i64 0, !dbg !23 + %698 = insertelement <2 x i32> %697, i32 %686, i64 1, !dbg !23 + %699 = insertelement <2 x i32> poison, i32 %690, i64 0, !dbg !23 + %700 = insertelement <2 x i32> %699, i32 %694, i64 1, !dbg !23 + %701 = icmp sle <2 x i32> %698, %700, !dbg !23 + %702 = or <2 x i1> %674, %701, !dbg !23 + %703 = and <2 x i1> %673, %702, !dbg !23 + %704 = xor <2 x i1> %703, %29, !dbg !23 + %705 = insertelement <2 x i32> poison, i32 %677, i64 0, !dbg !63 + %706 = insertelement <2 x i32> %705, i32 %679, i64 1, !dbg !63 + %707 = select <2 x i1> %704, <2 x i32> zeroinitializer, <2 x i32> %706, !dbg !63 + %708 = insertelement <2 x i32> poison, i32 %615, i64 0, !dbg !64 + %709 = insertelement <2 x i32> %708, i32 %617, i64 1, !dbg !64 + %710 = xor <2 x i32> %707, %709, !dbg !64 + %711 = insertelement <2 x i32> poison, i32 %688, i64 0, !dbg !23 + %712 = insertelement <2 x i32> %711, i32 %684, i64 1, !dbg !23 + %713 = insertelement <2 x i32> poison, i32 %696, i64 0, !dbg !23 + %714 = insertelement <2 x i32> %713, i32 %692, i64 1, !dbg !23 + %715 = icmp sle <2 x i32> %712, %714, !dbg !23 + %716 = insertelement <2 x i1> poison, i1 %676, i64 0, !dbg !23 + %717 = insertelement <2 x i1> %716, i1 %668, i64 1, !dbg !23 + %718 = or <2 x i1> %717, %715, !dbg !23 + %719 = insertelement <2 x i1> poison, i1 %675, i64 0, !dbg !23 + %720 = insertelement <2 x i1> %719, i1 %667, i64 1, !dbg !23 + %721 = and <2 x i1> %720, %718, !dbg !23 + %722 = xor <2 x i1> %721, %29, !dbg !23 + %723 = insertelement <2 x i32> poison, i32 %680, i64 0, !dbg !63 + %724 = insertelement <2 x i32> %723, i32 %678, i64 1, !dbg !63 + %725 = select <2 x i1> %722, <2 x i32> zeroinitializer, <2 x i32> %724, !dbg !63 + %726 = insertelement <2 x i32> poison, i32 %618, i64 0, !dbg !64 + %727 = insertelement <2 x i32> %726, i32 %616, i64 1, !dbg !64 + %728 = xor <2 x i32> %725, %727, !dbg !64 + %729 = xor i32 %690, %682, !dbg !67 + %730 = xor i32 %692, %684, !dbg !67 + %731 = xor i32 %694, %686, !dbg !67 + %732 = xor i32 %696, %688, !dbg !67 + %733 = shufflevector <2 x i1> %704, <2 x i1> %722, <4 x i32> , !dbg !65 + %734 = insertelement <4 x i32> poison, i32 %729, i64 0, !dbg !65 + %735 = insertelement <4 x i32> %734, i32 %730, i64 1, !dbg !65 + %736 = insertelement <4 x i32> %735, i32 %731, i64 2, !dbg !65 + %737 = insertelement <4 x i32> %736, i32 %732, i64 3, !dbg !65 + %738 = select <4 x i1> %733, <4 x i32> zeroinitializer, <4 x i32> %737, !dbg !65 + %739 = xor <4 x i32> %738, %638, !dbg !66 + %740 = extractelement <2 x i32> %710, i64 0, !dbg !23 + %741 = extractelement <2 x i32> %710, i64 1, !dbg !23 + %742 = icmp sge i32 %740, %741, !dbg !23 + %743 = icmp ne i32 %740, %741, !dbg !23 + %shift = shufflevector <4 x i32> %739, <4 x i32> poison, <4 x i32> , !dbg !23 + %744 = icmp sle <4 x i32> %739, %shift, !dbg !23 + %745 = extractelement <4 x i1> %744, i64 0, !dbg !23 + %746 = or i1 %743, %745, !dbg !23 + %747 = and i1 %742, %746, !dbg !23 + %748 = extractelement <2 x i32> %728, i64 0, !dbg !23 + %749 = extractelement <2 x i32> %728, i64 1, !dbg !23 + %750 = icmp sge i32 %749, %748, !dbg !23 + %751 = icmp ne i32 %749, %748, !dbg !23 + %shift53 = shufflevector <4 x i32> %739, <4 x i32> poison, <4 x i32> , !dbg !23 + %752 = icmp sle <4 x i32> %739, %shift53, !dbg !23 + %753 = extractelement <4 x i1> %752, i64 1, !dbg !23 + %754 = or i1 %751, %753, !dbg !23 + %755 = and i1 %750, %754, !dbg !23 + %756 = insertelement <2 x i1> poison, i1 %755, i64 0, !dbg !23 + %757 = insertelement <2 x i1> %756, i1 %747, i64 1, !dbg !23 + %758 = xor <2 x i1> %757, %29, !dbg !23 + %759 = xor i32 %741, %740, !dbg !62 + %760 = xor i32 %748, %749, !dbg !62 + %761 = extractelement <2 x i1> %758, i64 1, !dbg !63 + %762 = select i1 %761, i32 0, i32 %759, !dbg !63 + %763 = extractelement <2 x i1> %758, i64 0, !dbg !63 + %764 = select i1 %763, i32 0, i32 %760, !dbg !63 + %765 = xor i32 %762, %740, !dbg !64 + %766 = xor i32 %764, %749, !dbg !64 + %767 = xor i32 %762, %741, !dbg !64 + %768 = xor i32 %764, %748, !dbg !64 + %769 = shufflevector <4 x i32> %739, <4 x i32> poison, <2 x i32> , !dbg !67 + %770 = shufflevector <4 x i32> %739, <4 x i32> poison, <2 x i32> , !dbg !67 + %771 = xor <2 x i32> %769, %770, !dbg !67 + %772 = select <2 x i1> %758, <2 x i32> zeroinitializer, <2 x i32> %771, !dbg !65 + %773 = shufflevector <2 x i32> %772, <2 x i32> poison, <4 x i32> , !dbg !65 + %774 = shufflevector <2 x i32> %772, <2 x i32> poison, <2 x i32> , !dbg !66 + %775 = shufflevector <4 x i32> %739, <4 x i32> poison, <2 x i32> , !dbg !66 + %776 = xor <2 x i32> %774, %775, !dbg !66 + %777 = shufflevector <4 x i32> %739, <4 x i32> poison, <2 x i32> , !dbg !66 + %778 = xor <2 x i32> %772, %777, !dbg !66 + %779 = icmp sge i32 %765, %766, !dbg !23 + %780 = icmp ne i32 %765, %766, !dbg !23 + %781 = extractelement <2 x i32> %778, i64 1, !dbg !23 + %782 = extractelement <2 x i32> %776, i64 1, !dbg !23 + %783 = icmp sle i32 %781, %782, !dbg !23 + %784 = or i1 %780, %783, !dbg !23 + %785 = and i1 %779, %784, !dbg !23 + %.not24 = xor i1 %785, %27, !dbg !23 + %786 = icmp sge i32 %767, %768, !dbg !23 + %787 = icmp ne i32 %767, %768, !dbg !23 + %788 = extractelement <2 x i32> %778, i64 0, !dbg !23 + %789 = extractelement <2 x i32> %776, i64 0, !dbg !23 + %790 = icmp sle i32 %789, %788, !dbg !23 + %791 = or i1 %787, %790, !dbg !23 + %792 = and i1 %786, %791, !dbg !23 + %.not25 = xor i1 %792, %27, !dbg !23 + %793 = xor i32 %766, %765, !dbg !62 + %794 = xor i32 %768, %767, !dbg !62 + %795 = select i1 %.not24, i32 0, i32 %793, !dbg !63 + %796 = select i1 %.not25, i32 0, i32 %794, !dbg !63 + %797 = insertelement <2 x i32> poison, i32 %796, i64 0, !dbg !64 + %798 = insertelement <2 x i32> %797, i32 %795, i64 1, !dbg !64 + %799 = insertelement <2 x i32> poison, i32 %767, i64 0, !dbg !64 + %800 = insertelement <2 x i32> %799, i32 %765, i64 1, !dbg !64 + %801 = xor <2 x i32> %798, %800, !dbg !64 + %802 = insertelement <2 x i32> poison, i32 %768, i64 0, !dbg !64 + %803 = insertelement <2 x i32> %802, i32 %766, i64 1, !dbg !64 + %804 = xor <2 x i32> %798, %803, !dbg !64 + %805 = xor i32 %782, %781, !dbg !67 + %806 = xor i32 %788, %789, !dbg !67 + %807 = insertelement <2 x i1> poison, i1 %.not25, i64 0, !dbg !65 + %808 = insertelement <2 x i1> %807, i1 %.not24, i64 1, !dbg !65 + %809 = insertelement <2 x i32> poison, i32 %806, i64 0, !dbg !65 + %810 = insertelement <2 x i32> %809, i32 %805, i64 1, !dbg !65 + %811 = select <2 x i1> %808, <2 x i32> zeroinitializer, <2 x i32> %810, !dbg !65 + %812 = shufflevector <2 x i32> %811, <2 x i32> poison, <4 x i32> , !dbg !66 + %813 = xor <4 x i32> %773, %812, !dbg !66 + %814 = xor <2 x i32> %811, %776, !dbg !66 + %815 = xor <2 x i32> %811, %778, !dbg !66 + %816 = extractelement <2 x i32> %801, i64 1, !dbg !28 + %817 = mul nuw nsw i32 %816, %25, !dbg !27 + %818 = extractelement <2 x i32> %804, i64 1, !dbg !28 + %819 = mul nuw nsw i32 %818, %25, !dbg !27 + %820 = extractelement <2 x i32> %801, i64 0, !dbg !28 + %821 = mul nuw nsw i32 %820, %25, !dbg !27 + %822 = extractelement <2 x i32> %804, i64 0, !dbg !28 + %823 = mul nuw nsw i32 %822, %25, !dbg !27 + %824 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %817, i32 2, i32 31), !dbg !68 + %825 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %819, i32 2, i32 31), !dbg !68 + %826 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %821, i32 2, i32 31), !dbg !68 + %827 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %823, i32 2, i32 31), !dbg !68 + %828 = insertelement <4 x i32> poison, i32 %817, i64 0, !dbg !69 + %829 = insertelement <4 x i32> %828, i32 %819, i64 1, !dbg !69 + %830 = insertelement <4 x i32> %829, i32 %821, i64 2, !dbg !69 + %831 = insertelement <4 x i32> %830, i32 %823, i64 3, !dbg !69 + %832 = insertelement <4 x i32> poison, i32 %824, i64 0, !dbg !69 + %833 = insertelement <4 x i32> %832, i32 %825, i64 1, !dbg !69 + %834 = insertelement <4 x i32> %833, i32 %826, i64 2, !dbg !69 + %835 = insertelement <4 x i32> %834, i32 %827, i64 3, !dbg !69 + %836 = add <4 x i32> %831, %835, !dbg !69 + %837 = mul nuw nsw i32 %816, %.lobit, !dbg !28 + %838 = mul nuw nsw i32 %818, %.lobit, !dbg !28 + %839 = mul nuw nsw i32 %820, %.lobit, !dbg !28 + %840 = mul nuw nsw i32 %822, %.lobit, !dbg !28 + %841 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %837, i32 2, i32 31), !dbg !68 + %842 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %838, i32 2, i32 31), !dbg !68 + %843 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %839, i32 2, i32 31), !dbg !68 + %844 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %840, i32 2, i32 31), !dbg !68 + %845 = insertelement <4 x i32> poison, i32 %837, i64 0, !dbg !69 + %846 = insertelement <4 x i32> %845, i32 %838, i64 1, !dbg !69 + %847 = insertelement <4 x i32> %846, i32 %839, i64 2, !dbg !69 + %848 = insertelement <4 x i32> %847, i32 %840, i64 3, !dbg !69 + %849 = insertelement <4 x i32> poison, i32 %841, i64 0, !dbg !69 + %850 = insertelement <4 x i32> %849, i32 %842, i64 1, !dbg !69 + %851 = insertelement <4 x i32> %850, i32 %843, i64 2, !dbg !69 + %852 = insertelement <4 x i32> %851, i32 %844, i64 3, !dbg !69 + %853 = add <4 x i32> %848, %852, !dbg !69 + %854 = shufflevector <2 x i32> %815, <2 x i32> %814, <4 x i32> , !dbg !25 + %855 = mul nuw nsw <4 x i32> %854, %31, !dbg !25 + %856 = extractelement <4 x i32> %855, i64 0, !dbg !68 + %857 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %856, i32 2, i32 31), !dbg !68 + %858 = extractelement <4 x i32> %855, i64 1, !dbg !68 + %859 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %858, i32 2, i32 31), !dbg !68 + %860 = extractelement <4 x i32> %855, i64 2, !dbg !68 + %861 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %860, i32 2, i32 31), !dbg !68 + %862 = extractelement <4 x i32> %855, i64 3, !dbg !68 + %863 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %862, i32 2, i32 31), !dbg !68 + %864 = insertelement <4 x i32> poison, i32 %857, i64 0, !dbg !69 + %865 = insertelement <4 x i32> %864, i32 %859, i64 1, !dbg !69 + %866 = insertelement <4 x i32> %865, i32 %861, i64 2, !dbg !69 + %867 = insertelement <4 x i32> %866, i32 %863, i64 3, !dbg !69 + %868 = add <4 x i32> %855, %867, !dbg !69 + %869 = mul nuw nsw <4 x i32> %854, %33, !dbg !26 + %870 = extractelement <4 x i32> %869, i64 0, !dbg !68 + %871 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %870, i32 2, i32 31), !dbg !68 + %872 = extractelement <4 x i32> %869, i64 1, !dbg !68 + %873 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %872, i32 2, i32 31), !dbg !68 + %874 = extractelement <4 x i32> %869, i64 2, !dbg !68 + %875 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %874, i32 2, i32 31), !dbg !68 + %876 = extractelement <4 x i32> %869, i64 3, !dbg !68 + %877 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %876, i32 2, i32 31), !dbg !68 + %878 = insertelement <4 x i32> poison, i32 %871, i64 0, !dbg !69 + %879 = insertelement <4 x i32> %878, i32 %873, i64 1, !dbg !69 + %880 = insertelement <4 x i32> %879, i32 %875, i64 2, !dbg !69 + %881 = insertelement <4 x i32> %880, i32 %877, i64 3, !dbg !69 + %882 = add <4 x i32> %869, %881, !dbg !69 + %883 = icmp slt <4 x i32> %836, %853, !dbg !60 + %884 = icmp eq <4 x i32> %836, %853, !dbg !70 + %885 = icmp sgt <4 x i32> %868, %882, !dbg !71 + %886 = and <4 x i1> %884, %885, !dbg !72 + %887 = or <4 x i1> %883, %886, !dbg !73 + %888 = or <4 x i1> %883, %886, !dbg !73 + %889 = shufflevector <4 x i1> %888, <4 x i1> poison, <2 x i32> , !dbg !73 + %890 = or <4 x i1> %883, %886, !dbg !73 + %891 = shufflevector <4 x i1> %890, <4 x i1> poison, <2 x i32> , !dbg !73 + %foldExtExtBinop = xor <4 x i32> %836, %853, !dbg !62 + %foldExtExtBinop55 = xor <4 x i32> %836, %853, !dbg !62 + %foldExtExtBinop57 = xor <4 x i32> %836, %853, !dbg !62 + %foldExtExtBinop59 = xor <4 x i32> %836, %853, !dbg !62 + %892 = shufflevector <2 x i1> %889, <2 x i1> %891, <2 x i32> , !dbg !63 + %893 = shufflevector <4 x i32> %foldExtExtBinop57, <4 x i32> %foldExtExtBinop, <2 x i32> , !dbg !63 + %894 = select <2 x i1> %892, <2 x i32> %893, <2 x i32> zeroinitializer, !dbg !63 + %895 = shufflevector <2 x i1> %891, <2 x i1> %889, <2 x i32> , !dbg !63 + %896 = shufflevector <4 x i32> %foldExtExtBinop59, <4 x i32> %foldExtExtBinop55, <2 x i32> , !dbg !63 + %897 = select <2 x i1> %895, <2 x i32> %896, <2 x i32> zeroinitializer, !dbg !63 + %898 = shufflevector <2 x i32> %894, <2 x i32> %897, <2 x i32> , !dbg !64 + %899 = shufflevector <2 x i32> %801, <2 x i32> %804, <2 x i32> , !dbg !64 + %900 = xor <2 x i32> %898, %899, !dbg !64 + %901 = xor <2 x i32> %894, %801, !dbg !64 + %902 = xor <2 x i32> %897, %804, !dbg !64 + %903 = shufflevector <2 x i32> %897, <2 x i32> %894, <2 x i32> , !dbg !64 + %904 = shufflevector <2 x i32> %804, <2 x i32> %801, <2 x i32> , !dbg !64 + %905 = xor <2 x i32> %903, %904, !dbg !64 + %906 = xor <4 x i32> %868, %882, !dbg !67 + %907 = xor <4 x i32> %868, %882, !dbg !67 + %908 = shufflevector <4 x i32> %907, <4 x i32> poison, <2 x i32> , !dbg !67 + %909 = xor <4 x i32> %868, %882, !dbg !67 + %910 = shufflevector <4 x i32> %909, <4 x i32> poison, <2 x i32> , !dbg !67 + %911 = select <4 x i1> %887, <4 x i32> %906, <4 x i32> zeroinitializer, !dbg !65 + %912 = select <2 x i1> %889, <2 x i32> %908, <2 x i32> zeroinitializer, !dbg !65 + %913 = select <2 x i1> %891, <2 x i32> %910, <2 x i32> zeroinitializer, !dbg !65 + %914 = xor <4 x i32> %813, %911, !dbg !66 + %915 = shufflevector <2 x i32> %913, <2 x i32> %912, <2 x i32> , !dbg !66 + %916 = shufflevector <2 x i32> %815, <2 x i32> %814, <2 x i32> , !dbg !66 + %917 = xor <2 x i32> %915, %916, !dbg !66 + %918 = shufflevector <2 x i32> %912, <2 x i32> %913, <2 x i32> , !dbg !66 + %919 = shufflevector <2 x i32> %814, <2 x i32> %815, <2 x i32> , !dbg !66 + %920 = xor <2 x i32> %918, %919, !dbg !66 + %921 = shufflevector <2 x i32> %913, <2 x i32> %912, <2 x i32> , !dbg !66 + %922 = shufflevector <2 x i32> %815, <2 x i32> %814, <2 x i32> , !dbg !66 + %923 = xor <2 x i32> %921, %922, !dbg !66 + %924 = shufflevector <2 x i32> %912, <2 x i32> %913, <2 x i32> , !dbg !66 + %925 = shufflevector <2 x i32> %814, <2 x i32> %815, <2 x i32> , !dbg !66 + %926 = xor <2 x i32> %924, %925, !dbg !66 + %927 = shufflevector <2 x i32> %901, <2 x i32> %902, <4 x i32> , !dbg !27 + %928 = mul nuw nsw <4 x i32> %927, %35, !dbg !27 + %929 = extractelement <4 x i32> %928, i64 0, !dbg !68 + %930 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %929, i32 1, i32 31), !dbg !68 + %931 = extractelement <4 x i32> %928, i64 1, !dbg !68 + %932 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %931, i32 1, i32 31), !dbg !68 + %933 = extractelement <4 x i32> %928, i64 2, !dbg !68 + %934 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %933, i32 1, i32 31), !dbg !68 + %935 = extractelement <4 x i32> %928, i64 3, !dbg !68 + %936 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %935, i32 1, i32 31), !dbg !68 + %937 = insertelement <4 x i32> poison, i32 %930, i64 0, !dbg !69 + %938 = insertelement <4 x i32> %937, i32 %932, i64 1, !dbg !69 + %939 = insertelement <4 x i32> %938, i32 %934, i64 2, !dbg !69 + %940 = insertelement <4 x i32> %939, i32 %936, i64 3, !dbg !69 + %941 = add <4 x i32> %928, %940, !dbg !69 + %942 = mul nuw nsw <4 x i32> %927, %37, !dbg !28 + %943 = extractelement <4 x i32> %942, i64 0, !dbg !68 + %944 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %943, i32 1, i32 31), !dbg !68 + %945 = extractelement <4 x i32> %942, i64 1, !dbg !68 + %946 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %945, i32 1, i32 31), !dbg !68 + %947 = extractelement <4 x i32> %942, i64 2, !dbg !68 + %948 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %947, i32 1, i32 31), !dbg !68 + %949 = extractelement <4 x i32> %942, i64 3, !dbg !68 + %950 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %949, i32 1, i32 31), !dbg !68 + %951 = insertelement <4 x i32> poison, i32 %944, i64 0, !dbg !69 + %952 = insertelement <4 x i32> %951, i32 %946, i64 1, !dbg !69 + %953 = insertelement <4 x i32> %952, i32 %948, i64 2, !dbg !69 + %954 = insertelement <4 x i32> %953, i32 %950, i64 3, !dbg !69 + %955 = add <4 x i32> %942, %954, !dbg !69 + %956 = extractelement <2 x i32> %926, i64 1, !dbg !26 + %957 = mul nuw nsw i32 %956, %24, !dbg !25 + %958 = extractelement <2 x i32> %923, i64 1, !dbg !26 + %959 = mul nuw nsw i32 %958, %24, !dbg !25 + %960 = extractelement <2 x i32> %926, i64 0, !dbg !26 + %961 = mul nuw nsw i32 %960, %24, !dbg !25 + %962 = extractelement <2 x i32> %923, i64 0, !dbg !26 + %963 = mul nuw nsw i32 %962, %24, !dbg !25 + %964 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %957, i32 1, i32 31), !dbg !68 + %965 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %959, i32 1, i32 31), !dbg !68 + %966 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %961, i32 1, i32 31), !dbg !68 + %967 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %963, i32 1, i32 31), !dbg !68 + %968 = insertelement <2 x i32> poison, i32 %961, i64 0, !dbg !69 + %969 = insertelement <2 x i32> %968, i32 %959, i64 1, !dbg !69 + %970 = insertelement <2 x i32> poison, i32 %966, i64 0, !dbg !69 + %971 = insertelement <2 x i32> %970, i32 %965, i64 1, !dbg !69 + %972 = add <2 x i32> %969, %971, !dbg !69 + %973 = insertelement <4 x i32> poison, i32 %964, i64 0, !dbg !69 + %974 = insertelement <4 x i32> %973, i32 %965, i64 1, !dbg !69 + %975 = insertelement <4 x i32> %974, i32 %966, i64 2, !dbg !69 + %976 = insertelement <4 x i32> %975, i32 %967, i64 3, !dbg !69 + %977 = insertelement <4 x i32> poison, i32 %957, i64 0, !dbg !69 + %978 = insertelement <4 x i32> %977, i32 %959, i64 1, !dbg !69 + %979 = insertelement <4 x i32> %978, i32 %961, i64 2, !dbg !69 + %980 = insertelement <4 x i32> %979, i32 %963, i64 3, !dbg !69 + %981 = add <4 x i32> %976, %980, !dbg !69 + %982 = insertelement <2 x i32> poison, i32 %963, i64 0, !dbg !69 + %983 = insertelement <2 x i32> %982, i32 %957, i64 1, !dbg !69 + %984 = insertelement <2 x i32> poison, i32 %967, i64 0, !dbg !69 + %985 = insertelement <2 x i32> %984, i32 %964, i64 1, !dbg !69 + %986 = add <2 x i32> %983, %985, !dbg !69 + %987 = mul nuw nsw i32 %956, %22, !dbg !26 + %988 = mul nuw nsw i32 %958, %22, !dbg !26 + %989 = mul nuw nsw i32 %960, %22, !dbg !26 + %990 = mul nuw nsw i32 %962, %22, !dbg !26 + %991 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %987, i32 1, i32 31), !dbg !68 + %992 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %988, i32 1, i32 31), !dbg !68 + %993 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %989, i32 1, i32 31), !dbg !68 + %994 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %990, i32 1, i32 31), !dbg !68 + %995 = insertelement <2 x i32> poison, i32 %993, i64 0, !dbg !69 + %996 = insertelement <2 x i32> %995, i32 %992, i64 1, !dbg !69 + %997 = insertelement <2 x i32> poison, i32 %989, i64 0, !dbg !69 + %998 = insertelement <2 x i32> %997, i32 %988, i64 1, !dbg !69 + %999 = add <2 x i32> %996, %998, !dbg !69 + %1000 = insertelement <4 x i32> poison, i32 %991, i64 0, !dbg !69 + %1001 = insertelement <4 x i32> %1000, i32 %992, i64 1, !dbg !69 + %1002 = insertelement <4 x i32> %1001, i32 %993, i64 2, !dbg !69 + %1003 = insertelement <4 x i32> %1002, i32 %994, i64 3, !dbg !69 + %1004 = insertelement <4 x i32> poison, i32 %987, i64 0, !dbg !69 + %1005 = insertelement <4 x i32> %1004, i32 %988, i64 1, !dbg !69 + %1006 = insertelement <4 x i32> %1005, i32 %989, i64 2, !dbg !69 + %1007 = insertelement <4 x i32> %1006, i32 %990, i64 3, !dbg !69 + %1008 = add <4 x i32> %1003, %1007, !dbg !69 + %1009 = insertelement <2 x i32> poison, i32 %994, i64 0, !dbg !69 + %1010 = insertelement <2 x i32> %1009, i32 %991, i64 1, !dbg !69 + %1011 = insertelement <2 x i32> poison, i32 %990, i64 0, !dbg !69 + %1012 = insertelement <2 x i32> %1011, i32 %987, i64 1, !dbg !69 + %1013 = add <2 x i32> %1010, %1012, !dbg !69 + %1014 = icmp slt <4 x i32> %941, %955, !dbg !60 + %1015 = icmp slt <4 x i32> %941, %955, !dbg !60 + %1016 = shufflevector <4 x i1> %1015, <4 x i1> poison, <2 x i32> , !dbg !60 + %1017 = icmp slt <4 x i32> %941, %955, !dbg !60 + %1018 = shufflevector <4 x i1> %1017, <4 x i1> poison, <2 x i32> , !dbg !60 + %1019 = icmp eq <4 x i32> %941, %955, !dbg !70 + %1020 = icmp sgt <4 x i32> %981, %1008, !dbg !71 + %1021 = and <4 x i1> %1019, %1020, !dbg !72 + %1022 = and <4 x i1> %1019, %1020, !dbg !72 + %1023 = shufflevector <4 x i1> %1022, <4 x i1> poison, <2 x i32> , !dbg !72 + %1024 = and <4 x i1> %1019, %1020, !dbg !72 + %1025 = shufflevector <4 x i1> %1024, <4 x i1> poison, <2 x i32> , !dbg !72 + %1026 = or <4 x i1> %1014, %1021, !dbg !73 + %1027 = or <2 x i1> %1016, %1023, !dbg !73 + %1028 = or <2 x i1> %1018, %1025, !dbg !73 + %1029 = shufflevector <2 x i1> %1018, <2 x i1> %1016, <2 x i32> , !dbg !73 + %1030 = shufflevector <2 x i1> %1025, <2 x i1> %1023, <2 x i32> , !dbg !73 + %1031 = or <2 x i1> %1029, %1030, !dbg !73 + %1032 = shufflevector <2 x i1> %1016, <2 x i1> %1018, <2 x i32> , !dbg !73 + %1033 = shufflevector <2 x i1> %1023, <2 x i1> %1025, <2 x i32> , !dbg !73 + %1034 = or <2 x i1> %1032, %1033, !dbg !73 + %1035 = xor <4 x i32> %941, %955, !dbg !62 + %1036 = shufflevector <4 x i32> %1035, <4 x i32> poison, <2 x i32> , !dbg !62 + %1037 = xor <4 x i32> %941, %955, !dbg !62 + %1038 = shufflevector <4 x i32> %1037, <4 x i32> poison, <2 x i32> , !dbg !62 + %1039 = shufflevector <2 x i1> %1034, <2 x i1> %1031, <2 x i32> , !dbg !63 + %1040 = shufflevector <2 x i32> %1036, <2 x i32> %1038, <2 x i32> , !dbg !63 + %1041 = select <2 x i1> %1039, <2 x i32> %1040, <2 x i32> zeroinitializer, !dbg !63 + %1042 = select <2 x i1> %1034, <2 x i32> %1036, <2 x i32> zeroinitializer, !dbg !63 + %1043 = select <2 x i1> %1031, <2 x i32> %1038, <2 x i32> zeroinitializer, !dbg !63 + %1044 = shufflevector <2 x i1> %1031, <2 x i1> %1034, <2 x i32> , !dbg !63 + %1045 = shufflevector <2 x i32> %1038, <2 x i32> %1036, <2 x i32> , !dbg !63 + %1046 = select <2 x i1> %1044, <2 x i32> %1045, <2 x i32> zeroinitializer, !dbg !63 + %1047 = xor <2 x i32> %1041, %900, !dbg !64 + %1048 = xor <2 x i32> %1042, %901, !dbg !64 + %1049 = xor <2 x i32> %1043, %902, !dbg !64 + %1050 = xor <2 x i32> %1046, %905, !dbg !64 + %1051 = shufflevector <2 x i32> %1013, <2 x i32> %999, <4 x i32> , !dbg !67 + %1052 = shufflevector <2 x i32> %986, <2 x i32> %972, <4 x i32> , !dbg !67 + %1053 = xor <4 x i32> %1051, %1052, !dbg !67 + %1054 = xor <2 x i32> %999, %972, !dbg !67 + %1055 = xor <2 x i32> %1013, %986, !dbg !67 + %1056 = select <4 x i1> %1026, <4 x i32> %1053, <4 x i32> zeroinitializer, !dbg !65 + %1057 = select <2 x i1> %1027, <2 x i32> %1054, <2 x i32> zeroinitializer, !dbg !65 + %1058 = select <2 x i1> %1028, <2 x i32> %1055, <2 x i32> zeroinitializer, !dbg !65 + %1059 = shufflevector <2 x i1> %1031, <2 x i1> %1034, <2 x i32> , !dbg !65 + %1060 = shufflevector <2 x i32> %1054, <2 x i32> %1055, <2 x i32> , !dbg !65 + %1061 = select <2 x i1> %1059, <2 x i32> %1060, <2 x i32> zeroinitializer, !dbg !65 + %1062 = shufflevector <2 x i1> %1031, <2 x i1> %1034, <2 x i32> , !dbg !65 + %1063 = shufflevector <2 x i32> %1055, <2 x i32> %1054, <2 x i32> , !dbg !65 + %1064 = select <2 x i1> %1062, <2 x i32> %1063, <2 x i32> zeroinitializer, !dbg !65 + %1065 = shufflevector <2 x i32> %1055, <2 x i32> %1054, <2 x i32> , !dbg !65 + %1066 = select <2 x i1> %1031, <2 x i32> %1065, <2 x i32> zeroinitializer, !dbg !65 + %1067 = shufflevector <2 x i32> %1054, <2 x i32> %1055, <2 x i32> , !dbg !65 + %1068 = select <2 x i1> %1034, <2 x i32> %1067, <2 x i32> zeroinitializer, !dbg !65 + %1069 = xor <4 x i32> %914, %1056, !dbg !66 + %1070 = xor <2 x i32> %1061, %917, !dbg !66 + %1071 = xor <2 x i32> %1064, %920, !dbg !66 + %1072 = xor <2 x i32> %1066, %923, !dbg !66 + %1073 = xor <2 x i32> %1068, %926, !dbg !66 + %1074 = extractelement <2 x i32> %1048, i64 0, !dbg !60 + %1075 = extractelement <2 x i32> %1048, i64 1, !dbg !60 + %1076 = icmp slt i32 %1075, %1074, !dbg !60 + %1077 = extractelement <2 x i32> %1049, i64 0, !dbg !60 + %1078 = extractelement <2 x i32> %1049, i64 1, !dbg !60 + %1079 = icmp slt i32 %1078, %1077, !dbg !60 + %1080 = icmp eq i32 %1075, %1074, !dbg !70 + %1081 = icmp eq i32 %1078, %1077, !dbg !70 + %shift61 = shufflevector <2 x i32> %1073, <2 x i32> poison, <2 x i32> , !dbg !71 + %1082 = icmp sgt <2 x i32> %shift61, %1073, !dbg !71 + %1083 = extractelement <2 x i1> %1082, i64 0, !dbg !71 + %shift62 = shufflevector <2 x i32> %1072, <2 x i32> poison, <2 x i32> , !dbg !71 + %1084 = icmp sgt <2 x i32> %shift62, %1072, !dbg !71 + %1085 = extractelement <2 x i1> %1084, i64 0, !dbg !71 + %1086 = and i1 %1080, %1083, !dbg !72 + %1087 = and i1 %1081, %1085, !dbg !72 + %1088 = insertelement <2 x i1> poison, i1 %1076, i64 0, !dbg !73 + %1089 = insertelement <2 x i1> %1088, i1 %1079, i64 1, !dbg !73 + %1090 = insertelement <2 x i1> poison, i1 %1086, i64 0, !dbg !73 + %1091 = insertelement <2 x i1> %1090, i1 %1087, i64 1, !dbg !73 + %1092 = or <2 x i1> %1089, %1091, !dbg !73 + %1093 = shufflevector <2 x i32> %1048, <2 x i32> %1049, <2 x i32> , !dbg !62 + %1094 = shufflevector <2 x i32> %1048, <2 x i32> %1049, <2 x i32> , !dbg !62 + %1095 = xor <2 x i32> %1093, %1094, !dbg !62 + %1096 = select <2 x i1> %1092, <2 x i32> %1095, <2 x i32> zeroinitializer, !dbg !63 + %1097 = xor <2 x i32> %1096, %1047, !dbg !64 + %1098 = shufflevector <2 x i32> %1096, <2 x i32> poison, <2 x i32> zeroinitializer, !dbg !64 + %1099 = xor <2 x i32> %1098, %1048, !dbg !64 + %1100 = shufflevector <2 x i32> %1096, <2 x i32> poison, <2 x i32> , !dbg !64 + %1101 = shufflevector <2 x i32> %1096, <2 x i32> poison, <2 x i32> , !dbg !64 + %1102 = xor <2 x i32> %1101, %1049, !dbg !64 + %1103 = xor <2 x i32> %1100, %1050, !dbg !64 + %1104 = xor <2 x i32> %1071, %1070, !dbg !67 + %1105 = shufflevector <2 x i1> %1092, <2 x i1> poison, <2 x i32> , !dbg !65 + %1106 = select <2 x i1> %1105, <2 x i32> %1104, <2 x i32> zeroinitializer, !dbg !65 + %1107 = shufflevector <2 x i32> %1106, <2 x i32> poison, <4 x i32> , !dbg !66 + %1108 = xor <4 x i32> %1069, %1107, !dbg !66 + %1109 = shufflevector <2 x i32> %1106, <2 x i32> poison, <2 x i32> , !dbg !66 + %1110 = xor <2 x i32> %1057, %1109, !dbg !66 + %1111 = shufflevector <2 x i32> %1106, <2 x i32> poison, <2 x i32> zeroinitializer, !dbg !66 + %1112 = xor <2 x i32> %1111, %1072, !dbg !66 + %1113 = shufflevector <2 x i32> %1106, <2 x i32> poison, <2 x i32> , !dbg !66 + %1114 = xor <2 x i32> %1113, %1073, !dbg !66 + %1115 = icmp slt <2 x i32> %1099, %1102, !dbg !60 + %1116 = icmp eq <2 x i32> %1097, %1103, !dbg !70 + %1117 = icmp sgt <2 x i32> %1114, %1112, !dbg !71 + %1118 = and <2 x i1> %1116, %1117, !dbg !72 + %1119 = or <2 x i1> %1115, %1118, !dbg !73 + %1120 = xor <2 x i32> %1110, %1058, !dbg !67 + %1121 = xor <2 x i32> %1120, %814, !dbg !67 + %1122 = xor <2 x i32> %1121, %912, !dbg !67 + %1123 = xor <2 x i32> %1122, %815, !dbg !67 + %1124 = xor <2 x i32> %1123, %913, !dbg !67 + %1125 = xor <2 x i32> %1124, %1106, !dbg !67 + %1126 = select <2 x i1> %1119, <2 x i32> %1125, <2 x i32> zeroinitializer, !dbg !65 + %1127 = shufflevector <2 x i32> %1126, <2 x i32> poison, <4 x i32> , !dbg !65 + %1128 = xor <4 x i32> %1108, %1127, !dbg !66 + %1129 = xor <4 x i32> %1128, %739, !dbg !66 + %1130 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %539, i32 2, i32 31), !dbg !58 + %1131 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 0, i32 2, i32 31), !dbg !58 + %1132 = insertelement <2 x i32> poison, i32 %1130, i64 0, !dbg !58 + %1133 = insertelement <2 x i32> %1132, i32 %1131, i64 1, !dbg !58 + %1134 = bitcast <2 x i32> %1133 to i64, !dbg !58 + %1135 = add i64 %538, %1134, !dbg !56 + %extelt.offset = lshr i64 %1135, 32, !dbg !58 + %1136 = trunc nuw i64 %extelt.offset to i32, !dbg !58 + %1137 = trunc i64 %1135 to i32, !dbg !58 + %1138 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %1137, i32 1, i32 31), !dbg !58 + %1139 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %1136, i32 1, i32 31), !dbg !58 + %1140 = insertelement <2 x i32> poison, i32 %1138, i64 0, !dbg !58 + %1141 = insertelement <2 x i32> %1140, i32 %1139, i64 1, !dbg !58 + %1142 = bitcast <2 x i32> %1141 to i64, !dbg !58 + %1143 = add i64 %1135, %1142, !dbg !56 + %1144 = insertelement <1 x i64> poison, i64 %1143, i64 0, !dbg !34 + store <1 x i64> %1144, ptr addrspace(3) %49, align 8, !dbg !34 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !34 + %1145 = load i64, ptr addrspace(3) %51, align 8, !dbg !34 + %1146 = select <4 x i1> %47, <4 x i1> %540, <4 x i1> zeroinitializer, !dbg !74 + %1147 = bitcast <4 x i1> %1146 to i4, !dbg !75 + %1148 = tail call range(i4 0, 5) i4 @llvm.ctpop.i4(i4 %1147), !dbg !75 + %1149 = zext nneg i4 %1148 to i64, !dbg !75 + %1150 = zext nneg i4 %1148 to i32, !dbg !77 + %1151 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %1150, i32 2, i32 31), !dbg !77 + %1152 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 0, i32 2, i32 31), !dbg !77 + %1153 = insertelement <2 x i32> poison, i32 %1151, i64 0, !dbg !77 + %1154 = insertelement <2 x i32> %1153, i32 %1152, i64 1, !dbg !77 + %1155 = bitcast <2 x i32> %1154 to i64, !dbg !77 + %1156 = add i64 %1149, %1155, !dbg !75 + %extelt.offset34 = lshr i64 %1156, 32, !dbg !77 + %1157 = trunc nuw i64 %extelt.offset34 to i32, !dbg !77 + %1158 = trunc i64 %1156 to i32, !dbg !77 + %1159 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %1158, i32 1, i32 31), !dbg !77 + %1160 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %1157, i32 1, i32 31), !dbg !77 + %1161 = insertelement <2 x i32> poison, i32 %1159, i64 0, !dbg !77 + %1162 = insertelement <2 x i32> %1161, i32 %1160, i64 1, !dbg !77 + %1163 = bitcast <2 x i32> %1162 to i64, !dbg !77 + %1164 = add i64 %1156, %1163, !dbg !75 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !78 + %1165 = insertelement <1 x i64> poison, i64 %1164, i64 0, !dbg !78 + store <1 x i64> %1165, ptr addrspace(3) %49, align 8, !dbg !78 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !78 + %1166 = load i64, ptr addrspace(3) %51, align 8, !dbg !78 + %1167 = trunc i64 %1143 to i32, !dbg !34 + %1168 = insertelement <4 x i32> poison, i32 %55, i64 0, !dbg !79 + %1169 = insertelement <4 x i32> %1168, i32 %54, i64 1, !dbg !79 + %1170 = insertelement <4 x i32> %1169, i32 %53, i64 2, !dbg !79 + %1171 = insertelement <4 x i32> %1170, i32 %52, i64 3, !dbg !79 + %1172 = insertelement <4 x i32> poison, i32 %1167, i64 0, !dbg !79 + %1173 = shufflevector <4 x i32> %1172, <4 x i32> poison, <4 x i32> zeroinitializer, !dbg !79 + %1174 = icmp slt <4 x i32> %1171, %1173, !dbg !79 + %1175 = select <4 x i1> %1174, <4 x i32> %534, <4 x i32> splat (i32 16), !dbg !80 + %1176 = add <4 x i32> %1175, splat (i32 17), !dbg !81 + %1177 = icmp slt <4 x i32> %1175, zeroinitializer, !dbg !82 + %1178 = select <4 x i1> %1177, <4 x i32> %1176, <4 x i32> %1175, !dbg !83 + %1179 = icmp ugt <4 x i32> %1178, splat (i32 16), !dbg !84 + %shift63 = shufflevector <4 x i1> %1179, <4 x i1> poison, <4 x i32> , !dbg !85 + %foldExtExtBinop64 = or <4 x i1> %shift63, %1179, !dbg !85 + %shift66 = shufflevector <4 x i1> %foldExtExtBinop64, <4 x i1> poison, <4 x i32> , !dbg !85 + %foldExtExtBinop67 = or <4 x i1> %1179, %shift66, !dbg !85 + %shift69 = shufflevector <4 x i1> %foldExtExtBinop67, <4 x i1> poison, <4 x i32> , !dbg !85 + %foldExtExtBinop70 = or <4 x i1> %1179, %shift69, !dbg !85 + %1180 = extractelement <4 x i1> %foldExtExtBinop70, i64 0, !dbg !85 + %1181 = and i1 %19, %1180, !dbg !85 + br i1 %1181, label %1182, label %1183, !dbg !85 + +1182: ; preds = %11 + tail call void @__assertfail(ptr nonnull @assertMessage_0, ptr nonnull @assertFile_0, i32 71, ptr nonnull @assertFunc_0, i64 1), !dbg !85 + unreachable, !dbg !85 + +1183: ; preds = %11 + %1184 = trunc i64 %1164 to i32, !dbg !78 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !85 + %1185 = insertelement <4 x i32> poison, i32 %52, i64 0, !dbg !86 + %1186 = insertelement <4 x i32> %1185, i32 %53, i64 1, !dbg !86 + %1187 = insertelement <4 x i32> %1186, i32 %54, i64 2, !dbg !86 + %1188 = insertelement <4 x i32> %1187, i32 %55, i64 3, !dbg !86 + %1189 = insertelement <4 x i32> poison, i32 %1184, i64 0, !dbg !86 + %1190 = shufflevector <4 x i32> %1189, <4 x i32> poison, <4 x i32> zeroinitializer, !dbg !86 + %1191 = icmp slt <4 x i32> %1188, %1190, !dbg !86 + %1192 = select <4 x i1> %1191, <4 x i32> %1129, <4 x i32> splat (i32 16), !dbg !87 + %1193 = add <4 x i32> %1192, splat (i32 17), !dbg !88 + %1194 = icmp slt <4 x i32> %1192, zeroinitializer, !dbg !89 + %1195 = select <4 x i1> %1194, <4 x i32> %1193, <4 x i32> %1192, !dbg !90 + %1196 = icmp ugt <4 x i32> %1195, splat (i32 16), !dbg !91 + %1197 = bitcast <4 x i1> %1196 to i4, !dbg !92 + %1198 = icmp ne i4 %1197, 0, !dbg !92 + %1199 = and i1 %19, %1198, !dbg !92 + br i1 %1199, label %1200, label %1201, !dbg !92 + +1200: ; preds = %1183 + tail call void @__assertfail(ptr nonnull @assertMessage_1, ptr nonnull @assertFile_1, i32 80, ptr nonnull @assertFunc_1, i64 1), !dbg !92 + unreachable, !dbg !92 + +1201: ; preds = %1183 + %1202 = trunc i64 %1166 to i32, !dbg !78 + %1203 = trunc i64 %1145 to i32, !dbg !34 + %1204 = or disjoint i32 %13, %17, !dbg !13 + %1205 = icmp slt i32 %1204, 128, !dbg !14 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !92 + %1206 = sext i32 %1204 to i64, !dbg !93 + %1207 = getelementptr i32, ptr addrspace(1) %1, i64 %1206, !dbg !93 + %1208 = and i32 %14, 96, !dbg !94 + %1209 = icmp eq i32 %1208, 0, !dbg !94 + %1210 = and i1 %1209, %1205, !dbg !94 + tail call void asm sideeffect "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b"(i32 %1203, ptr addrspace(1) %1207, i1 %1210) #6, !dbg !94 + %1211 = getelementptr i32, ptr addrspace(1) %2, i64 %1206, !dbg !95 + tail call void asm sideeffect "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b"(i32 %1202, ptr addrspace(1) %1211, i1 %1210) #6, !dbg !96 + %1212 = getelementptr i32, ptr addrspace(1) %3, i64 %58, !dbg !97 + %1213 = extractelement <4 x i32> %534, i64 0, !dbg !98 + %1214 = extractelement <4 x i32> %534, i64 1, !dbg !98 + %1215 = extractelement <4 x i32> %534, i64 2, !dbg !98 + %1216 = extractelement <4 x i32> %534, i64 3, !dbg !98 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %1216, i32 %1215, i32 %1214, i32 %1213, ptr addrspace(1) %1212, i1 %19) #6, !dbg !98 + %1217 = mul i32 %18, 17, !dbg !99 + %1218 = extractelement <4 x i32> %1178, i64 3, !dbg !100 + %1219 = add i32 %1218, %1217, !dbg !100 + %1220 = extractelement <4 x i32> %1178, i64 2, !dbg !100 + %1221 = add i32 %1220, %1217, !dbg !100 + %1222 = extractelement <4 x i32> %1178, i64 1, !dbg !100 + %1223 = add i32 %1222, %1217, !dbg !100 + %1224 = extractelement <4 x i32> %1178, i64 0, !dbg !100 + %1225 = add i32 %1224, %1217, !dbg !100 + %1226 = sext i32 %1219 to i64, !dbg !101 + %1227 = getelementptr i32, ptr addrspace(1) %4, i64 %1226, !dbg !101 + %1228 = sext i32 %1221 to i64, !dbg !101 + %1229 = getelementptr i32, ptr addrspace(1) %4, i64 %1228, !dbg !101 + %1230 = sext i32 %1223 to i64, !dbg !101 + %1231 = getelementptr i32, ptr addrspace(1) %4, i64 %1230, !dbg !101 + %1232 = sext i32 %1225 to i64, !dbg !101 + %1233 = getelementptr i32, ptr addrspace(1) %4, i64 %1232, !dbg !101 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !102 + %1234 = ptrtoint ptr addrspace(1) %1227 to i64, !dbg !102 + %1235 = ptrtoint ptr addrspace(1) %1229 to i64, !dbg !102 + %1236 = ptrtoint ptr addrspace(1) %1231 to i64, !dbg !102 + %1237 = ptrtoint ptr addrspace(1) %1233 to i64, !dbg !102 + %1238 = and i32 %14, 48, !dbg !102 + %1239 = shl nuw nsw i32 %1238, 6, !dbg !102 + %1240 = shl nuw nsw i32 %14, 3, !dbg !102 + %1241 = and i32 %1240, 120, !dbg !102 + %1242 = lshr exact i32 %1238, 1, !dbg !102 + %1243 = shl nuw nsw i32 %14, 1, !dbg !102 + %1244 = and i32 %1243, 128, !dbg !102 + %1245 = or disjoint i32 %1239, %1241, !dbg !102 + %1246 = xor i32 %1245, %1242, !dbg !102 + %1247 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %1244, !dbg !102 + %1248 = getelementptr inbounds nuw i8, ptr addrspace(3) %1247, i32 %1246, !dbg !102 + %1249 = insertelement <1 x i64> poison, i64 %1234, i64 0, !dbg !102 + store <1 x i64> %1249, ptr addrspace(3) %1248, align 8, !dbg !102 + %1250 = getelementptr inbounds nuw i8, ptr addrspace(3) %1248, i32 256, !dbg !102 + %1251 = insertelement <1 x i64> poison, i64 %1235, i64 0, !dbg !102 + store <1 x i64> %1251, ptr addrspace(3) %1250, align 8, !dbg !102 + %1252 = getelementptr inbounds nuw i8, ptr addrspace(3) %1248, i32 512, !dbg !102 + %1253 = insertelement <1 x i64> poison, i64 %1236, i64 0, !dbg !102 + store <1 x i64> %1253, ptr addrspace(3) %1252, align 8, !dbg !102 + %1254 = getelementptr inbounds nuw i8, ptr addrspace(3) %1248, i32 768, !dbg !102 + %1255 = insertelement <1 x i64> poison, i64 %1237, i64 0, !dbg !102 + store <1 x i64> %1255, ptr addrspace(3) %1254, align 8, !dbg !102 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !102 + %1256 = and i32 %14, 12, !dbg !102 + %1257 = shl nuw nsw i32 %1256, 8, !dbg !102 + %1258 = shl nuw nsw i32 %20, 5, !dbg !102 + %1259 = and i32 %1240, 896, !dbg !102 + %1260 = shl nuw nsw i32 %1256, 1, !dbg !102 + %1261 = or disjoint i32 %1257, %1258, !dbg !102 + %1262 = or disjoint i32 %1261, %1259, !dbg !102 + %1263 = or disjoint i32 %1262, %1260, !dbg !102 + %1264 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %1263, !dbg !102 + %1265 = load i64, ptr addrspace(3) %1264, align 8, !dbg !102 + %1266 = xor i32 %1263, 8, !dbg !102 + %1267 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %1266, !dbg !102 + %1268 = load i64, ptr addrspace(3) %1267, align 8, !dbg !102 + %1269 = xor i32 %1263, 16, !dbg !102 + %1270 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %1269, !dbg !102 + %1271 = load i64, ptr addrspace(3) %1270, align 8, !dbg !102 + %1272 = xor i32 %1263, 24, !dbg !102 + %1273 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %1272, !dbg !102 + %1274 = load i64, ptr addrspace(3) %1273, align 8, !dbg !102 + %1275 = inttoptr i64 %1265 to ptr addrspace(1), !dbg !102 + %1276 = inttoptr i64 %1268 to ptr addrspace(1), !dbg !102 + %1277 = inttoptr i64 %1271 to ptr addrspace(1), !dbg !102 + %1278 = inttoptr i64 %1274 to ptr addrspace(1), !dbg !102 + tail call void asm sideeffect "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b"(i32 1, ptr addrspace(1) %1275, i1 %1205) #6, !dbg !102 + tail call void asm sideeffect "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b"(i32 1, ptr addrspace(1) %1276, i1 %1205) #6, !dbg !102 + tail call void asm sideeffect "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b"(i32 1, ptr addrspace(1) %1277, i1 %1205) #6, !dbg !102 + tail call void asm sideeffect "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b"(i32 1, ptr addrspace(1) %1278, i1 %1205) #6, !dbg !102 + %1279 = getelementptr i32, ptr addrspace(1) %5, i64 %58, !dbg !103 + %1280 = extractelement <4 x i32> %1129, i64 0, !dbg !104 + %1281 = extractelement <4 x i32> %1129, i64 1, !dbg !104 + %1282 = extractelement <4 x i32> %1129, i64 2, !dbg !104 + %1283 = extractelement <4 x i32> %1129, i64 3, !dbg !104 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %1280, i32 %1281, i32 %1282, i32 %1283, ptr addrspace(1) %1279, i1 %19) #6, !dbg !104 + %1284 = extractelement <4 x i32> %1195, i64 0, !dbg !105 + %1285 = add i32 %1284, %1217, !dbg !105 + %1286 = extractelement <4 x i32> %1195, i64 1, !dbg !105 + %1287 = add i32 %1286, %1217, !dbg !105 + %1288 = extractelement <4 x i32> %1195, i64 2, !dbg !105 + %1289 = add i32 %1288, %1217, !dbg !105 + %1290 = extractelement <4 x i32> %1195, i64 3, !dbg !105 + %1291 = add i32 %1290, %1217, !dbg !105 + %1292 = sext i32 %1285 to i64, !dbg !106 + %1293 = getelementptr i32, ptr addrspace(1) %6, i64 %1292, !dbg !106 + %1294 = sext i32 %1287 to i64, !dbg !106 + %1295 = getelementptr i32, ptr addrspace(1) %6, i64 %1294, !dbg !106 + %1296 = sext i32 %1289 to i64, !dbg !106 + %1297 = getelementptr i32, ptr addrspace(1) %6, i64 %1296, !dbg !106 + %1298 = sext i32 %1291 to i64, !dbg !106 + %1299 = getelementptr i32, ptr addrspace(1) %6, i64 %1298, !dbg !106 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !107 + %1300 = ptrtoint ptr addrspace(1) %1293 to i64, !dbg !107 + %1301 = ptrtoint ptr addrspace(1) %1295 to i64, !dbg !107 + %1302 = ptrtoint ptr addrspace(1) %1297 to i64, !dbg !107 + %1303 = ptrtoint ptr addrspace(1) %1299 to i64, !dbg !107 + %1304 = insertelement <1 x i64> poison, i64 %1300, i64 0, !dbg !107 + store <1 x i64> %1304, ptr addrspace(3) %1248, align 8, !dbg !107 + %1305 = insertelement <1 x i64> poison, i64 %1301, i64 0, !dbg !107 + store <1 x i64> %1305, ptr addrspace(3) %1250, align 8, !dbg !107 + %1306 = insertelement <1 x i64> poison, i64 %1302, i64 0, !dbg !107 + store <1 x i64> %1306, ptr addrspace(3) %1252, align 8, !dbg !107 + %1307 = insertelement <1 x i64> poison, i64 %1303, i64 0, !dbg !107 + store <1 x i64> %1307, ptr addrspace(3) %1254, align 8, !dbg !107 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !107 + %1308 = load i64, ptr addrspace(3) %1264, align 8, !dbg !107 + %1309 = load i64, ptr addrspace(3) %1267, align 8, !dbg !107 + %1310 = load i64, ptr addrspace(3) %1270, align 8, !dbg !107 + %1311 = load i64, ptr addrspace(3) %1273, align 8, !dbg !107 + %1312 = inttoptr i64 %1308 to ptr addrspace(1), !dbg !107 + %1313 = inttoptr i64 %1309 to ptr addrspace(1), !dbg !107 + %1314 = inttoptr i64 %1310 to ptr addrspace(1), !dbg !107 + %1315 = inttoptr i64 %1311 to ptr addrspace(1), !dbg !107 + tail call void asm sideeffect "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b"(i32 1, ptr addrspace(1) %1312, i1 %1205) #6, !dbg !107 + tail call void asm sideeffect "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b"(i32 1, ptr addrspace(1) %1313, i1 %1205) #6, !dbg !107 + tail call void asm sideeffect "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b"(i32 1, ptr addrspace(1) %1314, i1 %1205) #6, !dbg !107 + tail call void asm sideeffect "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b"(i32 1, ptr addrspace(1) %1315, i1 %1205) #6, !dbg !107 + ret void, !dbg !108 +} + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare noundef range(i32 0, 2147483647) i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #2 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare noundef range(i32 0, 1024) i32 @llvm.nvvm.read.ptx.sreg.tid.x() #2 + +; Function Attrs: convergent nocallback nounwind memory(inaccessiblemem: readwrite) +declare i32 @llvm.nvvm.shfl.sync.bfly.i32(i32, i32, i32, i32) #3 + +; Function Attrs: convergent nocallback nounwind +declare void @llvm.nvvm.barrier.cta.sync.aligned.all(i32) #4 + +; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare i4 @llvm.ctpop.i4(i4) #5 + +attributes #0 = { noreturn } +attributes #1 = { "nvvm.reqntid"="128" } +attributes #2 = { mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) } +attributes #3 = { convergent nocallback nounwind memory(inaccessiblemem: readwrite) } +attributes #4 = { convergent nocallback nounwind } +attributes #5 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } +attributes #6 = { nounwind } + +!llvm.dbg.cu = !{!0} +!llvm.module.flags = !{!2, !3} +!llvm.ident = !{!4} + +!0 = distinct !DICompileUnit(language: DW_LANG_C, file: !1, producer: "triton", isOptimized: true, runtimeVersion: 0, emissionKind: LineTablesOnly) +!1 = !DIFile(filename: "csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py", directory: "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv") +!2 = !{i32 2, !"Debug Info Version", i32 3} +!3 = !{i32 4, !"nvvm-reflect-ftz", i32 1} +!4 = !{!"clang version 3.8.0 (tags/RELEASE_380/final)"} +!5 = !DISubprogram(name: "__assertfail", linkageName: "__assertfail", scope: !6, file: !6, type: !7, spFlags: DISPFlagOptimized) +!6 = !DIFile(filename: "", directory: "") +!7 = !DISubroutineType(cc: DW_CC_normal, types: !8) +!8 = !{} +!9 = distinct !DISubprogram(name: "triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2", linkageName: "triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2", scope: !1, file: !1, line: 18, type: !7, scopeLine: 18, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !0) +!10 = !DILocation(line: 24, column: 28, scope: !9) +!11 = !DILocation(line: 24, column: 33, scope: !9) +!12 = !DILocation(line: 25, column: 44, scope: !9) +!13 = !DILocation(line: 25, column: 23, scope: !9) +!14 = !DILocation(line: 26, column: 21, scope: !9) +!15 = !DILocation(line: 27, column: 38, scope: !9) +!16 = !DILocation(line: 34, column: 40, scope: !9) +!17 = !DILocation(line: 627, column: 44, scope: !18, inlinedAt: !20) +!18 = distinct !DILexicalBlockFile(scope: !9, file: !19, discriminator: 0) +!19 = !DIFile(filename: "triton_helpers.py", directory: "/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime") +!20 = !DILocation(line: 46, column: 71, scope: !9) +!21 = !DILocation(line: 537, column: 21, scope: !18, inlinedAt: !20) +!22 = !DILocation(line: 599, column: 28, scope: !18, inlinedAt: !20) +!23 = !DILocation(line: 599, column: 28, scope: !18, inlinedAt: !24) +!24 = !DILocation(line: 51, column: 71, scope: !9) +!25 = !DILocation(line: 548, column: 23, scope: !18, inlinedAt: !24) +!26 = !DILocation(line: 551, column: 23, scope: !18, inlinedAt: !24) +!27 = !DILocation(line: 538, column: 40, scope: !18, inlinedAt: !24) +!28 = !DILocation(line: 539, column: 41, scope: !18, inlinedAt: !24) +!29 = !DILocation(line: 538, column: 40, scope: !18, inlinedAt: !20) +!30 = !DILocation(line: 539, column: 41, scope: !18, inlinedAt: !20) +!31 = !DILocation(line: 551, column: 23, scope: !18, inlinedAt: !20) +!32 = !DILocation(line: 548, column: 23, scope: !18, inlinedAt: !20) +!33 = !DILocation(line: 54, column: 35, scope: !9) +!34 = !DILocation(line: 60, column: 21, scope: !9) +!35 = !DILocation(line: 34, column: 37, scope: !9) +!36 = !DILocation(line: 34, column: 30, scope: !9) +!37 = !DILocation(line: 34, column: 45, scope: !9) +!38 = !DILocation(line: 601, column: 48, scope: !18, inlinedAt: !20) +!39 = !DILocation(line: 39, column: 18, scope: !9) +!40 = !DILocation(line: 574, column: 22, scope: !18, inlinedAt: !20) +!41 = !DILocation(line: 0, scope: !9) +!42 = !DILocation(line: 599, column: 19, scope: !18, inlinedAt: !20) +!43 = !DILocation(line: 600, column: 38, scope: !18, inlinedAt: !20) +!44 = !DILocation(line: 600, column: 46, scope: !18, inlinedAt: !20) +!45 = !DILocation(line: 600, column: 15, scope: !18, inlinedAt: !20) +!46 = !DILocation(line: 601, column: 59, scope: !18, inlinedAt: !20) +!47 = !DILocation(line: 601, column: 22, scope: !18, inlinedAt: !20) +!48 = !DILocation(line: 291, column: 36, scope: !49, inlinedAt: !20) +!49 = distinct !DILexicalBlockFile(scope: !9, file: !50, discriminator: 0) +!50 = !DIFile(filename: "standard.py", directory: "/workspace/specforge/lib/python3.11/site-packages/triton/language") +!51 = !DILocation(line: 261, column: 15, scope: !49, inlinedAt: !20) +!52 = !DILocation(line: 591, column: 21, scope: !18, inlinedAt: !20) +!53 = !DILocation(line: 594, column: 40, scope: !18, inlinedAt: !20) +!54 = !DILocation(line: 594, column: 29, scope: !18, inlinedAt: !20) +!55 = !DILocation(line: 594, column: 23, scope: !18, inlinedAt: !20) +!56 = !DILocation(line: 261, column: 15, scope: !49, inlinedAt: !57) +!57 = !DILocation(line: 55, column: 26, scope: !9) +!58 = !DILocation(line: 291, column: 36, scope: !49, inlinedAt: !57) +!59 = !DILocation(line: 47, column: 20, scope: !9) +!60 = !DILocation(line: 574, column: 22, scope: !18, inlinedAt: !24) +!61 = !DILocation(line: 599, column: 19, scope: !18, inlinedAt: !24) +!62 = !DILocation(line: 600, column: 38, scope: !18, inlinedAt: !24) +!63 = !DILocation(line: 600, column: 46, scope: !18, inlinedAt: !24) +!64 = !DILocation(line: 600, column: 15, scope: !18, inlinedAt: !24) +!65 = !DILocation(line: 601, column: 59, scope: !18, inlinedAt: !24) +!66 = !DILocation(line: 601, column: 22, scope: !18, inlinedAt: !24) +!67 = !DILocation(line: 601, column: 48, scope: !18, inlinedAt: !24) +!68 = !DILocation(line: 291, column: 36, scope: !49, inlinedAt: !24) +!69 = !DILocation(line: 261, column: 15, scope: !49, inlinedAt: !24) +!70 = !DILocation(line: 591, column: 21, scope: !18, inlinedAt: !24) +!71 = !DILocation(line: 594, column: 40, scope: !18, inlinedAt: !24) +!72 = !DILocation(line: 594, column: 29, scope: !18, inlinedAt: !24) +!73 = !DILocation(line: 594, column: 23, scope: !18, inlinedAt: !24) +!74 = !DILocation(line: 58, column: 35, scope: !9) +!75 = !DILocation(line: 261, column: 15, scope: !49, inlinedAt: !76) +!76 = !DILocation(line: 59, column: 26, scope: !9) +!77 = !DILocation(line: 291, column: 36, scope: !49, inlinedAt: !76) +!78 = !DILocation(line: 61, column: 21, scope: !9) +!79 = !DILocation(line: 64, column: 19, scope: !9) +!80 = !DILocation(line: 66, column: 35, scope: !9) +!81 = !DILocation(line: 68, column: 20, scope: !9) +!82 = !DILocation(line: 69, column: 20, scope: !9) +!83 = !DILocation(line: 70, column: 35, scope: !9) +!84 = !DILocation(line: 71, column: 38, scope: !9) +!85 = !DILocation(line: 71, column: 63, scope: !9) +!86 = !DILocation(line: 75, column: 19, scope: !9) +!87 = !DILocation(line: 76, column: 35, scope: !9) +!88 = !DILocation(line: 77, column: 20, scope: !9) +!89 = !DILocation(line: 78, column: 20, scope: !9) +!90 = !DILocation(line: 79, column: 35, scope: !9) +!91 = !DILocation(line: 80, column: 38, scope: !9) +!92 = !DILocation(line: 80, column: 63, scope: !9) +!93 = !DILocation(line: 81, column: 25, scope: !9) +!94 = !DILocation(line: 81, column: 37, scope: !9) +!95 = !DILocation(line: 82, column: 25, scope: !9) +!96 = !DILocation(line: 82, column: 37, scope: !9) +!97 = !DILocation(line: 83, column: 25, scope: !9) +!98 = !DILocation(line: 83, column: 47, scope: !9) +!99 = !DILocation(line: 84, column: 52, scope: !9) +!100 = !DILocation(line: 84, column: 49, scope: !9) +!101 = !DILocation(line: 84, column: 25, scope: !9) +!102 = !DILocation(line: 84, column: 85, scope: !9) +!103 = !DILocation(line: 85, column: 25, scope: !9) +!104 = !DILocation(line: 85, column: 47, scope: !9) +!105 = !DILocation(line: 86, column: 49, scope: !9) +!106 = !DILocation(line: 86, column: 25, scope: !9) +!107 = !DILocation(line: 86, column: 85, scope: !9) +!108 = !DILocation(line: 86, column: 4, scope: !9) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/EFSFIQ3KOQUXLGL22JESW5DLVY6LJWVBODRP6EOX4ISS2B62HCMA/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ptx b/SpecForge-ext/cache/compiled_kernels/triton/0/EFSFIQ3KOQUXLGL22JESW5DLVY6LJWVBODRP6EOX4ISS2B62HCMA/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ptx new file mode 100644 index 0000000000000000000000000000000000000000..f94d8f7dcedef129d49ca49206d178e777f0a51a --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/EFSFIQ3KOQUXLGL22JESW5DLVY6LJWVBODRP6EOX4ISS2B62HCMA/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ptx @@ -0,0 +1,2103 @@ +// +// Generated by LLVM NVPTX Back-End +// + +.version 8.7 +.target sm_90a +.address_size 64 + + // .globl triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2 // -- Begin function triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2 +.extern .func __assertfail +( + .param .b64 __assertfail_param_0, + .param .b64 __assertfail_param_1, + .param .b32 __assertfail_param_2, + .param .b64 __assertfail_param_3, + .param .b64 __assertfail_param_4 +) +.noreturn; +.global .align 1 .b8 assertFunc_1[8] = {117, 110, 107, 110, 111, 119, 110}; +.global .align 1 .b8 assertFile_1[114] = {47, 119, 111, 114, 107, 115, 112, 97, 99, 101, 47, 104, 97, 110, 114, 117, 105, 47, 83, 112, 101, 99, 70, 111, 114, 103, 101, 45, 101, 120, 116, 47, 99, 97, 99, 104, 101, 47, 99, 111, 109, 112, 105, 108, 101, 100, 95, 107, 101, 114, 110, 101, 108, 115, 47, 115, 118, 47, 99, 115, 118, 52, 54, 122, 52, 110, 100, 102, 100, 54, 53, 101, 101, 98, 50, 119, 104, 119, 115, 117, 107, 106, 101, 104, 121, 100, 98, 99, 116, 54, 53, 122, 106, 112, 53, 103, 53, 113, 112, 104, 119, 120, 118, 118, 97, 101, 99, 51, 118, 121, 46, 112, 121}; +.global .align 1 .b8 assertMessage_1[37] = {105, 110, 100, 101, 120, 32, 111, 117, 116, 32, 111, 102, 32, 98, 111, 117, 110, 100, 115, 58, 32, 48, 32, 60, 61, 32, 116, 109, 112, 52, 57, 32, 60, 32, 49, 55}; +.global .align 1 .b8 assertFunc_0[8] = {117, 110, 107, 110, 111, 119, 110}; +.global .align 1 .b8 assertFile_0[114] = {47, 119, 111, 114, 107, 115, 112, 97, 99, 101, 47, 104, 97, 110, 114, 117, 105, 47, 83, 112, 101, 99, 70, 111, 114, 103, 101, 45, 101, 120, 116, 47, 99, 97, 99, 104, 101, 47, 99, 111, 109, 112, 105, 108, 101, 100, 95, 107, 101, 114, 110, 101, 108, 115, 47, 115, 118, 47, 99, 115, 118, 52, 54, 122, 52, 110, 100, 102, 100, 54, 53, 101, 101, 98, 50, 119, 104, 119, 115, 117, 107, 106, 101, 104, 121, 100, 98, 99, 116, 54, 53, 122, 106, 112, 53, 103, 53, 113, 112, 104, 119, 120, 118, 118, 97, 101, 99, 51, 118, 121, 46, 112, 121}; +.global .align 1 .b8 assertMessage_0[37] = {105, 110, 100, 101, 120, 32, 111, 117, 116, 32, 111, 102, 32, 98, 111, 117, 110, 100, 115, 58, 32, 48, 32, 60, 61, 32, 116, 109, 112, 52, 48, 32, 60, 32, 49, 55}; +.extern .shared .align 16 .b8 global_smem[]; + // @triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2 +.visible .entry triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2( + .param .u64 .ptr .global .align 1 triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_0, + .param .u64 .ptr .global .align 1 triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_1, + .param .u64 .ptr .global .align 1 triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_2, + .param .u64 .ptr .global .align 1 triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_3, + .param .u64 .ptr .global .align 1 triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_4, + .param .u64 .ptr .global .align 1 triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_5, + .param .u64 .ptr .global .align 1 triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_6, + .param .u32 triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_7, + .param .u32 triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_8, + .param .u64 .ptr .global .align 1 triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_9, + .param .u64 .ptr .global .align 1 triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_10 +) +.reqntid 128 +{ + .reg .pred %p<324>; + .reg .b16 %rs<37>; + .reg .b32 %r<812>; + .reg .b64 %rd<82>; + .loc 1 18 0 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:18:0 +$L__func_begin0: + .loc 1 18 0 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:18:0 + +// %bb.0: + ld.param.b64 %rd17, [triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_0]; +$L__tmp0: + .loc 1 24 28 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:24:28 + mov.u32 %r26, %ctaid.x; + .loc 1 24 33 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:24:33 + shl.b32 %r1, %r26, 5; + .loc 1 25 44 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:25:44 + mov.u32 %r2, %tid.x; + and.b32 %r27, %r2, 124; + bfe.u32 %r28, %r2, 2, 5; + and.b32 %r3, %r2, 31; + .loc 1 25 23 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:25:23 + or.b32 %r4, %r28, %r1; + .loc 1 26 21 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:26:21 + setp.gt.s32 %p3, %r4, 127; + setp.lt.s32 %p2, %r4, 128; + .loc 1 27 38 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:27:38 + and.b32 %r5, %r2, 3; + .loc 1 34 40 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:34:40 + shl.b32 %r29, %r4, 4; +$L__tmp1: + .loc 2 627 44 // triton_helpers.py:627:44 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + and.b32 %r30, %r2, 1; + shr.u32 %r31, %r2, 1; + bfe.u32 %r32, %r2, 1, 1; + .loc 2 537 21 // triton_helpers.py:537:21 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r33, %r30, 1; + xor.b32 %r34, %r32, 1; + .loc 2 599 28 // triton_helpers.py:599:28 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + setp.ne.b32 %p4, %r30, 0; + and.b32 %r35, %r31, 1; + setp.ne.b32 %p5, %r35, 0; +$L__tmp2: + .loc 1 60 21 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:60:21 + shl.b32 %r36, %r27, 1; + mov.b32 %r37, global_smem; + add.s32 %r38, %r37, %r36; + shl.b32 %r39, %r3, 3; + add.s32 %r40, %r37, %r39; + .loc 1 27 38 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:27:38 + shl.b32 %r6, %r5, 2; + or.b32 %r7, %r6, 1; + or.b32 %r8, %r6, 2; + or.b32 %r9, %r6, 3; + .loc 1 34 37 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:34:37 + or.b32 %r41, %r29, %r6; + .loc 1 34 30 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:34:30 + mad.wide.s32 %rd13, %r41, 8, %rd17; + cvt.s64.s32 %rd18, %r29; + cvt.u64.u32 %rd19, %r6; + or.b64 %rd20, %rd18, %rd19; + shl.b64 %rd21, %rd20, 3; + add.s64 %rd22, %rd17, %rd21; + add.s64 %rd16, %rd22, 16; + .loc 1 34 45 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:34:45 + // begin inline asm + mov.u64 %rd11, 0x0; + mov.u64 %rd12, 0x0; + @%p2 ld.global.v2.b64 { %rd11, %rd12 }, [ %rd13 + 0 ]; + // end inline asm + // begin inline asm + mov.u64 %rd14, 0x0; + mov.u64 %rd15, 0x0; + @%p2 ld.global.v2.b64 { %rd14, %rd15 }, [ %rd16 + 0 ]; + // end inline asm +$L__tmp3: + .loc 2 601 48 // triton_helpers.py:601:48 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r42, %r7, %r6; + xor.b32 %r43, %r8, %r9; +$L__tmp4: + .loc 1 39 18 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:39:18 + add.s64 %rd23, %rd12, -1; + add.s64 %rd24, %rd14, -1; + add.s64 %rd25, %rd11, -1; + add.s64 %rd26, %rd15, -1; + setp.gt.u64 %p6, %rd26, 16382; + setp.gt.u64 %p7, %rd25, 16382; + setp.lt.u64 %p8, %rd26, 16383; + setp.lt.u64 %p9, %rd24, 16383; + setp.lt.u64 %p10, %rd23, 16383; + setp.lt.u64 %p11, %rd25, 16383; + .loc 1 0 0 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:0 + selp.b32 %r44, 1, 0, %p11; + selp.b32 %r45, 1, 0, %p10; + selp.b32 %r46, 1, 0, %p9; + selp.b32 %r47, 1, 0, %p8; +$L__tmp5: + .loc 2 574 22 // triton_helpers.py:574:22 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + and.pred %p12, %p10, %p7; + .loc 2 599 19 // triton_helpers.py:599:19 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + or.pred %p13, %p9, %p6; + .loc 2 600 38 // triton_helpers.py:600:38 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r48, %r44, %r45; + xor.b32 %r49, %r46, %r47; + .loc 2 600 46 // triton_helpers.py:600:46 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + selp.b32 %r50, %r48, 0, %p12; + selp.b32 %r51, %r49, 0, %p13; + .loc 2 600 15 // triton_helpers.py:600:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r52, %r50, %r44; + xor.b32 %r53, %r50, %r45; + xor.b32 %r54, %r51, %r46; + xor.b32 %r55, %r51, %r47; + .loc 2 601 59 // triton_helpers.py:601:59 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + selp.b32 %r56, %r42, 0, %p12; + selp.b32 %r57, %r43, 0, %p13; + .loc 2 601 22 // triton_helpers.py:601:22 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r58, %r56, %r6; + xor.b32 %r59, %r56, %r7; + xor.b32 %r60, %r57, %r8; + xor.b32 %r61, %r57, %r9; + .loc 2 599 28 // triton_helpers.py:599:28 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + setp.ge.u32 %p14, %r52, %r54; + setp.ge.u32 %p15, %r53, %r55; + setp.ne.b32 %p16, %r53, %r55; + setp.ne.b32 %p17, %r54, %r52; + setp.le.u32 %p18, %r59, %r61; + setp.le.u32 %p19, %r58, %r60; + or.pred %p20, %p19, %p17; + or.pred %p21, %p18, %p16; + and.pred %p22, %p15, %p21; + and.pred %p23, %p14, %p20; + xor.pred %p24, %p23, %p4; + xor.pred %p25, %p22, %p4; + .loc 2 600 38 // triton_helpers.py:600:38 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r62, %r52, %r54; + xor.b32 %r63, %r55, %r53; + .loc 2 600 46 // triton_helpers.py:600:46 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + selp.b32 %r64, 0, %r63, %p25; + selp.b32 %r65, 0, %r62, %p24; + .loc 2 600 15 // triton_helpers.py:600:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r66, %r65, %r52; + xor.b32 %r67, %r64, %r53; + xor.b32 %r68, %r64, %r55; + xor.b32 %r69, %r65, %r54; + .loc 2 601 48 // triton_helpers.py:601:48 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r70, %r60, %r58; + xor.b32 %r71, %r61, %r59; + .loc 2 601 59 // triton_helpers.py:601:59 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + selp.b32 %r72, 0, %r70, %p24; + selp.b32 %r73, 0, %r71, %p25; + .loc 2 601 22 // triton_helpers.py:601:22 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r74, %r72, %r58; + xor.b32 %r75, %r73, %r59; + xor.b32 %r76, %r72, %r60; + xor.b32 %r77, %r73, %r61; + .loc 2 599 28 // triton_helpers.py:599:28 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + setp.ge.u32 %p26, %r66, %r67; + setp.ne.b32 %p27, %r66, %r67; + setp.le.u32 %p28, %r74, %r75; + or.pred %p29, %p27, %p28; + and.pred %p30, %p26, %p29; + xor.pred %p31, %p30, %p4; + setp.ge.u32 %p32, %r69, %r68; + setp.ne.b32 %p33, %r69, %r68; + setp.le.u32 %p34, %r76, %r77; + or.pred %p35, %p33, %p34; + and.pred %p36, %p32, %p35; + xor.pred %p37, %p36, %p4; + .loc 2 600 38 // triton_helpers.py:600:38 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r78, %r66, %r67; + xor.b32 %r79, %r69, %r68; + .loc 2 600 46 // triton_helpers.py:600:46 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + selp.b32 %r80, 0, %r78, %p31; + selp.b32 %r81, 0, %r79, %p37; + .loc 2 600 15 // triton_helpers.py:600:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r82, %r80, %r66; + xor.b32 %r83, %r80, %r67; + xor.b32 %r84, %r81, %r69; + xor.b32 %r85, %r81, %r68; + .loc 2 601 48 // triton_helpers.py:601:48 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r86, %r74, %r75; + xor.b32 %r87, %r76, %r77; + .loc 2 601 59 // triton_helpers.py:601:59 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + selp.b32 %r88, 0, %r86, %p31; + selp.b32 %r89, 0, %r87, %p37; + .loc 2 601 22 // triton_helpers.py:601:22 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r90, %r88, %r74; + xor.b32 %r91, %r88, %r75; + xor.b32 %r92, %r89, %r76; + xor.b32 %r93, %r89, %r77; + .loc 2 538 40 // triton_helpers.py:538:40 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + mul.lo.s32 %r94, %r82, %r33; + mul.lo.s32 %r95, %r83, %r33; + mul.lo.s32 %r96, %r84, %r33; + mul.lo.s32 %r97, %r85, %r33; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r98, %r94, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r99, %r94, %r98; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r100, %r95, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r101, %r95, %r100; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r102, %r96, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r103, %r96, %r102; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r104, %r97, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r105, %r97, %r104; + .loc 2 539 41 // triton_helpers.py:539:41 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + mul.lo.s32 %r106, %r82, %r30; + mul.lo.s32 %r107, %r83, %r30; + mul.lo.s32 %r108, %r84, %r30; + mul.lo.s32 %r109, %r85, %r30; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r110, %r106, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r111, %r106, %r110; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r112, %r107, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r113, %r107, %r112; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r114, %r108, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r115, %r108, %r114; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r116, %r109, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r117, %r109, %r116; + .loc 2 548 23 // triton_helpers.py:548:23 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + mul.lo.s32 %r118, %r90, %r33; + mul.lo.s32 %r119, %r91, %r33; + mul.lo.s32 %r120, %r92, %r33; + mul.lo.s32 %r121, %r93, %r33; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r122, %r118, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r123, %r118, %r122; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r124, %r119, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r125, %r119, %r124; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r126, %r120, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r127, %r120, %r126; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r128, %r121, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r129, %r121, %r128; + .loc 2 551 23 // triton_helpers.py:551:23 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + mul.lo.s32 %r130, %r90, %r30; + mul.lo.s32 %r131, %r91, %r30; + mul.lo.s32 %r132, %r92, %r30; + mul.lo.s32 %r133, %r93, %r30; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r134, %r130, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r135, %r130, %r134; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r136, %r131, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r137, %r131, %r136; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r138, %r132, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r139, %r132, %r138; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r140, %r133, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r141, %r133, %r140; + .loc 2 599 28 // triton_helpers.py:599:28 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + setp.ge.s32 %p38, %r99, %r111; + setp.ne.b32 %p39, %r99, %r111; + setp.le.s32 %p40, %r123, %r135; + or.pred %p41, %p39, %p40; + and.pred %p42, %p38, %p41; + xor.pred %p43, %p42, %p5; + setp.ge.s32 %p44, %r101, %r113; + setp.ne.b32 %p45, %r101, %r113; + setp.le.s32 %p46, %r125, %r137; + or.pred %p47, %p45, %p46; + and.pred %p48, %p44, %p47; + xor.pred %p49, %p48, %p5; + setp.ge.s32 %p50, %r103, %r115; + setp.ne.b32 %p51, %r103, %r115; + setp.le.s32 %p52, %r127, %r139; + or.pred %p53, %p51, %p52; + and.pred %p54, %p50, %p53; + xor.pred %p55, %p54, %p5; + setp.ge.s32 %p56, %r105, %r117; + setp.ne.b32 %p57, %r105, %r117; + setp.le.s32 %p58, %r129, %r141; + or.pred %p59, %p57, %p58; + and.pred %p60, %p56, %p59; + xor.pred %p61, %p60, %p5; + .loc 2 600 38 // triton_helpers.py:600:38 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r142, %r99, %r111; + xor.b32 %r143, %r101, %r113; + xor.b32 %r144, %r103, %r115; + xor.b32 %r145, %r105, %r117; + .loc 2 600 46 // triton_helpers.py:600:46 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + selp.b32 %r146, 0, %r142, %p43; + selp.b32 %r147, 0, %r143, %p49; + selp.b32 %r148, 0, %r144, %p55; + selp.b32 %r149, 0, %r145, %p61; + .loc 2 600 15 // triton_helpers.py:600:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r150, %r146, %r82; + xor.b32 %r151, %r147, %r83; + xor.b32 %r152, %r148, %r84; + xor.b32 %r153, %r149, %r85; + .loc 2 601 48 // triton_helpers.py:601:48 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r154, %r123, %r135; + xor.b32 %r155, %r125, %r137; + xor.b32 %r156, %r127, %r139; + xor.b32 %r157, %r129, %r141; + .loc 2 601 59 // triton_helpers.py:601:59 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + selp.b32 %r158, 0, %r154, %p43; + selp.b32 %r159, 0, %r155, %p49; + selp.b32 %r160, 0, %r156, %p55; + selp.b32 %r161, 0, %r157, %p61; + .loc 2 601 22 // triton_helpers.py:601:22 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r162, %r158, %r90; + xor.b32 %r163, %r159, %r91; + xor.b32 %r164, %r160, %r92; + xor.b32 %r165, %r161, %r93; + .loc 2 599 28 // triton_helpers.py:599:28 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + setp.ge.s32 %p62, %r150, %r152; + setp.ne.b32 %p63, %r150, %r152; + setp.le.s32 %p64, %r162, %r164; + or.pred %p65, %p63, %p64; + and.pred %p66, %p62, %p65; + xor.pred %p67, %p66, %p5; + setp.ge.s32 %p68, %r151, %r153; + setp.ne.b32 %p69, %r151, %r153; + setp.le.s32 %p70, %r163, %r165; + or.pred %p71, %p69, %p70; + and.pred %p72, %p68, %p71; + xor.pred %p73, %p72, %p5; + .loc 2 600 38 // triton_helpers.py:600:38 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r166, %r150, %r152; + xor.b32 %r167, %r151, %r153; + .loc 2 600 46 // triton_helpers.py:600:46 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + selp.b32 %r168, 0, %r166, %p67; + selp.b32 %r169, 0, %r167, %p73; + .loc 2 600 15 // triton_helpers.py:600:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r170, %r168, %r150; + xor.b32 %r171, %r169, %r151; + xor.b32 %r172, %r168, %r152; + xor.b32 %r173, %r169, %r153; + .loc 2 601 48 // triton_helpers.py:601:48 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r174, %r162, %r164; + xor.b32 %r175, %r163, %r165; + .loc 2 601 59 // triton_helpers.py:601:59 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + selp.b32 %r176, 0, %r174, %p67; + selp.b32 %r177, 0, %r175, %p73; + .loc 2 601 22 // triton_helpers.py:601:22 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r178, %r176, %r162; + xor.b32 %r179, %r177, %r163; + xor.b32 %r180, %r176, %r164; + xor.b32 %r181, %r177, %r165; + .loc 2 599 28 // triton_helpers.py:599:28 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + setp.ge.s32 %p74, %r170, %r171; + setp.ne.b32 %p75, %r170, %r171; + setp.le.s32 %p76, %r178, %r179; + or.pred %p77, %p75, %p76; + and.pred %p78, %p74, %p77; + xor.pred %p79, %p78, %p5; + setp.ge.s32 %p80, %r172, %r173; + setp.ne.b32 %p81, %r172, %r173; + setp.le.s32 %p82, %r180, %r181; + or.pred %p83, %p81, %p82; + and.pred %p84, %p80, %p83; + xor.pred %p85, %p84, %p5; + .loc 2 600 38 // triton_helpers.py:600:38 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r182, %r170, %r171; + xor.b32 %r183, %r172, %r173; + .loc 2 600 46 // triton_helpers.py:600:46 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + selp.b32 %r184, 0, %r182, %p79; + selp.b32 %r185, 0, %r183, %p85; + .loc 2 601 48 // triton_helpers.py:601:48 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r186, %r178, %r179; + xor.b32 %r187, %r180, %r181; + .loc 2 601 59 // triton_helpers.py:601:59 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + selp.b32 %r188, 0, %r186, %p79; + selp.b32 %r189, 0, %r187, %p85; + .loc 2 601 22 // triton_helpers.py:601:22 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r190, %r188, %r178; + xor.b32 %r191, %r188, %r179; + xor.b32 %r192, %r189, %r180; + xor.b32 %r193, %r189, %r181; + .loc 2 548 23 // triton_helpers.py:548:23 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + mul.lo.s32 %r194, %r190, %r34; + mul.lo.s32 %r195, %r191, %r34; + mul.lo.s32 %r196, %r192, %r34; + mul.lo.s32 %r197, %r193, %r34; + .loc 2 551 23 // triton_helpers.py:551:23 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + mul.lo.s32 %r198, %r190, %r32; + mul.lo.s32 %r199, %r191, %r32; + mul.lo.s32 %r200, %r192, %r32; + mul.lo.s32 %r201, %r193, %r32; + .loc 2 600 15 // triton_helpers.py:600:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r202, %r184, %r171; + xor.b32 %r203, %r184, %r170; + xor.b32 %r204, %r185, %r173; + xor.b32 %r205, %r185, %r172; + .loc 2 538 40 // triton_helpers.py:538:40 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + mul.lo.s32 %r206, %r203, %r34; + mul.lo.s32 %r207, %r202, %r34; + mul.lo.s32 %r208, %r205, %r34; + mul.lo.s32 %r209, %r204, %r34; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r210, %r206, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r211, %r206, %r210; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r212, %r207, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r213, %r207, %r212; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r214, %r208, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r215, %r208, %r214; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r216, %r209, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r217, %r209, %r216; + .loc 2 539 41 // triton_helpers.py:539:41 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + mul.lo.s32 %r218, %r203, %r32; + mul.lo.s32 %r219, %r202, %r32; + mul.lo.s32 %r220, %r205, %r32; + mul.lo.s32 %r221, %r204, %r32; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r222, %r218, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r223, %r218, %r222; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r224, %r219, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r225, %r219, %r224; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r226, %r220, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r227, %r220, %r226; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r228, %r221, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r229, %r221, %r228; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r230, %r194, 2, 31, -1; + shfl.sync.bfly.b32 %r231, %r195, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r232, %r195, %r231; + add.s32 %r233, %r194, %r230; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r234, %r196, 2, 31, -1; + shfl.sync.bfly.b32 %r235, %r197, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r236, %r197, %r235; + add.s32 %r237, %r196, %r234; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r238, %r198, 2, 31, -1; + shfl.sync.bfly.b32 %r239, %r199, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r240, %r199, %r239; + add.s32 %r241, %r198, %r238; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r242, %r200, 2, 31, -1; + shfl.sync.bfly.b32 %r243, %r201, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r244, %r201, %r243; + add.s32 %r245, %r200, %r242; + .loc 2 574 22 // triton_helpers.py:574:22 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + setp.lt.s32 %p86, %r213, %r225; + setp.lt.s32 %p87, %r211, %r223; + .loc 2 591 21 // triton_helpers.py:591:21 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + setp.eq.b32 %p88, %r211, %r223; + setp.eq.b32 %p89, %r213, %r225; + .loc 2 594 40 // triton_helpers.py:594:40 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + setp.gt.s32 %p90, %r233, %r241; + setp.gt.s32 %p91, %r232, %r240; + .loc 2 594 29 // triton_helpers.py:594:29 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + and.pred %p92, %p89, %p91; + and.pred %p93, %p88, %p90; + .loc 2 594 23 // triton_helpers.py:594:23 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + or.pred %p94, %p87, %p93; + or.pred %p95, %p86, %p92; + .loc 2 574 22 // triton_helpers.py:574:22 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + setp.lt.s32 %p96, %r217, %r229; + setp.lt.s32 %p97, %r215, %r227; + .loc 2 591 21 // triton_helpers.py:591:21 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + setp.eq.b32 %p98, %r215, %r227; + setp.eq.b32 %p99, %r217, %r229; + .loc 2 594 40 // triton_helpers.py:594:40 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + setp.gt.s32 %p100, %r237, %r245; + setp.gt.s32 %p101, %r236, %r244; + .loc 2 594 29 // triton_helpers.py:594:29 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + and.pred %p102, %p99, %p101; + and.pred %p103, %p98, %p100; + .loc 2 594 23 // triton_helpers.py:594:23 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + or.pred %p104, %p97, %p103; + or.pred %p105, %p96, %p102; + .loc 2 600 38 // triton_helpers.py:600:38 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r246, %r211, %r223; + xor.b32 %r247, %r213, %r225; + xor.b32 %r248, %r215, %r227; + xor.b32 %r249, %r217, %r229; + .loc 2 600 46 // triton_helpers.py:600:46 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + selp.b32 %r250, %r247, 0, %p95; + selp.b32 %r251, %r246, 0, %p94; + selp.b32 %r252, %r249, 0, %p105; + selp.b32 %r253, %r248, 0, %p104; + .loc 2 600 15 // triton_helpers.py:600:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r254, %r251, %r203; + xor.b32 %r255, %r250, %r202; + xor.b32 %r256, %r253, %r205; + xor.b32 %r257, %r252, %r204; + .loc 2 601 48 // triton_helpers.py:601:48 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r258, %r233, %r241; + xor.b32 %r259, %r232, %r240; + xor.b32 %r260, %r237, %r245; + xor.b32 %r261, %r236, %r244; + .loc 2 601 59 // triton_helpers.py:601:59 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + selp.b32 %r262, %r259, 0, %p95; + selp.b32 %r263, %r258, 0, %p94; + selp.b32 %r264, %r261, 0, %p105; + selp.b32 %r265, %r260, 0, %p104; + .loc 2 601 22 // triton_helpers.py:601:22 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r266, %r263, %r190; + xor.b32 %r267, %r262, %r191; + xor.b32 %r268, %r265, %r192; + xor.b32 %r269, %r264, %r193; + .loc 2 538 40 // triton_helpers.py:538:40 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + mul.lo.s32 %r270, %r255, %r33; + mul.lo.s32 %r271, %r254, %r33; + mul.lo.s32 %r272, %r257, %r33; + mul.lo.s32 %r273, %r256, %r33; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r274, %r271, 1, 31, -1; + shfl.sync.bfly.b32 %r275, %r270, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r276, %r271, %r274; + add.s32 %r277, %r270, %r275; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r278, %r273, 1, 31, -1; + shfl.sync.bfly.b32 %r279, %r272, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r280, %r273, %r278; + add.s32 %r281, %r272, %r279; + .loc 2 539 41 // triton_helpers.py:539:41 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + mul.lo.s32 %r282, %r255, %r30; + mul.lo.s32 %r283, %r254, %r30; + mul.lo.s32 %r284, %r257, %r30; + mul.lo.s32 %r285, %r256, %r30; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r286, %r283, 1, 31, -1; + shfl.sync.bfly.b32 %r287, %r282, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r288, %r283, %r286; + add.s32 %r289, %r282, %r287; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r290, %r285, 1, 31, -1; + shfl.sync.bfly.b32 %r291, %r284, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r292, %r285, %r290; + add.s32 %r293, %r284, %r291; + .loc 2 548 23 // triton_helpers.py:548:23 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + mul.lo.s32 %r294, %r267, %r33; + mul.lo.s32 %r295, %r266, %r33; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r296, %r295, 1, 31, -1; + .loc 2 548 23 // triton_helpers.py:548:23 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + mul.lo.s32 %r297, %r269, %r33; + mul.lo.s32 %r298, %r268, %r33; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r299, %r294, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r300, %r294, %r299; + add.s32 %r301, %r295, %r296; + .loc 2 551 23 // triton_helpers.py:551:23 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + mul.lo.s32 %r302, %r269, %r30; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r303, %r298, 1, 31, -1; + .loc 2 548 23 // triton_helpers.py:548:23 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + mul.lo.s32 %r304, %r268, %r30; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r305, %r297, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r306, %r297, %r305; + add.s32 %r307, %r298, %r303; + .loc 2 551 23 // triton_helpers.py:551:23 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + mul.lo.s32 %r308, %r267, %r30; + mul.lo.s32 %r309, %r266, %r30; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r310, %r309, 1, 31, -1; + shfl.sync.bfly.b32 %r311, %r308, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r312, %r308, %r311; + add.s32 %r313, %r309, %r310; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r314, %r304, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r315, %r304, %r314; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + shfl.sync.bfly.b32 %r316, %r302, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + add.s32 %r317, %r316, %r302; + .loc 2 574 22 // triton_helpers.py:574:22 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + setp.lt.s32 %p106, %r277, %r289; + setp.lt.s32 %p107, %r276, %r288; + setp.lt.s32 %p108, %r281, %r293; + setp.lt.s32 %p109, %r280, %r292; + .loc 2 591 21 // triton_helpers.py:591:21 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + setp.eq.b32 %p110, %r276, %r288; + setp.eq.b32 %p111, %r277, %r289; + setp.eq.b32 %p112, %r280, %r292; + setp.eq.b32 %p113, %r281, %r293; + .loc 2 594 40 // triton_helpers.py:594:40 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + setp.gt.s32 %p114, %r301, %r313; + setp.gt.s32 %p115, %r300, %r312; + setp.gt.s32 %p116, %r307, %r315; + setp.gt.s32 %p117, %r306, %r317; + .loc 2 594 29 // triton_helpers.py:594:29 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + and.pred %p118, %p111, %p115; + and.pred %p119, %p110, %p114; + and.pred %p120, %p113, %p117; + and.pred %p121, %p112, %p116; + .loc 2 594 23 // triton_helpers.py:594:23 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + or.pred %p122, %p107, %p119; + or.pred %p123, %p106, %p118; + or.pred %p124, %p109, %p121; + or.pred %p125, %p108, %p120; + .loc 2 600 38 // triton_helpers.py:600:38 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r318, %r276, %r288; + xor.b32 %r319, %r277, %r289; + xor.b32 %r320, %r280, %r292; + xor.b32 %r321, %r281, %r293; + .loc 2 600 46 // triton_helpers.py:600:46 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + selp.b32 %r322, %r319, 0, %p123; + selp.b32 %r323, %r318, 0, %p122; + selp.b32 %r324, %r321, 0, %p125; + selp.b32 %r325, %r320, 0, %p124; + .loc 2 600 15 // triton_helpers.py:600:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r326, %r323, %r254; + xor.b32 %r327, %r322, %r255; + xor.b32 %r328, %r325, %r256; + xor.b32 %r329, %r324, %r257; + .loc 2 601 48 // triton_helpers.py:601:48 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r330, %r300, %r312; + xor.b32 %r331, %r301, %r313; + xor.b32 %r332, %r317, %r306; + xor.b32 %r333, %r307, %r315; + .loc 2 601 59 // triton_helpers.py:601:59 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + selp.b32 %r334, %r331, 0, %p122; + selp.b32 %r335, %r330, 0, %p123; + selp.b32 %r336, %r333, 0, %p124; + selp.b32 %r337, %r332, 0, %p125; + .loc 2 601 22 // triton_helpers.py:601:22 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r338, %r335, %r267; + xor.b32 %r339, %r334, %r266; + xor.b32 %r340, %r337, %r269; + xor.b32 %r341, %r336, %r268; + .loc 2 574 22 // triton_helpers.py:574:22 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + setp.lt.s32 %p126, %r327, %r329; + setp.lt.s32 %p127, %r326, %r328; + .loc 2 591 21 // triton_helpers.py:591:21 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + setp.eq.b32 %p128, %r326, %r328; + setp.eq.b32 %p129, %r327, %r329; + .loc 2 594 40 // triton_helpers.py:594:40 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + setp.gt.s32 %p130, %r339, %r341; + setp.gt.s32 %p131, %r338, %r340; + .loc 2 594 29 // triton_helpers.py:594:29 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + and.pred %p132, %p129, %p131; + and.pred %p133, %p128, %p130; + .loc 2 594 23 // triton_helpers.py:594:23 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + or.pred %p134, %p127, %p133; + or.pred %p135, %p126, %p132; + .loc 2 600 38 // triton_helpers.py:600:38 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r342, %r326, %r328; + xor.b32 %r343, %r327, %r329; + .loc 2 600 46 // triton_helpers.py:600:46 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + selp.b32 %r344, %r343, 0, %p135; + selp.b32 %r345, %r342, 0, %p134; + .loc 2 600 15 // triton_helpers.py:600:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r346, %r345, %r328; + xor.b32 %r347, %r344, %r327; + xor.b32 %r348, %r345, %r326; + xor.b32 %r349, %r344, %r329; + .loc 2 601 48 // triton_helpers.py:601:48 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r350, %r339, %r341; + xor.b32 %r351, %r338, %r340; + .loc 2 601 59 // triton_helpers.py:601:59 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + selp.b32 %r352, %r351, 0, %p135; + selp.b32 %r353, %r350, 0, %p134; + .loc 2 601 22 // triton_helpers.py:601:22 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r354, %r353, %r339; + xor.b32 %r355, %r352, %r338; + xor.b32 %r356, %r353, %r341; + xor.b32 %r357, %r352, %r340; + .loc 2 574 22 // triton_helpers.py:574:22 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + setp.lt.s32 %p136, %r348, %r347; + setp.lt.s32 %p137, %r346, %r349; + .loc 2 591 21 // triton_helpers.py:591:21 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + setp.eq.b32 %p138, %r348, %r347; + setp.eq.b32 %p139, %r349, %r346; + .loc 2 594 40 // triton_helpers.py:594:40 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + setp.gt.s32 %p140, %r354, %r355; + setp.gt.s32 %p141, %r356, %r357; + .loc 2 601 48 // triton_helpers.py:601:48 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r358, %r355, %r354; + xor.b32 %r359, %r357, %r356; + .loc 2 601 59 // triton_helpers.py:601:59 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + selp.b32 %r360, %r359, 0, %p141; + selp.b32 %r361, %r360, 0, %p139; + selp.b32 %r362, %r359, %r361, %p137; + selp.b32 %r363, %r358, 0, %p140; + selp.b32 %r364, %r363, 0, %p138; + selp.b32 %r365, %r358, %r364, %p136; + .loc 2 601 22 // triton_helpers.py:601:22 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:46:71 ] + xor.b32 %r12, %r365, %r355; + xor.b32 %r13, %r365, %r354; + xor.b32 %r11, %r362, %r356; + xor.b32 %r10, %r362, %r357; +$L__tmp6: + .loc 1 54 35 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:54:35 + and.pred %p142, %p2, %p9; + selp.b16 %rs1, 1, 0, %p142; + shl.b16 %rs2, %rs1, 2; + and.pred %p143, %p2, %p8; + selp.b16 %rs3, -1, 0, %p143; + shl.b16 %rs4, %rs3, 3; + or.b16 %rs5, %rs4, %rs2; + and.pred %p144, %p2, %p11; + selp.b16 %rs6, 1, 0, %p144; + and.pred %p145, %p2, %p10; + selp.b16 %rs7, -1, 0, %p145; + shl.b16 %rs8, %rs7, 1; + or.b16 %rs9, %rs6, %rs8; + and.b16 %rs10, %rs9, 3; + or.b16 %rs11, %rs10, %rs5; +$L__tmp7: + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:55:26 ] + and.b16 %rs12, %rs11, 15; + cvt.u32.u16 %r366, %rs12; + popc.b32 %r367, %r366; + cvt.u64.u32 %rd27, %r367; +$L__tmp8: + .loc 1 47 20 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:47:20 + setp.ne.b64 %p146, %rd15, 16384; + setp.ne.b64 %p147, %rd11, 16384; + setp.eq.b64 %p148, %rd15, 16384; + setp.eq.b64 %p149, %rd14, 16384; + setp.eq.b64 %p150, %rd12, 16384; + setp.eq.b64 %p151, %rd11, 16384; + .loc 1 0 0 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:0 + selp.b32 %r368, 1, 0, %p151; + selp.b32 %r369, 1, 0, %p150; + selp.b32 %r370, 1, 0, %p149; + selp.b32 %r371, 1, 0, %p148; +$L__tmp9: + .loc 2 574 22 // triton_helpers.py:574:22 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + and.pred %p152, %p150, %p147; + .loc 2 599 19 // triton_helpers.py:599:19 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + or.pred %p153, %p149, %p146; + .loc 2 600 38 // triton_helpers.py:600:38 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + xor.b32 %r372, %r368, %r369; + xor.b32 %r373, %r370, %r371; + .loc 2 600 46 // triton_helpers.py:600:46 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + selp.b32 %r374, %r372, 0, %p152; + selp.b32 %r375, %r373, 0, %p153; + .loc 2 600 15 // triton_helpers.py:600:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + xor.b32 %r376, %r374, %r368; + xor.b32 %r377, %r374, %r369; + xor.b32 %r378, %r375, %r370; + xor.b32 %r379, %r375, %r371; + .loc 2 601 59 // triton_helpers.py:601:59 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + selp.b32 %r380, %r42, 0, %p152; + selp.b32 %r381, %r43, 0, %p153; + .loc 2 601 22 // triton_helpers.py:601:22 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + xor.b32 %r382, %r380, %r6; + xor.b32 %r383, %r380, %r7; + xor.b32 %r384, %r381, %r8; + xor.b32 %r385, %r381, %r9; + .loc 2 599 28 // triton_helpers.py:599:28 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + setp.ge.u32 %p154, %r376, %r378; + setp.ge.u32 %p155, %r377, %r379; + setp.ne.b32 %p156, %r377, %r379; + setp.ne.b32 %p157, %r378, %r376; + setp.le.u32 %p158, %r383, %r385; + setp.le.u32 %p159, %r382, %r384; + or.pred %p160, %p159, %p157; + or.pred %p161, %p158, %p156; + and.pred %p162, %p155, %p161; + and.pred %p163, %p154, %p160; + xor.pred %p164, %p163, %p4; + xor.pred %p165, %p162, %p4; + .loc 2 600 38 // triton_helpers.py:600:38 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + xor.b32 %r386, %r376, %r378; + xor.b32 %r387, %r379, %r377; + .loc 2 600 46 // triton_helpers.py:600:46 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + selp.b32 %r388, 0, %r387, %p165; + selp.b32 %r389, 0, %r386, %p164; + .loc 2 600 15 // triton_helpers.py:600:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + xor.b32 %r390, %r389, %r376; + xor.b32 %r391, %r388, %r377; + xor.b32 %r392, %r388, %r379; + xor.b32 %r393, %r389, %r378; + .loc 2 601 48 // triton_helpers.py:601:48 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + xor.b32 %r394, %r384, %r382; + xor.b32 %r395, %r385, %r383; + .loc 2 601 59 // triton_helpers.py:601:59 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + selp.b32 %r396, 0, %r394, %p164; + selp.b32 %r397, 0, %r395, %p165; + .loc 2 601 22 // triton_helpers.py:601:22 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + xor.b32 %r398, %r396, %r382; + xor.b32 %r399, %r397, %r383; + xor.b32 %r400, %r396, %r384; + xor.b32 %r401, %r397, %r385; + .loc 2 599 28 // triton_helpers.py:599:28 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + setp.ge.u32 %p166, %r390, %r391; + setp.ne.b32 %p167, %r390, %r391; + setp.le.u32 %p168, %r398, %r399; + or.pred %p169, %p167, %p168; + and.pred %p170, %p166, %p169; + xor.pred %p171, %p170, %p4; + setp.ge.u32 %p172, %r393, %r392; + setp.ne.b32 %p173, %r393, %r392; + setp.le.u32 %p174, %r400, %r401; + or.pred %p175, %p173, %p174; + and.pred %p176, %p172, %p175; + xor.pred %p177, %p176, %p4; + .loc 2 600 38 // triton_helpers.py:600:38 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + xor.b32 %r402, %r390, %r391; + xor.b32 %r403, %r393, %r392; + .loc 2 600 46 // triton_helpers.py:600:46 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + selp.b32 %r404, 0, %r402, %p171; + selp.b32 %r405, 0, %r403, %p177; + .loc 2 600 15 // triton_helpers.py:600:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + xor.b32 %r406, %r404, %r390; + xor.b32 %r407, %r404, %r391; + xor.b32 %r408, %r405, %r393; + xor.b32 %r409, %r405, %r392; + .loc 2 601 48 // triton_helpers.py:601:48 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + xor.b32 %r410, %r398, %r399; + xor.b32 %r411, %r400, %r401; + .loc 2 601 59 // triton_helpers.py:601:59 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + selp.b32 %r412, 0, %r410, %p171; + selp.b32 %r413, 0, %r411, %p177; + .loc 2 538 40 // triton_helpers.py:538:40 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + mul.lo.s32 %r414, %r406, %r33; + mul.lo.s32 %r415, %r407, %r33; + mul.lo.s32 %r416, %r408, %r33; + mul.lo.s32 %r417, %r409, %r33; + .loc 2 539 41 // triton_helpers.py:539:41 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + mul.lo.s32 %r418, %r406, %r30; + mul.lo.s32 %r419, %r407, %r30; + mul.lo.s32 %r420, %r408, %r30; + mul.lo.s32 %r421, %r409, %r30; + .loc 2 601 22 // triton_helpers.py:601:22 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + xor.b32 %r422, %r413, %r401; + xor.b32 %r423, %r413, %r400; + xor.b32 %r424, %r412, %r399; + xor.b32 %r425, %r412, %r398; + .loc 2 548 23 // triton_helpers.py:548:23 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + mul.lo.s32 %r426, %r425, %r33; + mul.lo.s32 %r427, %r424, %r33; + mul.lo.s32 %r428, %r423, %r33; + mul.lo.s32 %r429, %r422, %r33; + .loc 2 551 23 // triton_helpers.py:551:23 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + mul.lo.s32 %r430, %r425, %r30; + mul.lo.s32 %r431, %r424, %r30; + mul.lo.s32 %r432, %r423, %r30; + mul.lo.s32 %r433, %r422, %r30; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + shfl.sync.bfly.b32 %r434, %r414, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + add.s32 %r435, %r434, %r414; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + shfl.sync.bfly.b32 %r436, %r415, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + add.s32 %r437, %r436, %r415; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + shfl.sync.bfly.b32 %r438, %r416, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + add.s32 %r439, %r438, %r416; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + shfl.sync.bfly.b32 %r440, %r417, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + add.s32 %r441, %r440, %r417; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + shfl.sync.bfly.b32 %r442, %r418, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + add.s32 %r443, %r442, %r418; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + shfl.sync.bfly.b32 %r444, %r419, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + add.s32 %r445, %r444, %r419; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + shfl.sync.bfly.b32 %r446, %r420, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + add.s32 %r447, %r446, %r420; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + shfl.sync.bfly.b32 %r448, %r421, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + add.s32 %r449, %r448, %r421; + .loc 2 599 28 // triton_helpers.py:599:28 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + setp.ge.s32 %p178, %r437, %r445; + setp.ne.b32 %p179, %r437, %r445; + setp.ge.s32 %p180, %r435, %r443; + setp.ge.s32 %p181, %r439, %r447; + setp.ne.b32 %p182, %r439, %r447; + setp.ne.b32 %p183, %r435, %r443; + setp.ge.s32 %p184, %r441, %r449; + setp.ne.b32 %p185, %r441, %r449; + .loc 2 600 38 // triton_helpers.py:600:38 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + xor.b32 %r450, %r443, %r435; + xor.b32 %r451, %r445, %r437; + xor.b32 %r452, %r447, %r439; + xor.b32 %r453, %r449, %r441; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + shfl.sync.bfly.b32 %r454, %r426, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + add.s32 %r455, %r454, %r426; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + shfl.sync.bfly.b32 %r456, %r427, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + add.s32 %r457, %r456, %r427; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + shfl.sync.bfly.b32 %r458, %r428, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + add.s32 %r459, %r458, %r428; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + shfl.sync.bfly.b32 %r460, %r429, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + add.s32 %r461, %r460, %r429; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + shfl.sync.bfly.b32 %r462, %r430, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + add.s32 %r463, %r462, %r430; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + shfl.sync.bfly.b32 %r464, %r431, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + add.s32 %r465, %r464, %r431; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + shfl.sync.bfly.b32 %r466, %r432, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + add.s32 %r467, %r466, %r432; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + shfl.sync.bfly.b32 %r468, %r433, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + add.s32 %r469, %r468, %r433; + .loc 2 599 28 // triton_helpers.py:599:28 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + setp.le.s32 %p186, %r459, %r467; + setp.le.s32 %p187, %r455, %r463; + or.pred %p188, %p183, %p187; + or.pred %p189, %p182, %p186; + and.pred %p190, %p181, %p189; + and.pred %p191, %p180, %p188; + xor.pred %p192, %p191, %p5; + xor.pred %p193, %p190, %p5; + .loc 2 600 46 // triton_helpers.py:600:46 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + selp.b32 %r470, 0, %r452, %p193; + selp.b32 %r471, 0, %r450, %p192; + .loc 2 600 15 // triton_helpers.py:600:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + xor.b32 %r472, %r471, %r406; + xor.b32 %r473, %r470, %r408; + .loc 2 599 28 // triton_helpers.py:599:28 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + setp.le.s32 %p194, %r461, %r469; + setp.le.s32 %p195, %r457, %r465; + or.pred %p196, %p179, %p195; + or.pred %p197, %p185, %p194; + and.pred %p198, %p184, %p197; + and.pred %p199, %p178, %p196; + xor.pred %p200, %p199, %p5; + xor.pred %p201, %p198, %p5; + .loc 2 600 46 // triton_helpers.py:600:46 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + selp.b32 %r474, 0, %r453, %p201; + selp.b32 %r475, 0, %r451, %p200; + .loc 2 600 15 // triton_helpers.py:600:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + xor.b32 %r476, %r475, %r407; + xor.b32 %r477, %r474, %r409; + .loc 2 601 48 // triton_helpers.py:601:48 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + xor.b32 %r478, %r463, %r455; + xor.b32 %r479, %r465, %r457; + xor.b32 %r480, %r467, %r459; + xor.b32 %r481, %r469, %r461; + .loc 2 601 59 // triton_helpers.py:601:59 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + selp.b32 %r482, 0, %r481, %p201; + selp.b32 %r483, 0, %r479, %p200; + selp.b32 %r484, 0, %r480, %p193; + selp.b32 %r485, 0, %r478, %p192; + .loc 2 601 22 // triton_helpers.py:601:22 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + xor.b32 %r486, %r485, %r425; + xor.b32 %r487, %r484, %r423; + xor.b32 %r488, %r483, %r424; + xor.b32 %r489, %r482, %r422; + .loc 2 599 28 // triton_helpers.py:599:28 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + setp.ge.s32 %p202, %r472, %r473; + setp.ne.b32 %p203, %r472, %r473; + setp.le.s32 %p204, %r488, %r489; + setp.le.s32 %p205, %r486, %r487; + or.pred %p206, %p203, %p205; + and.pred %p207, %p202, %p206; + setp.ge.s32 %p208, %r476, %r477; + setp.ne.b32 %p209, %r476, %r477; + or.pred %p210, %p209, %p204; + and.pred %p211, %p208, %p210; + xor.pred %p212, %p211, %p5; + xor.pred %p213, %p207, %p5; + .loc 2 600 38 // triton_helpers.py:600:38 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + xor.b32 %r490, %r473, %r472; + xor.b32 %r491, %r477, %r476; + .loc 2 600 46 // triton_helpers.py:600:46 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + selp.b32 %r492, 0, %r490, %p213; + selp.b32 %r493, 0, %r491, %p212; + .loc 2 600 15 // triton_helpers.py:600:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + xor.b32 %r494, %r492, %r472; + xor.b32 %r495, %r493, %r476; + xor.b32 %r496, %r492, %r473; + xor.b32 %r497, %r493, %r477; + .loc 2 601 48 // triton_helpers.py:601:48 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + xor.b32 %r498, %r487, %r486; + xor.b32 %r499, %r489, %r488; + .loc 2 601 59 // triton_helpers.py:601:59 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + selp.b32 %r500, 0, %r499, %p212; + selp.b32 %r501, 0, %r498, %p213; + .loc 2 601 22 // triton_helpers.py:601:22 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + xor.b32 %r502, %r501, %r487; + xor.b32 %r503, %r500, %r488; + xor.b32 %r504, %r500, %r489; + xor.b32 %r505, %r501, %r486; + .loc 2 599 28 // triton_helpers.py:599:28 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + setp.ge.s32 %p214, %r494, %r495; + setp.ne.b32 %p215, %r494, %r495; + setp.le.s32 %p216, %r505, %r503; + or.pred %p217, %p215, %p216; + and.pred %p218, %p214, %p217; + xor.pred %p219, %p218, %p5; + setp.ge.s32 %p220, %r496, %r497; + setp.ne.b32 %p221, %r496, %r497; + setp.le.s32 %p222, %r502, %r504; + or.pred %p223, %p221, %p222; + and.pred %p224, %p220, %p223; + xor.pred %p225, %p224, %p5; + .loc 2 600 38 // triton_helpers.py:600:38 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + xor.b32 %r506, %r495, %r494; + xor.b32 %r507, %r497, %r496; + .loc 2 600 46 // triton_helpers.py:600:46 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + selp.b32 %r508, 0, %r506, %p219; + selp.b32 %r509, 0, %r507, %p225; + .loc 2 600 15 // triton_helpers.py:600:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + xor.b32 %r510, %r509, %r496; + xor.b32 %r511, %r508, %r494; + xor.b32 %r512, %r509, %r497; + xor.b32 %r513, %r508, %r495; + .loc 2 601 48 // triton_helpers.py:601:48 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + xor.b32 %r514, %r503, %r505; + xor.b32 %r515, %r504, %r502; + .loc 2 601 59 // triton_helpers.py:601:59 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + selp.b32 %r516, 0, %r515, %p225; + selp.b32 %r517, 0, %r514, %p219; + .loc 2 601 22 // triton_helpers.py:601:22 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + xor.b32 %r522, %r517, %r503; + xor.b32 %r523, %r516, %r502; + xor.b32 %r524, %r517, %r505; + xor.b32 %r525, %r516, %r504; + .loc 2 538 40 // triton_helpers.py:538:40 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + mul.lo.s32 %r526, %r511, %r34; + mul.lo.s32 %r527, %r513, %r34; + mul.lo.s32 %r528, %r510, %r34; + mul.lo.s32 %r529, %r512, %r34; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + shfl.sync.bfly.b32 %r530, %r526, 2, 31, -1; + shfl.sync.bfly.b32 %r531, %r527, 2, 31, -1; + shfl.sync.bfly.b32 %r532, %r528, 2, 31, -1; + shfl.sync.bfly.b32 %r533, %r529, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + add.s32 %r534, %r527, %r531; + add.s32 %r535, %r529, %r533; + add.s32 %r536, %r526, %r530; + add.s32 %r537, %r528, %r532; + .loc 2 539 41 // triton_helpers.py:539:41 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + mul.lo.s32 %r538, %r511, %r32; + mul.lo.s32 %r539, %r513, %r32; + mul.lo.s32 %r540, %r510, %r32; + mul.lo.s32 %r541, %r512, %r32; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + shfl.sync.bfly.b32 %r542, %r538, 2, 31, -1; + shfl.sync.bfly.b32 %r543, %r539, 2, 31, -1; + shfl.sync.bfly.b32 %r544, %r540, 2, 31, -1; + shfl.sync.bfly.b32 %r545, %r541, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + add.s32 %r546, %r539, %r543; + add.s32 %r547, %r541, %r545; + add.s32 %r548, %r538, %r542; + add.s32 %r549, %r540, %r544; + .loc 2 548 23 // triton_helpers.py:548:23 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + mul.lo.s32 %r550, %r525, %r34; + mul.lo.s32 %r551, %r523, %r34; + mul.lo.s32 %r552, %r522, %r34; + mul.lo.s32 %r553, %r524, %r34; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + shfl.sync.bfly.b32 %r554, %r553, 2, 31, -1; + shfl.sync.bfly.b32 %r555, %r552, 2, 31, -1; + shfl.sync.bfly.b32 %r556, %r551, 2, 31, -1; + shfl.sync.bfly.b32 %r557, %r550, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + add.s32 %r558, %r551, %r556; + add.s32 %r559, %r553, %r554; + add.s32 %r560, %r550, %r557; + add.s32 %r561, %r552, %r555; + .loc 2 551 23 // triton_helpers.py:551:23 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + mul.lo.s32 %r562, %r525, %r32; + mul.lo.s32 %r563, %r523, %r32; + mul.lo.s32 %r564, %r522, %r32; + mul.lo.s32 %r565, %r524, %r32; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + shfl.sync.bfly.b32 %r566, %r565, 2, 31, -1; + shfl.sync.bfly.b32 %r567, %r564, 2, 31, -1; + shfl.sync.bfly.b32 %r568, %r563, 2, 31, -1; + shfl.sync.bfly.b32 %r569, %r562, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + add.s32 %r570, %r563, %r568; + add.s32 %r571, %r565, %r566; + add.s32 %r572, %r562, %r569; + add.s32 %r573, %r564, %r567; + .loc 2 574 22 // triton_helpers.py:574:22 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + setp.lt.s32 %p226, %r537, %r549; + setp.lt.s32 %p227, %r536, %r548; + setp.lt.s32 %p228, %r535, %r547; + setp.lt.s32 %p229, %r534, %r546; + .loc 2 591 21 // triton_helpers.py:591:21 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + setp.eq.b32 %p230, %r534, %r546; + setp.eq.b32 %p231, %r535, %r547; + setp.eq.b32 %p232, %r536, %r548; + setp.eq.b32 %p233, %r537, %r549; + .loc 2 594 40 // triton_helpers.py:594:40 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + setp.gt.s32 %p234, %r561, %r573; + setp.gt.s32 %p235, %r560, %r572; + setp.gt.s32 %p236, %r559, %r571; + setp.gt.s32 %p237, %r558, %r570; + .loc 2 594 29 // triton_helpers.py:594:29 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + and.pred %p238, %p233, %p237; + and.pred %p239, %p232, %p236; + and.pred %p240, %p231, %p235; + and.pred %p241, %p230, %p234; + .loc 2 594 23 // triton_helpers.py:594:23 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + or.pred %p242, %p229, %p241; + or.pred %p243, %p228, %p240; + or.pred %p244, %p227, %p239; + or.pred %p245, %p226, %p238; + .loc 2 600 38 // triton_helpers.py:600:38 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + xor.b32 %r574, %r534, %r546; + xor.b32 %r575, %r535, %r547; + xor.b32 %r576, %r536, %r548; + xor.b32 %r577, %r537, %r549; + .loc 2 600 46 // triton_helpers.py:600:46 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + selp.b32 %r578, %r577, 0, %p245; + selp.b32 %r579, %r576, 0, %p244; + selp.b32 %r580, %r575, 0, %p243; + selp.b32 %r581, %r574, 0, %p242; + .loc 2 600 15 // triton_helpers.py:600:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + xor.b32 %r582, %r581, %r513; + xor.b32 %r583, %r579, %r511; + xor.b32 %r584, %r578, %r510; + xor.b32 %r585, %r580, %r512; + .loc 2 601 48 // triton_helpers.py:601:48 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + xor.b32 %r586, %r561, %r573; + xor.b32 %r587, %r558, %r570; + xor.b32 %r588, %r560, %r572; + xor.b32 %r589, %r559, %r571; + .loc 2 601 59 // triton_helpers.py:601:59 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + selp.b32 %r590, %r589, 0, %p244; + selp.b32 %r591, %r588, 0, %p243; + selp.b32 %r592, %r587, 0, %p245; + selp.b32 %r593, %r586, 0, %p242; + .loc 2 601 22 // triton_helpers.py:601:22 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + xor.b32 %r598, %r593, %r522; + xor.b32 %r599, %r590, %r524; + xor.b32 %r600, %r591, %r525; + xor.b32 %r601, %r592, %r523; + .loc 2 538 40 // triton_helpers.py:538:40 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + mul.lo.s32 %r602, %r585, %r33; + mul.lo.s32 %r603, %r584, %r33; + mul.lo.s32 %r604, %r582, %r33; + mul.lo.s32 %r605, %r583, %r33; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + shfl.sync.bfly.b32 %r606, %r605, 1, 31, -1; + shfl.sync.bfly.b32 %r607, %r604, 1, 31, -1; + shfl.sync.bfly.b32 %r608, %r603, 1, 31, -1; + shfl.sync.bfly.b32 %r609, %r602, 1, 31, -1; + .loc 2 539 41 // triton_helpers.py:539:41 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + mul.lo.s32 %r614, %r585, %r30; + mul.lo.s32 %r615, %r584, %r30; + mul.lo.s32 %r616, %r582, %r30; + mul.lo.s32 %r617, %r583, %r30; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + shfl.sync.bfly.b32 %r618, %r617, 1, 31, -1; + shfl.sync.bfly.b32 %r619, %r616, 1, 31, -1; + shfl.sync.bfly.b32 %r620, %r615, 1, 31, -1; + shfl.sync.bfly.b32 %r621, %r614, 1, 31, -1; + .loc 2 548 23 // triton_helpers.py:548:23 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + mul.lo.s32 %r626, %r599, %r33; + mul.lo.s32 %r627, %r598, %r33; + mul.lo.s32 %r628, %r601, %r33; + mul.lo.s32 %r629, %r600, %r33; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + shfl.sync.bfly.b32 %r630, %r626, 1, 31, -1; + shfl.sync.bfly.b32 %r631, %r627, 1, 31, -1; + shfl.sync.bfly.b32 %r632, %r628, 1, 31, -1; + shfl.sync.bfly.b32 %r633, %r629, 1, 31, -1; + .loc 2 551 23 // triton_helpers.py:551:23 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + mul.lo.s32 %r638, %r599, %r30; + mul.lo.s32 %r639, %r598, %r30; + mul.lo.s32 %r640, %r601, %r30; + mul.lo.s32 %r641, %r600, %r30; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:51:71 ] + shfl.sync.bfly.b32 %r642, %r638, 1, 31, -1; + shfl.sync.bfly.b32 %r643, %r639, 1, 31, -1; + shfl.sync.bfly.b32 %r644, %r640, 1, 31, -1; + shfl.sync.bfly.b32 %r645, %r641, 1, 31, -1; +$L__tmp10: + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:55:26 ] + shfl.sync.bfly.b32 %r720, %r367, 2, 31, -1; + mov.b32 %r721, 0; + shfl.sync.bfly.b32 %r722, %r721, 2, 31, -1; + cvt.u64.u32 %rd28, %r720; + cvt.u64.u32 %rd29, %r722; + shl.b64 %rd30, %rd29, 32; + or.b64 %rd31, %rd28, %rd30; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:55:26 ] + add.s64 %rd32, %rd27, %rd31; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:55:26 ] + mov.b64 {_, %r723}, %rd32; + cvt.u32.u64 %r724, %rd32; + shfl.sync.bfly.b32 %r725, %r724, 1, 31, -1; + shfl.sync.bfly.b32 %r726, %r723, 1, 31, -1; + cvt.u64.u32 %rd33, %r725; + cvt.u64.u32 %rd34, %r726; + shl.b64 %rd35, %rd34, 32; + or.b64 %rd36, %rd33, %rd35; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:55:26 ] + add.s64 %rd37, %rd32, %rd36; +$L__tmp11: + .loc 1 60 21 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:60:21 + st.shared.b64 [%r38], %rd37; + bar.sync 0; + ld.shared.b64 %rd2, [%r40]; + .loc 1 58 35 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:58:35 + and.pred %p282, %p2, %p149; + selp.b16 %rs13, 1, 0, %p282; + shl.b16 %rs14, %rs13, 2; + and.pred %p283, %p2, %p148; + selp.b16 %rs15, -1, 0, %p283; + shl.b16 %rs16, %rs15, 3; + or.b16 %rs17, %rs16, %rs14; + and.pred %p284, %p2, %p151; + selp.b16 %rs18, 1, 0, %p284; + and.pred %p285, %p2, %p150; + selp.b16 %rs19, -1, 0, %p285; + shl.b16 %rs20, %rs19, 1; + or.b16 %rs21, %rs18, %rs20; + and.b16 %rs22, %rs21, 3; + or.b16 %rs23, %rs22, %rs17; +$L__tmp12: + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:59:26 ] + and.b16 %rs24, %rs23, 15; + cvt.u32.u16 %r727, %rs24; + popc.b32 %r728, %r727; + cvt.u64.u32 %rd38, %r728; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:59:26 ] + shfl.sync.bfly.b32 %r729, %r728, 2, 31, -1; + shfl.sync.bfly.b32 %r730, %r721, 2, 31, -1; + cvt.u64.u32 %rd39, %r729; + cvt.u64.u32 %rd40, %r730; + shl.b64 %rd41, %rd40, 32; + or.b64 %rd42, %rd39, %rd41; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:59:26 ] + add.s64 %rd43, %rd38, %rd42; + .loc 3 291 36 // standard.py:291:36 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:59:26 ] + mov.b64 {_, %r731}, %rd43; + cvt.u32.u64 %r732, %rd43; + shfl.sync.bfly.b32 %r733, %r732, 1, 31, -1; + shfl.sync.bfly.b32 %r734, %r731, 1, 31, -1; + cvt.u64.u32 %rd44, %r733; + cvt.u64.u32 %rd45, %r734; + shl.b64 %rd46, %rd45, 32; + or.b64 %rd47, %rd44, %rd46; + .loc 3 261 15 // standard.py:261:15 @[ csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:59:26 ] + add.s64 %rd3, %rd43, %rd47; +$L__tmp13: + .loc 1 61 21 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:61:21 + bar.sync 0; + st.shared.b64 [%r38], %rd3; + bar.sync 0; + .loc 1 60 21 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:60:21 + cvt.u32.u64 %r735, %rd37; + .loc 1 64 19 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:64:19 + setp.lt.s32 %p286, %r7, %r735; + setp.lt.s32 %p287, %r6, %r735; + setp.lt.s32 %p288, %r8, %r735; + setp.lt.s32 %p289, %r9, %r735; + .loc 1 66 35 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:66:35 + selp.b32 %r736, %r10, 16, %p289; + selp.b32 %r737, %r11, 16, %p288; + selp.b32 %r738, %r13, 16, %p287; + selp.b32 %r739, %r12, 16, %p286; + .loc 1 68 20 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:68:20 + add.s32 %r740, %r739, 17; + add.s32 %r741, %r738, 17; + add.s32 %r742, %r737, 17; + add.s32 %r743, %r736, 17; + .loc 1 69 20 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:69:20 + setp.lt.s32 %p290, %r739, 0; + setp.lt.s32 %p291, %r738, 0; + setp.lt.s32 %p292, %r737, 0; + setp.lt.s32 %p293, %r736, 0; + .loc 1 70 35 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:70:35 + selp.b32 %r18, %r743, %r736, %p293; + selp.b32 %r19, %r742, %r737, %p292; + selp.b32 %r21, %r741, %r738, %p291; + selp.b32 %r20, %r740, %r739, %p290; + .loc 1 71 63 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:71:63 + max.u32 %r744, %r21, %r20; + max.u32 %r745, %r19, %r744; + max.u32 %r746, %r18, %r745; + setp.lt.u32 %p294, %r746, 17; + or.pred %p295, %p3, %p294; + @%p295 bra $L__BB0_2; + bra.uni $L__BB0_1; +$L__BB0_2: + .loc 1 0 0 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:0 + xor.b32 %r518, %r501, %r517; + xor.b32 %r519, %r500, %r516; + xor.b32 %r520, %r501, %r516; + xor.b32 %r521, %r500, %r517; + xor.b32 %r594, %r521, %r593; + xor.b32 %r595, %r520, %r592; + xor.b32 %r596, %r519, %r591; + xor.b32 %r597, %r518, %r590; + add.s32 %r610, %r604, %r607; + add.s32 %r611, %r602, %r609; + add.s32 %r612, %r605, %r606; + add.s32 %r613, %r603, %r608; + add.s32 %r622, %r616, %r619; + add.s32 %r623, %r614, %r621; + add.s32 %r624, %r617, %r618; + add.s32 %r625, %r615, %r620; + add.s32 %r634, %r628, %r632; + add.s32 %r635, %r627, %r631; + add.s32 %r636, %r626, %r630; + add.s32 %r637, %r629, %r633; + add.s32 %r646, %r644, %r640; + add.s32 %r647, %r643, %r639; + add.s32 %r648, %r642, %r638; + add.s32 %r649, %r645, %r641; + setp.lt.s32 %p246, %r613, %r625; + setp.lt.s32 %p247, %r612, %r624; + setp.lt.s32 %p248, %r611, %r623; + setp.lt.s32 %p249, %r610, %r622; + setp.eq.b32 %p250, %r610, %r622; + setp.eq.b32 %p251, %r611, %r623; + setp.eq.b32 %p252, %r612, %r624; + setp.eq.b32 %p253, %r613, %r625; + setp.gt.s32 %p254, %r635, %r647; + setp.gt.s32 %p255, %r637, %r649; + setp.gt.s32 %p256, %r636, %r648; + setp.gt.s32 %p257, %r634, %r646; + and.pred %p258, %p253, %p257; + and.pred %p259, %p252, %p256; + and.pred %p260, %p251, %p255; + and.pred %p261, %p250, %p254; + or.pred %p262, %p249, %p261; + or.pred %p263, %p248, %p260; + or.pred %p264, %p247, %p259; + or.pred %p265, %p246, %p258; + xor.b32 %r650, %r610, %r622; + xor.b32 %r651, %r611, %r623; + xor.b32 %r652, %r612, %r624; + xor.b32 %r653, %r613, %r625; + selp.b32 %r654, %r653, 0, %p265; + selp.b32 %r655, %r652, 0, %p264; + selp.b32 %r656, %r651, 0, %p263; + selp.b32 %r657, %r650, 0, %p262; + xor.b32 %r658, %r654, %r584; + xor.b32 %r659, %r655, %r583; + xor.b32 %r660, %r657, %r582; + xor.b32 %r661, %r656, %r585; + xor.b32 %r662, %r646, %r634; + xor.b32 %r663, %r647, %r635; + xor.b32 %r664, %r648, %r636; + xor.b32 %r665, %r649, %r637; + selp.b32 %r666, %r663, 0, %p262; + selp.b32 %r667, %r662, 0, %p265; + selp.b32 %r668, %r665, 0, %p263; + selp.b32 %r669, %r664, 0, %p264; + xor.b32 %r670, %r597, %r669; + xor.b32 %r671, %r596, %r668; + xor.b32 %r672, %r669, %r599; + xor.b32 %r673, %r667, %r601; + xor.b32 %r674, %r666, %r598; + xor.b32 %r675, %r668, %r600; + setp.lt.s32 %p266, %r659, %r658; + setp.lt.s32 %p267, %r660, %r661; + setp.eq.b32 %p268, %r659, %r658; + setp.eq.b32 %p269, %r660, %r661; + setp.gt.s32 %p270, %r672, %r673; + setp.gt.s32 %p271, %r674, %r675; + and.pred %p272, %p268, %p270; + and.pred %p273, %p269, %p271; + or.pred %p274, %p266, %p272; + or.pred %p275, %p267, %p273; + xor.b32 %r676, %r658, %r659; + xor.b32 %r677, %r661, %r660; + selp.b32 %r678, %r677, 0, %p275; + selp.b32 %r679, %r676, 0, %p274; + xor.b32 %r680, %r679, %r658; + xor.b32 %r681, %r678, %r660; + xor.b32 %r682, %r679, %r659; + xor.b32 %r683, %r678, %r661; + xor.b32 %r684, %r675, %r674; + xor.b32 %r685, %r673, %r672; + selp.b32 %r686, %r685, 0, %p274; + selp.b32 %r687, %r684, 0, %p275; + xor.b32 %r688, %r671, %r687; + xor.b32 %r689, %r667, %r686; + xor.b32 %r690, %r689, %r595; + xor.b32 %r691, %r666, %r687; + xor.b32 %r692, %r691, %r594; + xor.b32 %r693, %r670, %r686; + xor.b32 %r694, %r687, %r675; + xor.b32 %r695, %r687, %r674; + xor.b32 %r696, %r686, %r673; + xor.b32 %r697, %r686, %r672; + setp.lt.s32 %p276, %r682, %r681; + setp.lt.s32 %p277, %r680, %r683; + setp.eq.b32 %p278, %r681, %r682; + setp.eq.b32 %p279, %r680, %r683; + setp.gt.s32 %p280, %r697, %r695; + setp.gt.s32 %p281, %r696, %r694; + xor.b32 %r698, %r689, %r668; + xor.b32 %r699, %r691, %r669; + xor.b32 %r700, %r699, %r522; + xor.b32 %r701, %r698, %r523; + xor.b32 %r702, %r701, %r592; + xor.b32 %r703, %r700, %r593; + xor.b32 %r704, %r703, %r524; + xor.b32 %r705, %r702, %r525; + xor.b32 %r706, %r705, %r591; + xor.b32 %r707, %r704, %r590; + xor.b32 %r708, %r707, %r686; + xor.b32 %r709, %r706, %r687; + selp.b32 %r710, %r709, 0, %p281; + selp.b32 %r711, %r710, 0, %p279; + selp.b32 %r712, %r709, %r711, %p277; + selp.b32 %r713, %r708, 0, %p280; + selp.b32 %r714, %r713, 0, %p278; + selp.b32 %r715, %r708, %r714, %p276; + xor.b32 %r716, %r693, %r715; + xor.b32 %r717, %r692, %r715; + xor.b32 %r718, %r690, %r712; + xor.b32 %r719, %r688, %r712; + xor.b32 %r17, %r719, %r489; + xor.b32 %r16, %r718, %r487; + xor.b32 %r15, %r717, %r488; + xor.b32 %r14, %r716, %r486; + .loc 1 61 21 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:61:21 + ld.shared.b64 %rd4, [%r40]; + cvt.u32.u64 %r747, %rd3; + .loc 1 71 63 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:71:63 + bar.sync 0; + .loc 1 75 19 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:75:19 + setp.lt.s32 %p297, %r8, %r747; + setp.lt.s32 %p298, %r9, %r747; + setp.lt.s32 %p299, %r6, %r747; + setp.lt.s32 %p300, %r7, %r747; + .loc 1 76 35 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:76:35 + selp.b32 %r748, %r15, 16, %p300; + selp.b32 %r749, %r14, 16, %p299; + selp.b32 %r750, %r17, 16, %p298; + selp.b32 %r751, %r16, 16, %p297; + .loc 1 77 20 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:77:20 + add.s32 %r752, %r751, 17; + add.s32 %r753, %r750, 17; + add.s32 %r754, %r749, 17; + add.s32 %r755, %r748, 17; + .loc 1 78 20 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:78:20 + setp.lt.s32 %p301, %r751, 0; + setp.lt.s32 %p302, %r750, 0; + setp.lt.s32 %p303, %r749, 0; + setp.lt.s32 %p304, %r748, 0; + .loc 1 79 35 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:79:35 + selp.b32 %r23, %r755, %r748, %p304; + selp.b32 %r22, %r754, %r749, %p303; + selp.b32 %r25, %r753, %r750, %p302; + selp.b32 %r24, %r752, %r751, %p301; + .loc 1 80 38 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:80:38 + setp.gt.u32 %p305, %r24, 16; + selp.b16 %rs25, 1, 0, %p305; + shl.b16 %rs26, %rs25, 2; + setp.gt.u32 %p306, %r25, 16; + selp.b16 %rs27, -1, 0, %p306; + shl.b16 %rs28, %rs27, 3; + or.b16 %rs29, %rs28, %rs26; + setp.gt.u32 %p307, %r22, 16; + selp.b16 %rs30, 1, 0, %p307; + setp.gt.u32 %p308, %r23, 16; + selp.b16 %rs31, -1, 0, %p308; + shl.b16 %rs32, %rs31, 1; + or.b16 %rs33, %rs30, %rs32; + and.b16 %rs34, %rs33, 3; + or.b16 %rs35, %rs34, %rs29; + .loc 1 80 63 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:80:63 + and.b16 %rs36, %rs35, 15; + setp.eq.b16 %p309, %rs36, 0; + or.pred %p310, %p3, %p309; + @%p310 bra $L__BB0_4; + bra.uni $L__BB0_3; +$L__BB0_4: + .loc 1 0 63 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:0:63 + ld.param.b64 %rd10, [triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_6]; + ld.param.b64 %rd9, [triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_5]; + ld.param.b64 %rd8, [triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_4]; + ld.param.b64 %rd7, [triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_3]; + ld.param.b64 %rd6, [triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_2]; + ld.param.b64 %rd5, [triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_1]; + cvt.s64.s32 %rd1, %r41; + .loc 1 61 21 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:61:21 + cvt.u32.u64 %r757, %rd4; + .loc 1 60 21 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:60:21 + cvt.u32.u64 %r756, %rd2; + .loc 1 25 23 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:25:23 + or.b32 %r774, %r1, %r3; + .loc 1 26 21 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:26:21 + setp.lt.s32 %p314, %r774, 128; + .loc 1 80 63 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:80:63 + bar.sync 0; + .loc 1 81 25 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:81:25 + mul.wide.s32 %rd60, %r774, 4; + add.s64 %rd48, %rd5, %rd60; + .loc 1 81 37 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:81:37 + and.b32 %r775, %r2, 96; + setp.eq.b32 %p323, %r775, 0; + and.pred %p311, %p323, %p314; + // begin inline asm + @%p311 st.global.b32 [ %rd48 + 0 ], { %r756 }; + // end inline asm + .loc 1 82 25 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:82:25 + add.s64 %rd49, %rd6, %rd60; + .loc 1 82 37 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:82:37 + // begin inline asm + @%p311 st.global.b32 [ %rd49 + 0 ], { %r757 }; + // end inline asm + .loc 1 83 25 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:83:25 + shl.b64 %rd61, %rd1, 2; + add.s64 %rd50, %rd7, %rd61; + .loc 1 83 47 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:83:47 + // begin inline asm + @%p2 st.global.v4.b32 [ %rd50 + 0 ], { %r13, %r12, %r11, %r10 }; + // end inline asm + .loc 1 84 52 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:84:52 + mul.lo.s32 %r776, %r4, 17; + .loc 1 84 49 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:84:49 + add.s32 %r777, %r21, %r776; + add.s32 %r778, %r20, %r776; + add.s32 %r779, %r19, %r776; + add.s32 %r780, %r18, %r776; + .loc 1 84 25 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:84:25 + mad.wide.s32 %rd62, %r777, 4, %rd8; + mad.wide.s32 %rd63, %r778, 4, %rd8; + mad.wide.s32 %rd64, %r779, 4, %rd8; + mad.wide.s32 %rd65, %r780, 4, %rd8; + .loc 1 84 85 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:84:85 + bar.sync 0; + and.b32 %r781, %r2, 48; + shl.b32 %r782, %r781, 6; + shl.b32 %r783, %r2, 3; + and.b32 %r784, %r783, 120; + shr.u32 %r785, %r781, 1; + shl.b32 %r786, %r2, 1; + and.b32 %r787, %r786, 128; + or.b32 %r788, %r782, %r784; + xor.b32 %r789, %r788, %r785; + add.s32 %r791, %r37, %r787; + add.s32 %r792, %r791, %r789; + st.shared.b64 [%r792], %rd62; + st.shared.b64 [%r792+256], %rd63; + st.shared.b64 [%r792+512], %rd64; + st.shared.b64 [%r792+768], %rd65; + bar.sync 0; + and.b32 %r793, %r2, 12; + shl.b32 %r794, %r793, 8; + shl.b32 %r795, %r5, 5; + and.b32 %r796, %r783, 896; + shl.b32 %r797, %r793, 1; + or.b32 %r798, %r794, %r795; + or.b32 %r799, %r798, %r796; + or.b32 %r800, %r799, %r797; + add.s32 %r801, %r37, %r800; + ld.shared.b64 %rd51, [%r801]; + xor.b32 %r802, %r800, 8; + add.s32 %r803, %r37, %r802; + ld.shared.b64 %rd52, [%r803]; + xor.b32 %r804, %r800, 16; + add.s32 %r805, %r37, %r804; + ld.shared.b64 %rd53, [%r805]; + xor.b32 %r806, %r800, 24; + add.s32 %r807, %r37, %r806; + ld.shared.b64 %rd54, [%r807]; + mov.b32 %r762, 1; + // begin inline asm + @%p314 st.global.b32 [ %rd51 + 0 ], { %r762 }; + // end inline asm + // begin inline asm + @%p314 st.global.b32 [ %rd52 + 0 ], { %r762 }; + // end inline asm + // begin inline asm + @%p314 st.global.b32 [ %rd53 + 0 ], { %r762 }; + // end inline asm + // begin inline asm + @%p314 st.global.b32 [ %rd54 + 0 ], { %r762 }; + // end inline asm + .loc 1 85 25 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:85:25 + add.s64 %rd55, %rd9, %rd61; + .loc 1 85 47 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:85:47 + // begin inline asm + @%p2 st.global.v4.b32 [ %rd55 + 0 ], { %r14, %r15, %r16, %r17 }; + // end inline asm + .loc 1 86 49 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:86:49 + add.s32 %r808, %r22, %r776; + add.s32 %r809, %r23, %r776; + add.s32 %r810, %r24, %r776; + add.s32 %r811, %r25, %r776; + .loc 1 86 25 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:86:25 + mad.wide.s32 %rd66, %r808, 4, %rd10; + mad.wide.s32 %rd67, %r809, 4, %rd10; + mad.wide.s32 %rd68, %r810, 4, %rd10; + mad.wide.s32 %rd69, %r811, 4, %rd10; + .loc 1 86 85 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:86:85 + bar.sync 0; + st.shared.b64 [%r792], %rd66; + st.shared.b64 [%r792+256], %rd67; + st.shared.b64 [%r792+512], %rd68; + st.shared.b64 [%r792+768], %rd69; + bar.sync 0; + ld.shared.b64 %rd56, [%r801]; + ld.shared.b64 %rd57, [%r803]; + ld.shared.b64 %rd58, [%r805]; + ld.shared.b64 %rd59, [%r807]; + // begin inline asm + @%p314 st.global.b32 [ %rd56 + 0 ], { %r762 }; + // end inline asm + // begin inline asm + @%p314 st.global.b32 [ %rd57 + 0 ], { %r762 }; + // end inline asm + // begin inline asm + @%p314 st.global.b32 [ %rd58 + 0 ], { %r762 }; + // end inline asm + // begin inline asm + @%p314 st.global.b32 [ %rd59 + 0 ], { %r762 }; + // end inline asm + .loc 1 86 4 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:86:4 + ret; +$L__BB0_1: + .loc 1 71 63 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:71:63 + { // callseq 1, 0 + .param .b64 param0; + .param .b64 param1; + .param .b32 param2; + .param .b64 param3; + .param .b64 param4; + mov.b64 %rd76, assertFunc_0; + cvta.global.u64 %rd77, %rd76; + st.param.b64 [param3], %rd77; + mov.b64 %rd78, assertFile_0; + cvta.global.u64 %rd79, %rd78; + st.param.b64 [param1], %rd79; + mov.b64 %rd80, assertMessage_0; + cvta.global.u64 %rd81, %rd80; + st.param.b64 [param0], %rd81; + st.param.b64 [param4], 1; + st.param.b32 [param2], 71; + call.uni __assertfail, (param0, param1, param2, param3, param4); + } // callseq 1 + trap; +$L__BB0_3: + .loc 1 80 63 // csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py:80:63 + { // callseq 0, 0 + .param .b64 param0; + .param .b64 param1; + .param .b32 param2; + .param .b64 param3; + .param .b64 param4; + mov.b64 %rd70, assertFunc_1; + cvta.global.u64 %rd71, %rd70; + st.param.b64 [param3], %rd71; + mov.b64 %rd72, assertFile_1; + cvta.global.u64 %rd73, %rd72; + st.param.b64 [param1], %rd73; + mov.b64 %rd74, assertMessage_1; + cvta.global.u64 %rd75, %rd74; + st.param.b64 [param0], %rd75; + st.param.b64 [param4], 1; + st.param.b32 [param2], 80; + call.uni __assertfail, (param0, param1, param2, param3, param4); + } // callseq 0 + trap; +$L__tmp14: +$L__func_end0: + // -- End function +} + .file 1 "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py" + .file 2 "/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py" + .file 3 "/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py" + .section .debug_abbrev + { +.b8 1 // Abbreviation Code +.b8 17 // DW_TAG_compile_unit +.b8 1 // DW_CHILDREN_yes +.b8 37 // DW_AT_producer +.b8 8 // DW_FORM_string +.b8 19 // DW_AT_language +.b8 5 // DW_FORM_data2 +.b8 3 // DW_AT_name +.b8 8 // DW_FORM_string +.b8 16 // DW_AT_stmt_list +.b8 6 // DW_FORM_data4 +.b8 27 // DW_AT_comp_dir +.b8 8 // DW_FORM_string +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 2 // Abbreviation Code +.b8 46 // DW_TAG_subprogram +.b8 0 // DW_CHILDREN_no +.b8 3 // DW_AT_name +.b8 8 // DW_FORM_string +.b8 32 // DW_AT_inline +.b8 11 // DW_FORM_data1 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 3 // Abbreviation Code +.b8 46 // DW_TAG_subprogram +.b8 1 // DW_CHILDREN_yes +.b8 17 // DW_AT_low_pc +.b8 1 // DW_FORM_addr +.b8 18 // DW_AT_high_pc +.b8 1 // DW_FORM_addr +.b8 49 // DW_AT_abstract_origin +.b8 19 // DW_FORM_ref4 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 4 // Abbreviation Code +.b8 29 // DW_TAG_inlined_subroutine +.b8 0 // DW_CHILDREN_no +.b8 49 // DW_AT_abstract_origin +.b8 19 // DW_FORM_ref4 +.b8 17 // DW_AT_low_pc +.b8 1 // DW_FORM_addr +.b8 18 // DW_AT_high_pc +.b8 1 // DW_FORM_addr +.b8 88 // DW_AT_call_file +.b8 11 // DW_FORM_data1 +.b8 89 // DW_AT_call_line +.b8 11 // DW_FORM_data1 +.b8 87 // DW_AT_call_column +.b8 11 // DW_FORM_data1 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 0 // EOM(3) + } + .section .debug_info + { +.b32 376 // Length of Unit +.b8 2 // DWARF version number +.b8 0 +.b32 .debug_abbrev // Offset Into Abbrev. Section +.b8 8 // Address Size (in bytes) +.b8 1 // Abbrev [1] 0xb:0x171 DW_TAG_compile_unit +.b8 116 // DW_AT_producer +.b8 114 +.b8 105 +.b8 116 +.b8 111 +.b8 110 +.b8 0 +.b8 2 // DW_AT_language +.b8 0 +.b8 99 // DW_AT_name +.b8 115 +.b8 118 +.b8 52 +.b8 54 +.b8 122 +.b8 52 +.b8 110 +.b8 100 +.b8 102 +.b8 100 +.b8 54 +.b8 53 +.b8 101 +.b8 101 +.b8 98 +.b8 50 +.b8 119 +.b8 104 +.b8 119 +.b8 115 +.b8 117 +.b8 107 +.b8 106 +.b8 101 +.b8 104 +.b8 121 +.b8 100 +.b8 98 +.b8 99 +.b8 116 +.b8 54 +.b8 53 +.b8 122 +.b8 106 +.b8 112 +.b8 53 +.b8 103 +.b8 53 +.b8 113 +.b8 112 +.b8 104 +.b8 119 +.b8 120 +.b8 118 +.b8 118 +.b8 97 +.b8 101 +.b8 99 +.b8 51 +.b8 118 +.b8 121 +.b8 46 +.b8 112 +.b8 121 +.b8 0 +.b32 .debug_line // DW_AT_stmt_list +.b8 47 // DW_AT_comp_dir +.b8 119 +.b8 111 +.b8 114 +.b8 107 +.b8 115 +.b8 112 +.b8 97 +.b8 99 +.b8 101 +.b8 47 +.b8 104 +.b8 97 +.b8 110 +.b8 114 +.b8 117 +.b8 105 +.b8 47 +.b8 83 +.b8 112 +.b8 101 +.b8 99 +.b8 70 +.b8 111 +.b8 114 +.b8 103 +.b8 101 +.b8 45 +.b8 101 +.b8 120 +.b8 116 +.b8 47 +.b8 99 +.b8 97 +.b8 99 +.b8 104 +.b8 101 +.b8 47 +.b8 99 +.b8 111 +.b8 109 +.b8 112 +.b8 105 +.b8 108 +.b8 101 +.b8 100 +.b8 95 +.b8 107 +.b8 101 +.b8 114 +.b8 110 +.b8 101 +.b8 108 +.b8 115 +.b8 47 +.b8 115 +.b8 118 +.b8 0 +.b8 2 // Abbrev [2] 0x8b:0x7a DW_TAG_subprogram +.b8 116 // DW_AT_name +.b8 114 +.b8 105 +.b8 116 +.b8 111 +.b8 110 +.b8 95 +.b8 112 +.b8 101 +.b8 114 +.b8 95 +.b8 102 +.b8 117 +.b8 115 +.b8 101 +.b8 100 +.b8 95 +.b8 95 +.b8 116 +.b8 111 +.b8 95 +.b8 99 +.b8 111 +.b8 112 +.b8 121 +.b8 95 +.b8 97 +.b8 114 +.b8 97 +.b8 110 +.b8 103 +.b8 101 +.b8 95 +.b8 98 +.b8 105 +.b8 116 +.b8 119 +.b8 105 +.b8 115 +.b8 101 +.b8 95 +.b8 97 +.b8 110 +.b8 100 +.b8 95 +.b8 101 +.b8 113 +.b8 95 +.b8 103 +.b8 116 +.b8 95 +.b8 105 +.b8 110 +.b8 100 +.b8 101 +.b8 120 +.b8 95 +.b8 112 +.b8 117 +.b8 116 +.b8 95 +.b8 108 +.b8 116 +.b8 95 +.b8 110 +.b8 101 +.b8 119 +.b8 95 +.b8 122 +.b8 101 +.b8 114 +.b8 111 +.b8 115 +.b8 95 +.b8 115 +.b8 99 +.b8 97 +.b8 108 +.b8 97 +.b8 114 +.b8 95 +.b8 116 +.b8 101 +.b8 110 +.b8 115 +.b8 111 +.b8 114 +.b8 95 +.b8 115 +.b8 111 +.b8 114 +.b8 116 +.b8 95 +.b8 115 +.b8 117 +.b8 109 +.b8 95 +.b8 117 +.b8 110 +.b8 115 +.b8 113 +.b8 117 +.b8 101 +.b8 101 +.b8 122 +.b8 101 +.b8 95 +.b8 118 +.b8 105 +.b8 101 +.b8 119 +.b8 95 +.b8 119 +.b8 104 +.b8 101 +.b8 114 +.b8 101 +.b8 95 +.b8 50 +.b8 0 +.b8 1 // DW_AT_inline +.b8 3 // Abbrev [3] 0x105:0x76 DW_TAG_subprogram +.b64 $L__func_begin0 // DW_AT_low_pc +.b64 $L__func_end0 // DW_AT_high_pc +.b32 139 // DW_AT_abstract_origin +.b8 4 // Abbrev [4] 0x11a:0x18 DW_TAG_inlined_subroutine +.b32 139 // DW_AT_abstract_origin +.b64 $L__tmp1 // DW_AT_low_pc +.b64 $L__tmp6 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 46 // DW_AT_call_line +.b8 71 // DW_AT_call_column +.b8 4 // Abbrev [4] 0x132:0x18 DW_TAG_inlined_subroutine +.b32 139 // DW_AT_abstract_origin +.b64 $L__tmp7 // DW_AT_low_pc +.b64 $L__tmp11 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 55 // DW_AT_call_line +.b8 26 // DW_AT_call_column +.b8 4 // Abbrev [4] 0x14a:0x18 DW_TAG_inlined_subroutine +.b32 139 // DW_AT_abstract_origin +.b64 $L__tmp9 // DW_AT_low_pc +.b64 $L__tmp10 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 51 // DW_AT_call_line +.b8 71 // DW_AT_call_column +.b8 4 // Abbrev [4] 0x162:0x18 DW_TAG_inlined_subroutine +.b32 139 // DW_AT_abstract_origin +.b64 $L__tmp12 // DW_AT_low_pc +.b64 $L__tmp13 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 59 // DW_AT_call_line +.b8 26 // DW_AT_call_column +.b8 0 // End Of Children Mark +.b8 0 // End Of Children Mark + } + .section .debug_macinfo { } diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/EFSFIQ3KOQUXLGL22JESW5DLVY6LJWVBODRP6EOX4ISS2B62HCMA/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.source b/SpecForge-ext/cache/compiled_kernels/triton/0/EFSFIQ3KOQUXLGL22JESW5DLVY6LJWVBODRP6EOX4ISS2B62HCMA/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.source new file mode 100644 index 0000000000000000000000000000000000000000..3f3118f454602e63e34245825f7539b22acb63d5 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/EFSFIQ3KOQUXLGL22JESW5DLVY6LJWVBODRP6EOX4ISS2B62HCMA/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.source @@ -0,0 +1,1405 @@ +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":18:0) +#loc91 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":640:0) +#loc95 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":607:0) +#loc103 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":518:0) +#loc141 = loc(unknown) +#loc166 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":285:0) +#loc170 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":260:0) +#loc175 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":86:0) +#loc179 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":63:0) +#loc188 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":131:0) +#loc193 = loc("in_ptr0"(#loc)) +#loc194 = loc("out_ptr4"(#loc)) +#loc195 = loc("out_ptr5"(#loc)) +#loc196 = loc("out_ptr6"(#loc)) +#loc197 = loc("out_ptr7"(#loc)) +#loc198 = loc("out_ptr8"(#loc)) +#loc199 = loc("out_ptr9"(#loc)) +#loc200 = loc("xnumel"(#loc)) +#loc201 = loc("r0_numel"(#loc)) +#loc257 = loc("x"(#loc91)) +#loc258 = loc("idxs"(#loc91)) +#loc259 = loc("x"(#loc95)) +#loc260 = loc("idxs"(#loc95)) +#loc265 = loc("x"(#loc103)) +#loc266 = loc("idxs"(#loc103)) +#loc267 = loc("flip"(#loc103)) +#loc323 = loc("input"(#loc166)) +#loc324 = loc("a"(#loc170)) +#loc325 = loc("b"(#loc170)) +#loc327 = loc("x"(#loc175)) +#loc328 = loc("x"(#loc179)) +#loc329 = loc("input"(#loc188)) +module { + tt.func public @triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %out_ptr4: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr4"(#loc)), %out_ptr5: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr5"(#loc)), %out_ptr6: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr6"(#loc)), %out_ptr7: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr7"(#loc)), %out_ptr8: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr8"(#loc)), %out_ptr9: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr9"(#loc)), %xnumel: i32 {tt.divisibility = 16 : i32} loc("xnumel"(#loc)), %r0_numel: i32 {tt.divisibility = 16 : i32} loc("r0_numel"(#loc))) attributes {noinline = false} { + %xnumel_0 = arith.constant 128 : i32 loc(#loc202) + %r0_numel_1 = arith.constant 16 : i32 loc(#loc203) + %xoffset = tt.get_program_id x : i32 loc(#loc204) + %xoffset_2 = arith.constant 32 : i32 loc(#loc205) + %xoffset_3 = arith.constant 32 : i32 loc(#loc205) + %xoffset_4 = arith.muli %xoffset, %xoffset_3 : i32 loc(#loc205) + %xindex = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> loc(#loc206) + %xindex_5 = tt.expand_dims %xindex {axis = 1 : i32} : tensor<32xi32> -> tensor<32x1xi32> loc(#loc207) + %xindex_6 = tt.splat %xoffset_4 : i32 -> tensor<32x1xi32> loc(#loc208) + %xindex_7 = arith.addi %xindex_6, %xindex_5 : tensor<32x1xi32> loc(#loc208) + %xmask = arith.constant dense<128> : tensor<32x1xi32> loc(#loc209) + %xmask_8 = arith.cmpi slt, %xindex_7, %xmask : tensor<32x1xi32> loc(#loc209) + %r0_index = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc210) + %r0_index_9 = tt.expand_dims %r0_index {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> loc(#loc211) + %r0_offset = arith.constant 0 : i32 loc(#loc212) + %r0_mask = arith.constant true loc(#loc213) + %r0_mask_10 = arith.constant dense : tensor<32x16xi1> loc(#loc213) + %tmp0 = arith.constant 16 : i32 loc(#loc214) + %tmp0_11 = arith.constant 16 : i32 loc(#loc214) + %tmp0_12 = arith.constant dense<16> : tensor<32x1xi32> loc(#loc214) + %tmp0_13 = arith.muli %tmp0_12, %xindex_7 : tensor<32x1xi32> loc(#loc214) + %tmp0_14 = tt.broadcast %r0_index_9 : tensor<1x16xi32> -> tensor<32x16xi32> loc(#loc215) + %tmp0_15 = tt.broadcast %tmp0_13 : tensor<32x1xi32> -> tensor<32x16xi32> loc(#loc215) + %tmp0_16 = arith.addi %tmp0_14, %tmp0_15 : tensor<32x16xi32> loc(#loc215) + %tmp0_17 = tt.splat %in_ptr0 : !tt.ptr -> tensor<32x16x!tt.ptr> loc(#loc216) + %tmp0_18 = tt.addptr %tmp0_17, %tmp0_16 : tensor<32x16x!tt.ptr>, tensor<32x16xi32> loc(#loc216) + %tmp0_19 = arith.constant 0.000000e+00 : f32 loc(#loc217) + %tmp0_20 = tt.broadcast %xmask_8 : tensor<32x1xi1> -> tensor<32x16xi1> loc(#loc217) + %tmp0_21 = arith.constant dense<0.000000e+00> : tensor<32x16xf32> loc(#loc217) + %tmp0_22 = arith.fptosi %tmp0_21 : tensor<32x16xf32> to tensor<32x16xi64> loc(#loc217) + %tmp0_23 = tt.load %tmp0_18, %tmp0_20, %tmp0_22 : tensor<32x16x!tt.ptr> loc(#loc217) + %tmp1 = arith.constant 0 : i64 loc(#loc218) + %tmp1_24 = arith.constant dense<0> : tensor<1x1xi64> loc(#loc218) + %tmp2 = arith.constant dense<0> : tensor<32x16xi64> loc(#loc219) + %tmp2_25 = arith.cmpi sgt, %tmp0_23, %tmp2 : tensor<32x16xi64> loc(#loc219) + %tmp3 = arith.constant 16384 : i64 loc(#loc220) + %tmp3_26 = arith.constant dense<16384> : tensor<1x1xi64> loc(#loc220) + %tmp4 = arith.constant dense<16384> : tensor<32x16xi64> loc(#loc221) + %tmp4_27 = arith.cmpi slt, %tmp0_23, %tmp4 : tensor<32x16xi64> loc(#loc221) + %tmp5 = arith.andi %tmp2_25, %tmp4_27 : tensor<32x16xi1> loc(#loc222) + %tmp6 = arith.extui %tmp5 : tensor<32x16xi1> to tensor<32x16xi8> loc(#loc223) + %tmp7 = arith.extsi %tmp6 : tensor<32x16xi8> to tensor<32x16xi32> loc(#loc224) + %tmp9 = arith.trunci %r0_index_9 : tensor<1x16xi32> to tensor<1x16xi16> loc(#loc225) + %tmp11 = tt.broadcast %tmp9 : tensor<1x16xi16> -> tensor<32x16xi16> loc(#loc226) + %0:2 = tt.call @"torch._inductor.runtime.triton_helpers.sort_with_index__i32S32_16S_i16S32_16S__(2,)cconstexpr_None__(3,)cconstexpr_1__(4,)cconstexpr_True__(5,)cconstexpr_True_"(%tmp7, %tmp11) : (tensor<32x16xi32>, tensor<32x16xi16>) -> (tensor<32x16xi32>, tensor<32x16xi32>) loc(#loc26) + %tmp14 = arith.constant dense<16384> : tensor<32x16xi64> loc(#loc227) + %tmp14_28 = arith.cmpi eq, %tmp0_23, %tmp14 : tensor<32x16xi64> loc(#loc227) + %tmp15 = arith.extui %tmp14_28 : tensor<32x16xi1> to tensor<32x16xi8> loc(#loc228) + %tmp16 = arith.extsi %tmp15 : tensor<32x16xi8> to tensor<32x16xi32> loc(#loc229) + %1:2 = tt.call @"torch._inductor.runtime.triton_helpers.sort_with_index__i32S32_16S_i16S32_16S__(2,)cconstexpr_None__(3,)cconstexpr_1__(4,)cconstexpr_True__(5,)cconstexpr_True_"(%tmp16, %tmp11) : (tensor<32x16xi32>, tensor<32x16xi16>) -> (tensor<32x16xi32>, tensor<32x16xi32>) loc(#loc30) + %tmp20 = arith.extsi %tmp7 : tensor<32x16xi32> to tensor<32x16xi64> loc(#loc230) + %tmp23 = arith.constant 0 : i32 loc(#loc231) + %tmp23_29 = arith.constant 0 : i64 loc(#loc231) + %tmp23_30 = arith.constant dense<0> : tensor<32x16xi64> loc(#loc231) + %tmp23_31 = tt.broadcast %xmask_8 : tensor<32x1xi1> -> tensor<32x16xi1> loc(#loc231) + %tmp23_32 = arith.select %tmp23_31, %tmp20, %tmp23_30 : tensor<32x16xi1>, tensor<32x16xi64> loc(#loc231) + %tmp24 = tt.call @"triton.language.standard.sum__i64S32_16S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%tmp23_32) : (tensor<32x16xi64>) -> tensor<32xi64> loc(#loc232) + %tmp24_33 = tt.expand_dims %tmp24 {axis = 1 : i32} : tensor<32xi64> -> tensor<32x1xi64> loc(#loc233) + %tmp25 = arith.extsi %tmp16 : tensor<32x16xi32> to tensor<32x16xi64> loc(#loc234) + %tmp28 = arith.constant 0 : i32 loc(#loc235) + %tmp28_34 = arith.constant 0 : i64 loc(#loc235) + %tmp28_35 = arith.constant dense<0> : tensor<32x16xi64> loc(#loc235) + %tmp28_36 = tt.broadcast %xmask_8 : tensor<32x1xi1> -> tensor<32x16xi1> loc(#loc235) + %tmp28_37 = arith.select %tmp28_36, %tmp25, %tmp28_35 : tensor<32x16xi1>, tensor<32x16xi64> loc(#loc235) + %tmp29 = tt.call @"triton.language.standard.sum__i64S32_16S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%tmp28_37) : (tensor<32x16xi64>) -> tensor<32xi64> loc(#loc236) + %tmp29_38 = tt.expand_dims %tmp29 {axis = 1 : i32} : tensor<32xi64> -> tensor<32x1xi64> loc(#loc237) + %tmp30 = arith.trunci %tmp24_33 : tensor<32x1xi64> to tensor<32x1xi32> loc(#loc238) + %tmp31 = arith.trunci %tmp29_38 : tensor<32x1xi64> to tensor<32x1xi32> loc(#loc239) + %tmp32 = arith.extsi %0#1 : tensor<32x16xi32> to tensor<32x16xi64> loc(#loc240) + %tmp33 = arith.trunci %tmp32 : tensor<32x16xi64> to tensor<32x16xi32> loc(#loc241) + %tmp34 = tt.broadcast %r0_index_9 : tensor<1x16xi32> -> tensor<32x16xi32> loc(#loc242) + %tmp34_39 = tt.broadcast %tmp30 : tensor<32x1xi32> -> tensor<32x16xi32> loc(#loc242) + %tmp34_40 = arith.cmpi slt, %tmp34, %tmp34_39 : tensor<32x16xi32> loc(#loc242) + %tmp35 = arith.constant 16 : i32 loc(#loc243) + %tmp35_41 = arith.constant dense<16> : tensor<1x1xi32> loc(#loc243) + %tmp36 = arith.constant dense<16> : tensor<32x16xi32> loc(#loc244) + %tmp36_42 = arith.select %tmp34_40, %tmp33, %tmp36 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc244) + %tmp37 = arith.constant 17 : i32 loc(#loc245) + %tmp37_43 = arith.constant dense<17> : tensor<32x16xi32> loc(#loc245) + %tmp38 = arith.addi %tmp36_42, %tmp37_43 : tensor<32x16xi32> loc(#loc246) + %tmp39 = arith.constant 0 : i32 loc(#loc247) + %tmp39_44 = arith.constant dense<0> : tensor<32x16xi32> loc(#loc247) + %tmp39_45 = arith.cmpi slt, %tmp36_42, %tmp39_44 : tensor<32x16xi32> loc(#loc247) + %tmp40 = arith.select %tmp39_45, %tmp38, %tmp36_42 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc248) + %c0_i32 = arith.constant 0 : i32 loc(#loc50) + %cst = arith.constant dense<0> : tensor<32x16xi32> loc(#loc50) + %2 = arith.cmpi sle, %cst, %tmp40 : tensor<32x16xi32> loc(#loc50) + %c17_i32 = arith.constant 17 : i32 loc(#loc51) + %cst_46 = arith.constant dense<17> : tensor<32x16xi32> loc(#loc51) + %3 = arith.cmpi slt, %tmp40, %cst_46 : tensor<32x16xi32> loc(#loc51) + %4 = arith.andi %2, %3 : tensor<32x16xi1> loc(#loc52) + %true = arith.constant true loc(#loc53) + %cst_47 = arith.constant dense : tensor<32x1xi1> loc(#loc53) + %5 = arith.xori %xmask_8, %cst_47 : tensor<32x1xi1> loc(#loc53) + %6 = tt.broadcast %5 : tensor<32x1xi1> -> tensor<32x16xi1> loc(#loc54) + %7 = arith.ori %4, %6 : tensor<32x16xi1> loc(#loc54) + tt.assert %7, "index out of bounds: 0 <= tmp40 < 17" : tensor<32x16xi1> loc(#loc55) + %tmp42 = arith.constant 1 : i32 loc(#loc249) + %tmp42_48 = arith.constant dense<1> : tensor<1x1xi32> loc(#loc249) + %tmp43 = arith.extsi %1#1 : tensor<32x16xi32> to tensor<32x16xi64> loc(#loc250) + %tmp44 = arith.trunci %tmp43 : tensor<32x16xi64> to tensor<32x16xi32> loc(#loc251) + %tmp45 = tt.broadcast %r0_index_9 : tensor<1x16xi32> -> tensor<32x16xi32> loc(#loc252) + %tmp45_49 = tt.broadcast %tmp31 : tensor<32x1xi32> -> tensor<32x16xi32> loc(#loc252) + %tmp45_50 = arith.cmpi slt, %tmp45, %tmp45_49 : tensor<32x16xi32> loc(#loc252) + %tmp46 = arith.constant dense<16> : tensor<32x16xi32> loc(#loc253) + %tmp46_51 = arith.select %tmp45_50, %tmp44, %tmp46 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc253) + %tmp47 = arith.addi %tmp46_51, %tmp37_43 : tensor<32x16xi32> loc(#loc254) + %tmp48 = arith.constant 0 : i32 loc(#loc255) + %tmp48_52 = arith.constant dense<0> : tensor<32x16xi32> loc(#loc255) + %tmp48_53 = arith.cmpi slt, %tmp46_51, %tmp48_52 : tensor<32x16xi32> loc(#loc255) + %tmp49 = arith.select %tmp48_53, %tmp47, %tmp46_51 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc256) + %c0_i32_54 = arith.constant 0 : i32 loc(#loc64) + %cst_55 = arith.constant dense<0> : tensor<32x16xi32> loc(#loc64) + %8 = arith.cmpi sle, %cst_55, %tmp49 : tensor<32x16xi32> loc(#loc64) + %c17_i32_56 = arith.constant 17 : i32 loc(#loc65) + %cst_57 = arith.constant dense<17> : tensor<32x16xi32> loc(#loc65) + %9 = arith.cmpi slt, %tmp49, %cst_57 : tensor<32x16xi32> loc(#loc65) + %10 = arith.andi %8, %9 : tensor<32x16xi1> loc(#loc66) + %true_58 = arith.constant true loc(#loc67) + %cst_59 = arith.constant dense : tensor<32x1xi1> loc(#loc67) + %11 = arith.xori %xmask_8, %cst_59 : tensor<32x1xi1> loc(#loc67) + %12 = tt.broadcast %11 : tensor<32x1xi1> -> tensor<32x16xi1> loc(#loc68) + %13 = arith.ori %10, %12 : tensor<32x16xi1> loc(#loc68) + tt.assert %13, "index out of bounds: 0 <= tmp49 < 17" : tensor<32x16xi1> loc(#loc69) + %14 = tt.splat %out_ptr4 : !tt.ptr -> tensor<32x1x!tt.ptr> loc(#loc70) + %15 = tt.addptr %14, %xindex_7 : tensor<32x1x!tt.ptr>, tensor<32x1xi32> loc(#loc70) + tt.store %15, %tmp30, %xmask_8 : tensor<32x1x!tt.ptr> loc(#loc71) + %16 = tt.splat %out_ptr5 : !tt.ptr -> tensor<32x1x!tt.ptr> loc(#loc72) + %17 = tt.addptr %16, %xindex_7 : tensor<32x1x!tt.ptr>, tensor<32x1xi32> loc(#loc72) + tt.store %17, %tmp31, %xmask_8 : tensor<32x1x!tt.ptr> loc(#loc73) + %c16_i32 = arith.constant 16 : i32 loc(#loc74) + %c16_i32_60 = arith.constant 16 : i32 loc(#loc74) + %cst_61 = arith.constant dense<16> : tensor<32x1xi32> loc(#loc74) + %18 = arith.muli %cst_61, %xindex_7 : tensor<32x1xi32> loc(#loc74) + %19 = tt.broadcast %r0_index_9 : tensor<1x16xi32> -> tensor<32x16xi32> loc(#loc75) + %20 = tt.broadcast %18 : tensor<32x1xi32> -> tensor<32x16xi32> loc(#loc75) + %21 = arith.addi %19, %20 : tensor<32x16xi32> loc(#loc75) + %22 = tt.splat %out_ptr6 : !tt.ptr -> tensor<32x16x!tt.ptr> loc(#loc76) + %23 = tt.addptr %22, %21 : tensor<32x16x!tt.ptr>, tensor<32x16xi32> loc(#loc76) + %24 = tt.broadcast %xmask_8 : tensor<32x1xi1> -> tensor<32x16xi1> loc(#loc77) + tt.store %23, %tmp33, %24 : tensor<32x16x!tt.ptr> loc(#loc77) + %c17_i32_62 = arith.constant 17 : i32 loc(#loc78) + %c17_i32_63 = arith.constant 17 : i32 loc(#loc78) + %cst_64 = arith.constant dense<17> : tensor<32x1xi32> loc(#loc78) + %25 = arith.muli %cst_64, %xindex_7 : tensor<32x1xi32> loc(#loc78) + %26 = tt.broadcast %25 : tensor<32x1xi32> -> tensor<32x16xi32> loc(#loc79) + %27 = arith.addi %tmp40, %26 : tensor<32x16xi32> loc(#loc79) + %28 = tt.splat %out_ptr7 : !tt.ptr -> tensor<32x16x!tt.ptr> loc(#loc80) + %29 = tt.addptr %28, %27 : tensor<32x16x!tt.ptr>, tensor<32x16xi32> loc(#loc80) + %cst_65 = arith.constant dense<1> : tensor<32x16xi32> loc(#loc81) + %30 = tt.broadcast %xmask_8 : tensor<32x1xi1> -> tensor<32x16xi1> loc(#loc81) + tt.store %29, %cst_65, %30 : tensor<32x16x!tt.ptr> loc(#loc81) + %c16_i32_66 = arith.constant 16 : i32 loc(#loc82) + %c16_i32_67 = arith.constant 16 : i32 loc(#loc82) + %cst_68 = arith.constant dense<16> : tensor<32x1xi32> loc(#loc82) + %31 = arith.muli %cst_68, %xindex_7 : tensor<32x1xi32> loc(#loc82) + %32 = tt.broadcast %r0_index_9 : tensor<1x16xi32> -> tensor<32x16xi32> loc(#loc83) + %33 = tt.broadcast %31 : tensor<32x1xi32> -> tensor<32x16xi32> loc(#loc83) + %34 = arith.addi %32, %33 : tensor<32x16xi32> loc(#loc83) + %35 = tt.splat %out_ptr8 : !tt.ptr -> tensor<32x16x!tt.ptr> loc(#loc84) + %36 = tt.addptr %35, %34 : tensor<32x16x!tt.ptr>, tensor<32x16xi32> loc(#loc84) + %37 = tt.broadcast %xmask_8 : tensor<32x1xi1> -> tensor<32x16xi1> loc(#loc85) + tt.store %36, %tmp44, %37 : tensor<32x16x!tt.ptr> loc(#loc85) + %c17_i32_69 = arith.constant 17 : i32 loc(#loc86) + %c17_i32_70 = arith.constant 17 : i32 loc(#loc86) + %cst_71 = arith.constant dense<17> : tensor<32x1xi32> loc(#loc86) + %38 = arith.muli %cst_71, %xindex_7 : tensor<32x1xi32> loc(#loc86) + %39 = tt.broadcast %38 : tensor<32x1xi32> -> tensor<32x16xi32> loc(#loc87) + %40 = arith.addi %tmp49, %39 : tensor<32x16xi32> loc(#loc87) + %41 = tt.splat %out_ptr9 : !tt.ptr -> tensor<32x16x!tt.ptr> loc(#loc88) + %42 = tt.addptr %41, %40 : tensor<32x16x!tt.ptr>, tensor<32x16xi32> loc(#loc88) + %cst_72 = arith.constant dense<1> : tensor<32x16xi32> loc(#loc89) + %43 = tt.broadcast %xmask_8 : tensor<32x1xi1> -> tensor<32x16xi1> loc(#loc89) + tt.store %42, %cst_72, %43 : tensor<32x16x!tt.ptr> loc(#loc89) + tt.return loc(#loc90) + } loc(#loc) + tt.func private @"torch._inductor.runtime.triton_helpers.sort_with_index__i32S32_16S_i16S32_16S__(2,)cconstexpr_None__(3,)cconstexpr_1__(4,)cconstexpr_True__(5,)cconstexpr_True_"(%x: tensor<32x16xi32> loc("x"(#loc91)), %idxs: tensor<32x16xi16> loc("idxs"(#loc91))) -> (tensor<32x16xi32>, tensor<32x16xi32>) attributes {noinline = false} { + %0:2 = tt.call @"torch._inductor.runtime.triton_helpers._bitonic_merge_with_index__i32S32_16S_i16S32_16S__(2,)cconstexpr_None__(3,)cconstexpr_1__(4,)cconstexpr_True__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x, %idxs) : (tensor<32x16xi32>, tensor<32x16xi16>) -> (tensor<32x16xi32>, tensor<32x16xi32>) loc(#loc92) + %1:2 = tt.call @"torch._inductor.runtime.triton_helpers._bitonic_merge_with_index__i32S32_16S_i32S32_16S__(2,)cconstexpr_None__(3,)cconstexpr_2__(4,)cconstexpr_True__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%0#0, %0#1) : (tensor<32x16xi32>, tensor<32x16xi32>) -> (tensor<32x16xi32>, tensor<32x16xi32>) loc(#loc92) + %2:2 = tt.call @"torch._inductor.runtime.triton_helpers._bitonic_merge_with_index__i32S32_16S_i32S32_16S__(2,)cconstexpr_None__(3,)cconstexpr_3__(4,)cconstexpr_True__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%1#0, %1#1) : (tensor<32x16xi32>, tensor<32x16xi32>) -> (tensor<32x16xi32>, tensor<32x16xi32>) loc(#loc92) + %3:2 = tt.call @"torch._inductor.runtime.triton_helpers._bitonic_merge_with_index__i32S32_16S_i32S32_16S__(2,)cconstexpr_None__(3,)cconstexpr_4__(4,)cconstexpr_False__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%2#0, %2#1) : (tensor<32x16xi32>, tensor<32x16xi32>) -> (tensor<32x16xi32>, tensor<32x16xi32>) loc(#loc92) + tt.return %3#0, %3#1 : tensor<32x16xi32>, tensor<32x16xi32> loc(#loc93) + ^bb1: // no predecessors + %4 = ub.poison : tensor<32x16xi32> loc(#loc94) + %5 = ub.poison : tensor<32x16xi32> loc(#loc94) + tt.return %4, %5 : tensor<32x16xi32>, tensor<32x16xi32> loc(#loc94) + } loc(#loc91) + tt.func private @"torch._inductor.runtime.triton_helpers._bitonic_merge_with_index__i32S32_16S_i16S32_16S__(2,)cconstexpr_None__(3,)cconstexpr_1__(4,)cconstexpr_True__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x: tensor<32x16xi32> loc("x"(#loc95)), %idxs: tensor<32x16xi16> loc("idxs"(#loc95))) -> (tensor<32x16xi32>, tensor<32x16xi32>) attributes {noinline = false} { + %flip = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc261) + %flip_0 = tt.expand_dims %flip {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc262) + %flip_1 = tt.expand_dims %flip_0 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc262) + %flip_2 = tt.broadcast %flip_1 : tensor<1x2x1xi32> -> tensor<128x2x2xi32> loc(#loc263) + %flip_3 = tt.reshape %flip_2 : tensor<128x2x2xi32> -> tensor<32x16xi32> loc(#loc264) + %0:2 = tt.call @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S32_16S_i16S32_16S_i32S32_16S__(2,)cconstexpr_None__(4,)cconstexpr_3__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x, %idxs, %flip_3) : (tensor<32x16xi32>, tensor<32x16xi16>, tensor<32x16xi32>) -> (tensor<32x16xi32>, tensor<32x16xi32>) loc(#loc100) + tt.return %0#0, %0#1 : tensor<32x16xi32>, tensor<32x16xi32> loc(#loc101) + ^bb1: // no predecessors + %1 = ub.poison : tensor<32x16xi32> loc(#loc102) + %2 = ub.poison : tensor<32x16xi32> loc(#loc102) + tt.return %1, %2 : tensor<32x16xi32>, tensor<32x16xi32> loc(#loc102) + } loc(#loc95) + tt.func private @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S32_16S_i16S32_16S_i32S32_16S__(2,)cconstexpr_None__(4,)cconstexpr_3__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x: tensor<32x16xi32> loc("x"(#loc103)), %idxs: tensor<32x16xi16> loc("idxs"(#loc103)), %flip: tensor<32x16xi32> loc("flip"(#loc103))) -> (tensor<32x16xi32>, tensor<32x16xi32>) attributes {noinline = false} { + %y = tt.reshape %x : tensor<32x16xi32> -> tensor<256x2x1xi32> loc(#loc268) + %right_mask = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc269) + %right_mask_0 = tt.expand_dims %right_mask {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc270) + %right_mask_1 = tt.expand_dims %right_mask_0 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc270) + %left_mask = arith.constant 1 : i32 loc(#loc271) + %left_mask_2 = arith.constant 1 : i32 loc(#loc271) + %left_mask_3 = arith.constant dense<1> : tensor<1x2x1xi32> loc(#loc271) + %left_mask_4 = arith.subi %left_mask_3, %right_mask_1 : tensor<1x2x1xi32> loc(#loc271) + %ileft = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<256x2x1xi32> loc(#loc272) + %ileft_5 = arith.muli %y, %ileft : tensor<256x2x1xi32> loc(#loc272) + %ileft_6 = tt.call @"triton.language.standard.sum__i32S256_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%ileft_5) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc273) + %ileft_7 = tt.expand_dims %ileft_6 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc274) + %ileft_8 = tt.broadcast %ileft_7 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc275) + %iright = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<256x2x1xi32> loc(#loc276) + %iright_9 = arith.muli %y, %iright : tensor<256x2x1xi32> loc(#loc276) + %iright_10 = tt.call @"triton.language.standard.sum__i32S256_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%iright_9) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc277) + %iright_11 = tt.expand_dims %iright_10 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc278) + %iright_12 = tt.broadcast %iright_11 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc279) + %ileft_13 = tt.reshape %ileft_8 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc280) + %iright_14 = tt.reshape %iright_12 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc281) + %y_idx = tt.reshape %idxs : tensor<32x16xi16> -> tensor<256x2x1xi16> loc(#loc282) + %left_idx = arith.trunci %left_mask_4 : tensor<1x2x1xi32> to tensor<1x2x1xi16> loc(#loc283) + %left_idx_15 = tt.broadcast %left_idx : tensor<1x2x1xi16> -> tensor<256x2x1xi16> loc(#loc284) + %left_idx_16 = arith.muli %y_idx, %left_idx_15 : tensor<256x2x1xi16> loc(#loc284) + %left_idx_17 = tt.call @"triton.language.standard.sum__i16S256_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%left_idx_16) : (tensor<256x2x1xi16>) -> tensor<256x1xi32> loc(#loc285) + %left_idx_18 = tt.expand_dims %left_idx_17 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc286) + %left_idx_19 = tt.broadcast %left_idx_18 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc287) + %right_idx = arith.trunci %right_mask_1 : tensor<1x2x1xi32> to tensor<1x2x1xi16> loc(#loc288) + %right_idx_20 = tt.broadcast %right_idx : tensor<1x2x1xi16> -> tensor<256x2x1xi16> loc(#loc289) + %right_idx_21 = arith.muli %y_idx, %right_idx_20 : tensor<256x2x1xi16> loc(#loc289) + %right_idx_22 = tt.call @"triton.language.standard.sum__i16S256_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%right_idx_21) : (tensor<256x2x1xi16>) -> tensor<256x1xi32> loc(#loc290) + %right_idx_23 = tt.expand_dims %right_idx_22 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc291) + %right_idx_24 = tt.broadcast %right_idx_23 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc292) + %left_idx_25 = tt.reshape %left_idx_19 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc293) + %right_idx_26 = tt.reshape %right_idx_24 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc294) + %left_valid_mask = arith.constant true loc(#loc295) + %left_valid_mask_27 = arith.constant dense : tensor<32x16xi1> loc(#loc295) + %right_valid_mask = arith.constant true loc(#loc296) + %right_valid_mask_28 = arith.constant dense : tensor<32x16xi1> loc(#loc296) + %left_isnan = arith.cmpi ne, %ileft_13, %ileft_13 : tensor<32x16xi32> loc(#loc297) + %right_isnan = arith.cmpi ne, %iright_14, %iright_14 : tensor<32x16xi32> loc(#loc298) + %cond = arith.cmpi slt, %ileft_13, %iright_14 : tensor<32x16xi32> loc(#loc331) + %0 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S32_16S__(%ileft_13) : (tensor<32x16xi32>) -> i1 loc(#loc136) + %1 = scf.if %0 -> (tensor<32x16xi1>) { + %cond_49 = arith.constant true loc(#loc300) + %cond_50 = arith.constant dense : tensor<32x16xi1> loc(#loc300) + %cond_51 = arith.xori %left_isnan, %cond_50 : tensor<32x16xi1> loc(#loc300) + %cond_52 = arith.andi %right_isnan, %cond_51 : tensor<32x16xi1> loc(#loc301) + %cond_53 = arith.ori %cond, %cond_52 : tensor<32x16xi1> loc(#loc332) + scf.yield %cond_53 : tensor<32x16xi1> loc(#loc332) + } else { + scf.yield %cond : tensor<32x16xi1> loc(#loc141) + } loc(#loc137) + %eq = arith.cmpi eq, %ileft_13, %iright_14 : tensor<32x16xi32> loc(#loc333) + %2 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S32_16S__(%ileft_13) : (tensor<32x16xi32>) -> i1 loc(#loc143) + %3 = scf.if %2 -> (tensor<32x16xi1>) { + %eq_49 = arith.andi %left_isnan, %right_isnan : tensor<32x16xi1> loc(#loc304) + %eq_50 = arith.ori %eq, %eq_49 : tensor<32x16xi1> loc(#loc334) + scf.yield %eq_50 : tensor<32x16xi1> loc(#loc334) + } else { + scf.yield %eq : tensor<32x16xi1> loc(#loc141) + } loc(#loc144) + %cond_29 = arith.cmpi sgt, %left_idx_25, %right_idx_26 : tensor<32x16xi32> loc(#loc306) + %cond_30 = arith.andi %3, %cond_29 : tensor<32x16xi1> loc(#loc307) + %cond_31 = arith.ori %1, %cond_30 : tensor<32x16xi1> loc(#loc308) + %cond_32 = arith.cmpi ugt, %right_valid_mask_28, %left_valid_mask_27 : tensor<32x16xi1> loc(#loc309) + %cond_33 = arith.cmpi eq, %right_valid_mask_28, %left_valid_mask_27 : tensor<32x16xi1> loc(#loc310) + %cond_34 = arith.andi %cond_33, %cond_31 : tensor<32x16xi1> loc(#loc311) + %cond_35 = arith.ori %cond_32, %cond_34 : tensor<32x16xi1> loc(#loc312) + %cond_36 = arith.extui %cond_35 : tensor<32x16xi1> to tensor<32x16xi32> loc(#loc313) + %cond_37 = arith.xori %cond_36, %flip : tensor<32x16xi32> loc(#loc313) + %cond_38 = arith.constant 0 : i32 loc(#loc314) + %cond_39 = arith.constant dense<0> : tensor<32x16xi32> loc(#loc314) + %cond_40 = arith.cmpi ne, %cond_37, %cond_39 : tensor<32x16xi32> loc(#loc314) + %ret = arith.xori %ileft_13, %iright_14 : tensor<32x16xi32> loc(#loc315) + %ret_41 = tt.call @triton.language.standard.zeros_like__i32S32_16S__(%x) : (tensor<32x16xi32>) -> tensor<32x16xi32> loc(#loc316) + %ret_42 = arith.select %cond_40, %ret, %ret_41 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc317) + %ret_43 = arith.xori %x, %ret_42 : tensor<32x16xi32> loc(#loc318) + %new_idxs = arith.xori %left_idx_25, %right_idx_26 : tensor<32x16xi32> loc(#loc319) + %new_idxs_44 = tt.call @triton.language.standard.zeros_like__i16S32_16S__(%idxs) : (tensor<32x16xi16>) -> tensor<32x16xi16> loc(#loc320) + %new_idxs_45 = arith.extsi %new_idxs_44 : tensor<32x16xi16> to tensor<32x16xi32> loc(#loc321) + %new_idxs_46 = arith.select %cond_40, %new_idxs, %new_idxs_45 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc321) + %new_idxs_47 = arith.extsi %idxs : tensor<32x16xi16> to tensor<32x16xi32> loc(#loc322) + %new_idxs_48 = arith.xori %new_idxs_47, %new_idxs_46 : tensor<32x16xi32> loc(#loc322) + tt.return %ret_43, %new_idxs_48 : tensor<32x16xi32>, tensor<32x16xi32> loc(#loc164) + ^bb1: // no predecessors + %4 = ub.poison : tensor<32x16xi32> loc(#loc165) + %5 = ub.poison : tensor<32x16xi32> loc(#loc165) + tt.return %4, %5 : tensor<32x16xi32>, tensor<32x16xi32> loc(#loc165) + } loc(#loc103) + tt.func private @"triton.language.standard.sum__i32S256_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%input: tensor<256x2x1xi32> loc("input"(#loc166))) -> tensor<256x1xi32> attributes {noinline = false} { + %0 = "tt.reduce"(%input) <{axis = 1 : i32}> ({ + ^bb0(%arg1: i32 loc(unknown), %arg2: i32 loc(unknown)): + %2 = tt.call @triton.language.standard._sum_combine__i32_i32__(%arg1, %arg2) : (i32, i32) -> i32 loc(#loc167) + tt.reduce.return %2 : i32 loc(#loc167) + }) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc167) + tt.return %0 : tensor<256x1xi32> loc(#loc168) + ^bb1: // no predecessors + %1 = ub.poison : tensor<256x1xi32> loc(#loc169) + tt.return %1 : tensor<256x1xi32> loc(#loc169) + } loc(#loc166) + tt.func private @triton.language.standard._sum_combine__i32_i32__(%a: i32 loc("a"(#loc170)), %b: i32 loc("b"(#loc170))) -> i32 attributes {noinline = false} { + %0 = arith.addi %a, %b : i32 loc(#loc171) + tt.return %0 : i32 loc(#loc172) + ^bb1: // no predecessors + %1 = ub.poison : i32 loc(#loc173) + tt.return %1 : i32 loc(#loc173) + } loc(#loc170) + tt.func private @"triton.language.standard.sum__i16S256_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%input: tensor<256x2x1xi16> loc("input"(#loc166))) -> tensor<256x1xi32> attributes {noinline = false} { + %input_0 = arith.extsi %input : tensor<256x2x1xi16> to tensor<256x2x1xi32> loc(#loc326) + %0 = "tt.reduce"(%input_0) <{axis = 1 : i32}> ({ + ^bb0(%arg1: i32 loc(unknown), %arg2: i32 loc(unknown)): + %2 = tt.call @triton.language.standard._sum_combine__i32_i32__(%arg1, %arg2) : (i32, i32) -> i32 loc(#loc167) + tt.reduce.return %2 : i32 loc(#loc167) + }) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc167) + tt.return %0 : tensor<256x1xi32> loc(#loc168) + ^bb1: // no predecessors + %1 = ub.poison : tensor<256x1xi32> loc(#loc169) + tt.return %1 : tensor<256x1xi32> loc(#loc169) + } loc(#loc166) + tt.func private @torch._inductor.runtime.triton_helpers.is_floating__i32S32_16S__(%x: tensor<32x16xi32> loc("x"(#loc175))) -> i1 attributes {noinline = false} { + %0 = tt.call @torch._inductor.runtime.triton_helpers.promote_to_tensor__i32S32_16S__(%x) : (tensor<32x16xi32>) -> tensor<32x16xi32> loc(#loc176) + %false = arith.constant false loc(#loc177) + tt.return %false : i1 loc(#loc177) + ^bb1: // no predecessors + %1 = ub.poison : i1 loc(#loc178) + tt.return %1 : i1 loc(#loc178) + } loc(#loc175) + tt.func private @torch._inductor.runtime.triton_helpers.promote_to_tensor__i32S32_16S__(%x: tensor<32x16xi32> loc("x"(#loc179))) -> tensor<32x16xi32> attributes {noinline = false} { + %0 = tt.call @"triton.language.standard.zeros____(0, 0)cconstexpr_1__(1,)cconstexpr_int1_"() : () -> tensor<1xi1> loc(#loc180) + %1 = arith.extui %0 : tensor<1xi1> to tensor<1xi32> loc(#loc181) + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc181) + %3 = tt.broadcast %2 : tensor<1x1xi32> -> tensor<32x16xi32> loc(#loc181) + %4 = arith.addi %x, %3 : tensor<32x16xi32> loc(#loc181) + tt.return %4 : tensor<32x16xi32> loc(#loc182) + ^bb1: // no predecessors + %5 = ub.poison : tensor<32x16xi32> loc(#loc183) + tt.return %5 : tensor<32x16xi32> loc(#loc183) + } loc(#loc179) + tt.func private @"triton.language.standard.zeros____(0, 0)cconstexpr_1__(1,)cconstexpr_int1_"() -> tensor<1xi1> attributes {noinline = false} { + %false = arith.constant false loc(#loc185) + %cst = arith.constant dense : tensor<1xi1> loc(#loc185) + tt.return %cst : tensor<1xi1> loc(#loc186) + ^bb1: // no predecessors + %0 = ub.poison : tensor<1xi1> loc(#loc187) + tt.return %0 : tensor<1xi1> loc(#loc187) + } loc(#loc184) + tt.func private @triton.language.standard.zeros_like__i32S32_16S__(%input: tensor<32x16xi32> loc("input"(#loc188))) -> tensor<32x16xi32> attributes {noinline = false} { + %0 = tt.call @"triton.language.standard.zeros____(0, 0)cconstexpr_32__(0, 1)cconstexpr_16__(1,)cconstexpr_int32_"() : () -> tensor<32x16xi32> loc(#loc189) + tt.return %0 : tensor<32x16xi32> loc(#loc190) + ^bb1: // no predecessors + %1 = ub.poison : tensor<32x16xi32> loc(#loc191) + tt.return %1 : tensor<32x16xi32> loc(#loc191) + } loc(#loc188) + tt.func private @"triton.language.standard.zeros____(0, 0)cconstexpr_32__(0, 1)cconstexpr_16__(1,)cconstexpr_int32_"() -> tensor<32x16xi32> attributes {noinline = false} { + %c0_i32 = arith.constant 0 : i32 loc(#loc185) + %cst = arith.constant dense<0> : tensor<32x16xi32> loc(#loc185) + tt.return %cst : tensor<32x16xi32> loc(#loc186) + ^bb1: // no predecessors + %0 = ub.poison : tensor<32x16xi32> loc(#loc187) + tt.return %0 : tensor<32x16xi32> loc(#loc187) + } loc(#loc184) + tt.func private @triton.language.standard.zeros_like__i16S32_16S__(%input: tensor<32x16xi16> loc("input"(#loc188))) -> tensor<32x16xi16> attributes {noinline = false} { + %0 = tt.call @"triton.language.standard.zeros____(0, 0)cconstexpr_32__(0, 1)cconstexpr_16__(1,)cconstexpr_int16_"() : () -> tensor<32x16xi16> loc(#loc189) + tt.return %0 : tensor<32x16xi16> loc(#loc190) + ^bb1: // no predecessors + %1 = ub.poison : tensor<32x16xi16> loc(#loc191) + tt.return %1 : tensor<32x16xi16> loc(#loc191) + } loc(#loc188) + tt.func private @"triton.language.standard.zeros____(0, 0)cconstexpr_32__(0, 1)cconstexpr_16__(1,)cconstexpr_int16_"() -> tensor<32x16xi16> attributes {noinline = false} { + %c0_i16 = arith.constant 0 : i16 loc(#loc185) + %cst = arith.constant dense<0> : tensor<32x16xi16> loc(#loc185) + tt.return %cst : tensor<32x16xi16> loc(#loc186) + ^bb1: // no predecessors + %0 = ub.poison : tensor<32x16xi16> loc(#loc187) + tt.return %0 : tensor<32x16xi16> loc(#loc187) + } loc(#loc184) + tt.func private @"torch._inductor.runtime.triton_helpers._bitonic_merge_with_index__i32S32_16S_i32S32_16S__(2,)cconstexpr_None__(3,)cconstexpr_2__(4,)cconstexpr_True__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x: tensor<32x16xi32> loc("x"(#loc95)), %idxs: tensor<32x16xi32> loc("idxs"(#loc95))) -> (tensor<32x16xi32>, tensor<32x16xi32>) attributes {noinline = false} { + %flip = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc261) + %flip_0 = tt.expand_dims %flip {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc262) + %flip_1 = tt.expand_dims %flip_0 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc262) + %flip_2 = tt.broadcast %flip_1 : tensor<1x2x1xi32> -> tensor<64x2x4xi32> loc(#loc263) + %flip_3 = tt.reshape %flip_2 : tensor<64x2x4xi32> -> tensor<32x16xi32> loc(#loc264) + %0:2 = tt.call @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S32_16S_i32S32_16S_i32S32_16S__(2,)cconstexpr_None__(4,)cconstexpr_2__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x, %idxs, %flip_3) : (tensor<32x16xi32>, tensor<32x16xi32>, tensor<32x16xi32>) -> (tensor<32x16xi32>, tensor<32x16xi32>) loc(#loc100) + %1:2 = tt.call @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S32_16S_i32S32_16S_i32S32_16S__(2,)cconstexpr_None__(4,)cconstexpr_3__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%0#0, %0#1, %flip_3) : (tensor<32x16xi32>, tensor<32x16xi32>, tensor<32x16xi32>) -> (tensor<32x16xi32>, tensor<32x16xi32>) loc(#loc100) + tt.return %1#0, %1#1 : tensor<32x16xi32>, tensor<32x16xi32> loc(#loc101) + ^bb1: // no predecessors + %2 = ub.poison : tensor<32x16xi32> loc(#loc102) + %3 = ub.poison : tensor<32x16xi32> loc(#loc102) + tt.return %2, %3 : tensor<32x16xi32>, tensor<32x16xi32> loc(#loc102) + } loc(#loc95) + tt.func private @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S32_16S_i32S32_16S_i32S32_16S__(2,)cconstexpr_None__(4,)cconstexpr_2__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x: tensor<32x16xi32> loc("x"(#loc103)), %idxs: tensor<32x16xi32> loc("idxs"(#loc103)), %flip: tensor<32x16xi32> loc("flip"(#loc103))) -> (tensor<32x16xi32>, tensor<32x16xi32>) attributes {noinline = false} { + %y = tt.reshape %x : tensor<32x16xi32> -> tensor<128x2x2xi32> loc(#loc268) + %right_mask = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc269) + %right_mask_0 = tt.expand_dims %right_mask {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc270) + %right_mask_1 = tt.expand_dims %right_mask_0 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc270) + %left_mask = arith.constant 1 : i32 loc(#loc271) + %left_mask_2 = arith.constant 1 : i32 loc(#loc271) + %left_mask_3 = arith.constant dense<1> : tensor<1x2x1xi32> loc(#loc271) + %left_mask_4 = arith.subi %left_mask_3, %right_mask_1 : tensor<1x2x1xi32> loc(#loc271) + %ileft = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<128x2x2xi32> loc(#loc272) + %ileft_5 = arith.muli %y, %ileft : tensor<128x2x2xi32> loc(#loc272) + %ileft_6 = tt.call @"triton.language.standard.sum__i32S128_2_2S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%ileft_5) : (tensor<128x2x2xi32>) -> tensor<128x2xi32> loc(#loc273) + %ileft_7 = tt.expand_dims %ileft_6 {axis = 1 : i32} : tensor<128x2xi32> -> tensor<128x1x2xi32> loc(#loc274) + %ileft_8 = tt.broadcast %ileft_7 : tensor<128x1x2xi32> -> tensor<128x2x2xi32> loc(#loc275) + %iright = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<128x2x2xi32> loc(#loc276) + %iright_9 = arith.muli %y, %iright : tensor<128x2x2xi32> loc(#loc276) + %iright_10 = tt.call @"triton.language.standard.sum__i32S128_2_2S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%iright_9) : (tensor<128x2x2xi32>) -> tensor<128x2xi32> loc(#loc277) + %iright_11 = tt.expand_dims %iright_10 {axis = 1 : i32} : tensor<128x2xi32> -> tensor<128x1x2xi32> loc(#loc278) + %iright_12 = tt.broadcast %iright_11 : tensor<128x1x2xi32> -> tensor<128x2x2xi32> loc(#loc279) + %ileft_13 = tt.reshape %ileft_8 : tensor<128x2x2xi32> -> tensor<32x16xi32> loc(#loc280) + %iright_14 = tt.reshape %iright_12 : tensor<128x2x2xi32> -> tensor<32x16xi32> loc(#loc281) + %y_idx = tt.reshape %idxs : tensor<32x16xi32> -> tensor<128x2x2xi32> loc(#loc282) + %left_idx = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<128x2x2xi32> loc(#loc284) + %left_idx_15 = arith.muli %y_idx, %left_idx : tensor<128x2x2xi32> loc(#loc284) + %left_idx_16 = tt.call @"triton.language.standard.sum__i32S128_2_2S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%left_idx_15) : (tensor<128x2x2xi32>) -> tensor<128x2xi32> loc(#loc285) + %left_idx_17 = tt.expand_dims %left_idx_16 {axis = 1 : i32} : tensor<128x2xi32> -> tensor<128x1x2xi32> loc(#loc286) + %left_idx_18 = tt.broadcast %left_idx_17 : tensor<128x1x2xi32> -> tensor<128x2x2xi32> loc(#loc287) + %right_idx = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<128x2x2xi32> loc(#loc289) + %right_idx_19 = arith.muli %y_idx, %right_idx : tensor<128x2x2xi32> loc(#loc289) + %right_idx_20 = tt.call @"triton.language.standard.sum__i32S128_2_2S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%right_idx_19) : (tensor<128x2x2xi32>) -> tensor<128x2xi32> loc(#loc290) + %right_idx_21 = tt.expand_dims %right_idx_20 {axis = 1 : i32} : tensor<128x2xi32> -> tensor<128x1x2xi32> loc(#loc291) + %right_idx_22 = tt.broadcast %right_idx_21 : tensor<128x1x2xi32> -> tensor<128x2x2xi32> loc(#loc292) + %left_idx_23 = tt.reshape %left_idx_18 : tensor<128x2x2xi32> -> tensor<32x16xi32> loc(#loc293) + %right_idx_24 = tt.reshape %right_idx_22 : tensor<128x2x2xi32> -> tensor<32x16xi32> loc(#loc294) + %left_valid_mask = arith.constant true loc(#loc295) + %left_valid_mask_25 = arith.constant dense : tensor<32x16xi1> loc(#loc295) + %right_valid_mask = arith.constant true loc(#loc296) + %right_valid_mask_26 = arith.constant dense : tensor<32x16xi1> loc(#loc296) + %left_isnan = arith.cmpi ne, %ileft_13, %ileft_13 : tensor<32x16xi32> loc(#loc297) + %right_isnan = arith.cmpi ne, %iright_14, %iright_14 : tensor<32x16xi32> loc(#loc298) + %cond = arith.cmpi slt, %ileft_13, %iright_14 : tensor<32x16xi32> loc(#loc331) + %0 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S32_16S__(%ileft_13) : (tensor<32x16xi32>) -> i1 loc(#loc136) + %1 = scf.if %0 -> (tensor<32x16xi1>) { + %cond_45 = arith.constant true loc(#loc300) + %cond_46 = arith.constant dense : tensor<32x16xi1> loc(#loc300) + %cond_47 = arith.xori %left_isnan, %cond_46 : tensor<32x16xi1> loc(#loc300) + %cond_48 = arith.andi %right_isnan, %cond_47 : tensor<32x16xi1> loc(#loc301) + %cond_49 = arith.ori %cond, %cond_48 : tensor<32x16xi1> loc(#loc332) + scf.yield %cond_49 : tensor<32x16xi1> loc(#loc332) + } else { + scf.yield %cond : tensor<32x16xi1> loc(#loc141) + } loc(#loc137) + %eq = arith.cmpi eq, %ileft_13, %iright_14 : tensor<32x16xi32> loc(#loc333) + %2 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S32_16S__(%ileft_13) : (tensor<32x16xi32>) -> i1 loc(#loc143) + %3 = scf.if %2 -> (tensor<32x16xi1>) { + %eq_45 = arith.andi %left_isnan, %right_isnan : tensor<32x16xi1> loc(#loc304) + %eq_46 = arith.ori %eq, %eq_45 : tensor<32x16xi1> loc(#loc334) + scf.yield %eq_46 : tensor<32x16xi1> loc(#loc334) + } else { + scf.yield %eq : tensor<32x16xi1> loc(#loc141) + } loc(#loc144) + %cond_27 = arith.cmpi sgt, %left_idx_23, %right_idx_24 : tensor<32x16xi32> loc(#loc306) + %cond_28 = arith.andi %3, %cond_27 : tensor<32x16xi1> loc(#loc307) + %cond_29 = arith.ori %1, %cond_28 : tensor<32x16xi1> loc(#loc308) + %cond_30 = arith.cmpi ugt, %right_valid_mask_26, %left_valid_mask_25 : tensor<32x16xi1> loc(#loc309) + %cond_31 = arith.cmpi eq, %right_valid_mask_26, %left_valid_mask_25 : tensor<32x16xi1> loc(#loc310) + %cond_32 = arith.andi %cond_31, %cond_29 : tensor<32x16xi1> loc(#loc311) + %cond_33 = arith.ori %cond_30, %cond_32 : tensor<32x16xi1> loc(#loc312) + %cond_34 = arith.extui %cond_33 : tensor<32x16xi1> to tensor<32x16xi32> loc(#loc313) + %cond_35 = arith.xori %cond_34, %flip : tensor<32x16xi32> loc(#loc313) + %cond_36 = arith.constant 0 : i32 loc(#loc314) + %cond_37 = arith.constant dense<0> : tensor<32x16xi32> loc(#loc314) + %cond_38 = arith.cmpi ne, %cond_35, %cond_37 : tensor<32x16xi32> loc(#loc314) + %ret = arith.xori %ileft_13, %iright_14 : tensor<32x16xi32> loc(#loc315) + %ret_39 = tt.call @triton.language.standard.zeros_like__i32S32_16S__(%x) : (tensor<32x16xi32>) -> tensor<32x16xi32> loc(#loc316) + %ret_40 = arith.select %cond_38, %ret, %ret_39 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc317) + %ret_41 = arith.xori %x, %ret_40 : tensor<32x16xi32> loc(#loc318) + %new_idxs = arith.xori %left_idx_23, %right_idx_24 : tensor<32x16xi32> loc(#loc319) + %new_idxs_42 = tt.call @triton.language.standard.zeros_like__i32S32_16S__(%idxs) : (tensor<32x16xi32>) -> tensor<32x16xi32> loc(#loc320) + %new_idxs_43 = arith.select %cond_38, %new_idxs, %new_idxs_42 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc321) + %new_idxs_44 = arith.xori %idxs, %new_idxs_43 : tensor<32x16xi32> loc(#loc322) + tt.return %ret_41, %new_idxs_44 : tensor<32x16xi32>, tensor<32x16xi32> loc(#loc164) + ^bb1: // no predecessors + %4 = ub.poison : tensor<32x16xi32> loc(#loc165) + %5 = ub.poison : tensor<32x16xi32> loc(#loc165) + tt.return %4, %5 : tensor<32x16xi32>, tensor<32x16xi32> loc(#loc165) + } loc(#loc103) + tt.func private @"triton.language.standard.sum__i32S128_2_2S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%input: tensor<128x2x2xi32> loc("input"(#loc166))) -> tensor<128x2xi32> attributes {noinline = false} { + %0 = "tt.reduce"(%input) <{axis = 1 : i32}> ({ + ^bb0(%arg1: i32 loc(unknown), %arg2: i32 loc(unknown)): + %2 = tt.call @triton.language.standard._sum_combine__i32_i32__(%arg1, %arg2) : (i32, i32) -> i32 loc(#loc167) + tt.reduce.return %2 : i32 loc(#loc167) + }) : (tensor<128x2x2xi32>) -> tensor<128x2xi32> loc(#loc167) + tt.return %0 : tensor<128x2xi32> loc(#loc168) + ^bb1: // no predecessors + %1 = ub.poison : tensor<128x2xi32> loc(#loc169) + tt.return %1 : tensor<128x2xi32> loc(#loc169) + } loc(#loc166) + tt.func private @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S32_16S_i32S32_16S_i32S32_16S__(2,)cconstexpr_None__(4,)cconstexpr_3__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x: tensor<32x16xi32> loc("x"(#loc103)), %idxs: tensor<32x16xi32> loc("idxs"(#loc103)), %flip: tensor<32x16xi32> loc("flip"(#loc103))) -> (tensor<32x16xi32>, tensor<32x16xi32>) attributes {noinline = false} { + %y = tt.reshape %x : tensor<32x16xi32> -> tensor<256x2x1xi32> loc(#loc268) + %right_mask = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc269) + %right_mask_0 = tt.expand_dims %right_mask {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc270) + %right_mask_1 = tt.expand_dims %right_mask_0 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc270) + %left_mask = arith.constant 1 : i32 loc(#loc271) + %left_mask_2 = arith.constant 1 : i32 loc(#loc271) + %left_mask_3 = arith.constant dense<1> : tensor<1x2x1xi32> loc(#loc271) + %left_mask_4 = arith.subi %left_mask_3, %right_mask_1 : tensor<1x2x1xi32> loc(#loc271) + %ileft = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<256x2x1xi32> loc(#loc272) + %ileft_5 = arith.muli %y, %ileft : tensor<256x2x1xi32> loc(#loc272) + %ileft_6 = tt.call @"triton.language.standard.sum__i32S256_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%ileft_5) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc273) + %ileft_7 = tt.expand_dims %ileft_6 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc274) + %ileft_8 = tt.broadcast %ileft_7 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc275) + %iright = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<256x2x1xi32> loc(#loc276) + %iright_9 = arith.muli %y, %iright : tensor<256x2x1xi32> loc(#loc276) + %iright_10 = tt.call @"triton.language.standard.sum__i32S256_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%iright_9) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc277) + %iright_11 = tt.expand_dims %iright_10 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc278) + %iright_12 = tt.broadcast %iright_11 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc279) + %ileft_13 = tt.reshape %ileft_8 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc280) + %iright_14 = tt.reshape %iright_12 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc281) + %y_idx = tt.reshape %idxs : tensor<32x16xi32> -> tensor<256x2x1xi32> loc(#loc282) + %left_idx = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<256x2x1xi32> loc(#loc284) + %left_idx_15 = arith.muli %y_idx, %left_idx : tensor<256x2x1xi32> loc(#loc284) + %left_idx_16 = tt.call @"triton.language.standard.sum__i32S256_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%left_idx_15) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc285) + %left_idx_17 = tt.expand_dims %left_idx_16 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc286) + %left_idx_18 = tt.broadcast %left_idx_17 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc287) + %right_idx = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<256x2x1xi32> loc(#loc289) + %right_idx_19 = arith.muli %y_idx, %right_idx : tensor<256x2x1xi32> loc(#loc289) + %right_idx_20 = tt.call @"triton.language.standard.sum__i32S256_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%right_idx_19) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc290) + %right_idx_21 = tt.expand_dims %right_idx_20 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc291) + %right_idx_22 = tt.broadcast %right_idx_21 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc292) + %left_idx_23 = tt.reshape %left_idx_18 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc293) + %right_idx_24 = tt.reshape %right_idx_22 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc294) + %left_valid_mask = arith.constant true loc(#loc295) + %left_valid_mask_25 = arith.constant dense : tensor<32x16xi1> loc(#loc295) + %right_valid_mask = arith.constant true loc(#loc296) + %right_valid_mask_26 = arith.constant dense : tensor<32x16xi1> loc(#loc296) + %left_isnan = arith.cmpi ne, %ileft_13, %ileft_13 : tensor<32x16xi32> loc(#loc297) + %right_isnan = arith.cmpi ne, %iright_14, %iright_14 : tensor<32x16xi32> loc(#loc298) + %cond = arith.cmpi slt, %ileft_13, %iright_14 : tensor<32x16xi32> loc(#loc331) + %0 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S32_16S__(%ileft_13) : (tensor<32x16xi32>) -> i1 loc(#loc136) + %1 = scf.if %0 -> (tensor<32x16xi1>) { + %cond_45 = arith.constant true loc(#loc300) + %cond_46 = arith.constant dense : tensor<32x16xi1> loc(#loc300) + %cond_47 = arith.xori %left_isnan, %cond_46 : tensor<32x16xi1> loc(#loc300) + %cond_48 = arith.andi %right_isnan, %cond_47 : tensor<32x16xi1> loc(#loc301) + %cond_49 = arith.ori %cond, %cond_48 : tensor<32x16xi1> loc(#loc332) + scf.yield %cond_49 : tensor<32x16xi1> loc(#loc332) + } else { + scf.yield %cond : tensor<32x16xi1> loc(#loc141) + } loc(#loc137) + %eq = arith.cmpi eq, %ileft_13, %iright_14 : tensor<32x16xi32> loc(#loc333) + %2 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S32_16S__(%ileft_13) : (tensor<32x16xi32>) -> i1 loc(#loc143) + %3 = scf.if %2 -> (tensor<32x16xi1>) { + %eq_45 = arith.andi %left_isnan, %right_isnan : tensor<32x16xi1> loc(#loc304) + %eq_46 = arith.ori %eq, %eq_45 : tensor<32x16xi1> loc(#loc334) + scf.yield %eq_46 : tensor<32x16xi1> loc(#loc334) + } else { + scf.yield %eq : tensor<32x16xi1> loc(#loc141) + } loc(#loc144) + %cond_27 = arith.cmpi sgt, %left_idx_23, %right_idx_24 : tensor<32x16xi32> loc(#loc306) + %cond_28 = arith.andi %3, %cond_27 : tensor<32x16xi1> loc(#loc307) + %cond_29 = arith.ori %1, %cond_28 : tensor<32x16xi1> loc(#loc308) + %cond_30 = arith.cmpi ugt, %right_valid_mask_26, %left_valid_mask_25 : tensor<32x16xi1> loc(#loc309) + %cond_31 = arith.cmpi eq, %right_valid_mask_26, %left_valid_mask_25 : tensor<32x16xi1> loc(#loc310) + %cond_32 = arith.andi %cond_31, %cond_29 : tensor<32x16xi1> loc(#loc311) + %cond_33 = arith.ori %cond_30, %cond_32 : tensor<32x16xi1> loc(#loc312) + %cond_34 = arith.extui %cond_33 : tensor<32x16xi1> to tensor<32x16xi32> loc(#loc313) + %cond_35 = arith.xori %cond_34, %flip : tensor<32x16xi32> loc(#loc313) + %cond_36 = arith.constant 0 : i32 loc(#loc314) + %cond_37 = arith.constant dense<0> : tensor<32x16xi32> loc(#loc314) + %cond_38 = arith.cmpi ne, %cond_35, %cond_37 : tensor<32x16xi32> loc(#loc314) + %ret = arith.xori %ileft_13, %iright_14 : tensor<32x16xi32> loc(#loc315) + %ret_39 = tt.call @triton.language.standard.zeros_like__i32S32_16S__(%x) : (tensor<32x16xi32>) -> tensor<32x16xi32> loc(#loc316) + %ret_40 = arith.select %cond_38, %ret, %ret_39 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc317) + %ret_41 = arith.xori %x, %ret_40 : tensor<32x16xi32> loc(#loc318) + %new_idxs = arith.xori %left_idx_23, %right_idx_24 : tensor<32x16xi32> loc(#loc319) + %new_idxs_42 = tt.call @triton.language.standard.zeros_like__i32S32_16S__(%idxs) : (tensor<32x16xi32>) -> tensor<32x16xi32> loc(#loc320) + %new_idxs_43 = arith.select %cond_38, %new_idxs, %new_idxs_42 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc321) + %new_idxs_44 = arith.xori %idxs, %new_idxs_43 : tensor<32x16xi32> loc(#loc322) + tt.return %ret_41, %new_idxs_44 : tensor<32x16xi32>, tensor<32x16xi32> loc(#loc164) + ^bb1: // no predecessors + %4 = ub.poison : tensor<32x16xi32> loc(#loc165) + %5 = ub.poison : tensor<32x16xi32> loc(#loc165) + tt.return %4, %5 : tensor<32x16xi32>, tensor<32x16xi32> loc(#loc165) + } loc(#loc103) + tt.func private @"torch._inductor.runtime.triton_helpers._bitonic_merge_with_index__i32S32_16S_i32S32_16S__(2,)cconstexpr_None__(3,)cconstexpr_3__(4,)cconstexpr_True__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x: tensor<32x16xi32> loc("x"(#loc95)), %idxs: tensor<32x16xi32> loc("idxs"(#loc95))) -> (tensor<32x16xi32>, tensor<32x16xi32>) attributes {noinline = false} { + %flip = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc261) + %flip_0 = tt.expand_dims %flip {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc262) + %flip_1 = tt.expand_dims %flip_0 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc262) + %flip_2 = tt.broadcast %flip_1 : tensor<1x2x1xi32> -> tensor<32x2x8xi32> loc(#loc263) + %flip_3 = tt.reshape %flip_2 : tensor<32x2x8xi32> -> tensor<32x16xi32> loc(#loc264) + %0:2 = tt.call @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S32_16S_i32S32_16S_i32S32_16S__(2,)cconstexpr_None__(4,)cconstexpr_1__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x, %idxs, %flip_3) : (tensor<32x16xi32>, tensor<32x16xi32>, tensor<32x16xi32>) -> (tensor<32x16xi32>, tensor<32x16xi32>) loc(#loc100) + %1:2 = tt.call @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S32_16S_i32S32_16S_i32S32_16S__(2,)cconstexpr_None__(4,)cconstexpr_2__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%0#0, %0#1, %flip_3) : (tensor<32x16xi32>, tensor<32x16xi32>, tensor<32x16xi32>) -> (tensor<32x16xi32>, tensor<32x16xi32>) loc(#loc100) + %2:2 = tt.call @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S32_16S_i32S32_16S_i32S32_16S__(2,)cconstexpr_None__(4,)cconstexpr_3__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%1#0, %1#1, %flip_3) : (tensor<32x16xi32>, tensor<32x16xi32>, tensor<32x16xi32>) -> (tensor<32x16xi32>, tensor<32x16xi32>) loc(#loc100) + tt.return %2#0, %2#1 : tensor<32x16xi32>, tensor<32x16xi32> loc(#loc101) + ^bb1: // no predecessors + %3 = ub.poison : tensor<32x16xi32> loc(#loc102) + %4 = ub.poison : tensor<32x16xi32> loc(#loc102) + tt.return %3, %4 : tensor<32x16xi32>, tensor<32x16xi32> loc(#loc102) + } loc(#loc95) + tt.func private @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S32_16S_i32S32_16S_i32S32_16S__(2,)cconstexpr_None__(4,)cconstexpr_1__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x: tensor<32x16xi32> loc("x"(#loc103)), %idxs: tensor<32x16xi32> loc("idxs"(#loc103)), %flip: tensor<32x16xi32> loc("flip"(#loc103))) -> (tensor<32x16xi32>, tensor<32x16xi32>) attributes {noinline = false} { + %y = tt.reshape %x : tensor<32x16xi32> -> tensor<64x2x4xi32> loc(#loc268) + %right_mask = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc269) + %right_mask_0 = tt.expand_dims %right_mask {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc270) + %right_mask_1 = tt.expand_dims %right_mask_0 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc270) + %left_mask = arith.constant 1 : i32 loc(#loc271) + %left_mask_2 = arith.constant 1 : i32 loc(#loc271) + %left_mask_3 = arith.constant dense<1> : tensor<1x2x1xi32> loc(#loc271) + %left_mask_4 = arith.subi %left_mask_3, %right_mask_1 : tensor<1x2x1xi32> loc(#loc271) + %ileft = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<64x2x4xi32> loc(#loc272) + %ileft_5 = arith.muli %y, %ileft : tensor<64x2x4xi32> loc(#loc272) + %ileft_6 = tt.call @"triton.language.standard.sum__i32S64_2_4S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%ileft_5) : (tensor<64x2x4xi32>) -> tensor<64x4xi32> loc(#loc273) + %ileft_7 = tt.expand_dims %ileft_6 {axis = 1 : i32} : tensor<64x4xi32> -> tensor<64x1x4xi32> loc(#loc274) + %ileft_8 = tt.broadcast %ileft_7 : tensor<64x1x4xi32> -> tensor<64x2x4xi32> loc(#loc275) + %iright = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<64x2x4xi32> loc(#loc276) + %iright_9 = arith.muli %y, %iright : tensor<64x2x4xi32> loc(#loc276) + %iright_10 = tt.call @"triton.language.standard.sum__i32S64_2_4S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%iright_9) : (tensor<64x2x4xi32>) -> tensor<64x4xi32> loc(#loc277) + %iright_11 = tt.expand_dims %iright_10 {axis = 1 : i32} : tensor<64x4xi32> -> tensor<64x1x4xi32> loc(#loc278) + %iright_12 = tt.broadcast %iright_11 : tensor<64x1x4xi32> -> tensor<64x2x4xi32> loc(#loc279) + %ileft_13 = tt.reshape %ileft_8 : tensor<64x2x4xi32> -> tensor<32x16xi32> loc(#loc280) + %iright_14 = tt.reshape %iright_12 : tensor<64x2x4xi32> -> tensor<32x16xi32> loc(#loc281) + %y_idx = tt.reshape %idxs : tensor<32x16xi32> -> tensor<64x2x4xi32> loc(#loc282) + %left_idx = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<64x2x4xi32> loc(#loc284) + %left_idx_15 = arith.muli %y_idx, %left_idx : tensor<64x2x4xi32> loc(#loc284) + %left_idx_16 = tt.call @"triton.language.standard.sum__i32S64_2_4S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%left_idx_15) : (tensor<64x2x4xi32>) -> tensor<64x4xi32> loc(#loc285) + %left_idx_17 = tt.expand_dims %left_idx_16 {axis = 1 : i32} : tensor<64x4xi32> -> tensor<64x1x4xi32> loc(#loc286) + %left_idx_18 = tt.broadcast %left_idx_17 : tensor<64x1x4xi32> -> tensor<64x2x4xi32> loc(#loc287) + %right_idx = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<64x2x4xi32> loc(#loc289) + %right_idx_19 = arith.muli %y_idx, %right_idx : tensor<64x2x4xi32> loc(#loc289) + %right_idx_20 = tt.call @"triton.language.standard.sum__i32S64_2_4S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%right_idx_19) : (tensor<64x2x4xi32>) -> tensor<64x4xi32> loc(#loc290) + %right_idx_21 = tt.expand_dims %right_idx_20 {axis = 1 : i32} : tensor<64x4xi32> -> tensor<64x1x4xi32> loc(#loc291) + %right_idx_22 = tt.broadcast %right_idx_21 : tensor<64x1x4xi32> -> tensor<64x2x4xi32> loc(#loc292) + %left_idx_23 = tt.reshape %left_idx_18 : tensor<64x2x4xi32> -> tensor<32x16xi32> loc(#loc293) + %right_idx_24 = tt.reshape %right_idx_22 : tensor<64x2x4xi32> -> tensor<32x16xi32> loc(#loc294) + %left_valid_mask = arith.constant true loc(#loc295) + %left_valid_mask_25 = arith.constant dense : tensor<32x16xi1> loc(#loc295) + %right_valid_mask = arith.constant true loc(#loc296) + %right_valid_mask_26 = arith.constant dense : tensor<32x16xi1> loc(#loc296) + %left_isnan = arith.cmpi ne, %ileft_13, %ileft_13 : tensor<32x16xi32> loc(#loc297) + %right_isnan = arith.cmpi ne, %iright_14, %iright_14 : tensor<32x16xi32> loc(#loc298) + %cond = arith.cmpi slt, %ileft_13, %iright_14 : tensor<32x16xi32> loc(#loc331) + %0 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S32_16S__(%ileft_13) : (tensor<32x16xi32>) -> i1 loc(#loc136) + %1 = scf.if %0 -> (tensor<32x16xi1>) { + %cond_45 = arith.constant true loc(#loc300) + %cond_46 = arith.constant dense : tensor<32x16xi1> loc(#loc300) + %cond_47 = arith.xori %left_isnan, %cond_46 : tensor<32x16xi1> loc(#loc300) + %cond_48 = arith.andi %right_isnan, %cond_47 : tensor<32x16xi1> loc(#loc301) + %cond_49 = arith.ori %cond, %cond_48 : tensor<32x16xi1> loc(#loc332) + scf.yield %cond_49 : tensor<32x16xi1> loc(#loc332) + } else { + scf.yield %cond : tensor<32x16xi1> loc(#loc141) + } loc(#loc137) + %eq = arith.cmpi eq, %ileft_13, %iright_14 : tensor<32x16xi32> loc(#loc333) + %2 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S32_16S__(%ileft_13) : (tensor<32x16xi32>) -> i1 loc(#loc143) + %3 = scf.if %2 -> (tensor<32x16xi1>) { + %eq_45 = arith.andi %left_isnan, %right_isnan : tensor<32x16xi1> loc(#loc304) + %eq_46 = arith.ori %eq, %eq_45 : tensor<32x16xi1> loc(#loc334) + scf.yield %eq_46 : tensor<32x16xi1> loc(#loc334) + } else { + scf.yield %eq : tensor<32x16xi1> loc(#loc141) + } loc(#loc144) + %cond_27 = arith.cmpi sgt, %left_idx_23, %right_idx_24 : tensor<32x16xi32> loc(#loc306) + %cond_28 = arith.andi %3, %cond_27 : tensor<32x16xi1> loc(#loc307) + %cond_29 = arith.ori %1, %cond_28 : tensor<32x16xi1> loc(#loc308) + %cond_30 = arith.cmpi ugt, %right_valid_mask_26, %left_valid_mask_25 : tensor<32x16xi1> loc(#loc309) + %cond_31 = arith.cmpi eq, %right_valid_mask_26, %left_valid_mask_25 : tensor<32x16xi1> loc(#loc310) + %cond_32 = arith.andi %cond_31, %cond_29 : tensor<32x16xi1> loc(#loc311) + %cond_33 = arith.ori %cond_30, %cond_32 : tensor<32x16xi1> loc(#loc312) + %cond_34 = arith.extui %cond_33 : tensor<32x16xi1> to tensor<32x16xi32> loc(#loc313) + %cond_35 = arith.xori %cond_34, %flip : tensor<32x16xi32> loc(#loc313) + %cond_36 = arith.constant 0 : i32 loc(#loc314) + %cond_37 = arith.constant dense<0> : tensor<32x16xi32> loc(#loc314) + %cond_38 = arith.cmpi ne, %cond_35, %cond_37 : tensor<32x16xi32> loc(#loc314) + %ret = arith.xori %ileft_13, %iright_14 : tensor<32x16xi32> loc(#loc315) + %ret_39 = tt.call @triton.language.standard.zeros_like__i32S32_16S__(%x) : (tensor<32x16xi32>) -> tensor<32x16xi32> loc(#loc316) + %ret_40 = arith.select %cond_38, %ret, %ret_39 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc317) + %ret_41 = arith.xori %x, %ret_40 : tensor<32x16xi32> loc(#loc318) + %new_idxs = arith.xori %left_idx_23, %right_idx_24 : tensor<32x16xi32> loc(#loc319) + %new_idxs_42 = tt.call @triton.language.standard.zeros_like__i32S32_16S__(%idxs) : (tensor<32x16xi32>) -> tensor<32x16xi32> loc(#loc320) + %new_idxs_43 = arith.select %cond_38, %new_idxs, %new_idxs_42 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc321) + %new_idxs_44 = arith.xori %idxs, %new_idxs_43 : tensor<32x16xi32> loc(#loc322) + tt.return %ret_41, %new_idxs_44 : tensor<32x16xi32>, tensor<32x16xi32> loc(#loc164) + ^bb1: // no predecessors + %4 = ub.poison : tensor<32x16xi32> loc(#loc165) + %5 = ub.poison : tensor<32x16xi32> loc(#loc165) + tt.return %4, %5 : tensor<32x16xi32>, tensor<32x16xi32> loc(#loc165) + } loc(#loc103) + tt.func private @"triton.language.standard.sum__i32S64_2_4S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%input: tensor<64x2x4xi32> loc("input"(#loc166))) -> tensor<64x4xi32> attributes {noinline = false} { + %0 = "tt.reduce"(%input) <{axis = 1 : i32}> ({ + ^bb0(%arg1: i32 loc(unknown), %arg2: i32 loc(unknown)): + %2 = tt.call @triton.language.standard._sum_combine__i32_i32__(%arg1, %arg2) : (i32, i32) -> i32 loc(#loc167) + tt.reduce.return %2 : i32 loc(#loc167) + }) : (tensor<64x2x4xi32>) -> tensor<64x4xi32> loc(#loc167) + tt.return %0 : tensor<64x4xi32> loc(#loc168) + ^bb1: // no predecessors + %1 = ub.poison : tensor<64x4xi32> loc(#loc169) + tt.return %1 : tensor<64x4xi32> loc(#loc169) + } loc(#loc166) + tt.func private @"torch._inductor.runtime.triton_helpers._bitonic_merge_with_index__i32S32_16S_i32S32_16S__(2,)cconstexpr_None__(3,)cconstexpr_4__(4,)cconstexpr_False__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x: tensor<32x16xi32> loc("x"(#loc95)), %idxs: tensor<32x16xi32> loc("idxs"(#loc95))) -> (tensor<32x16xi32>, tensor<32x16xi32>) attributes {noinline = false} { + %flip = arith.constant false loc(#loc330) + %0:2 = tt.call @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S32_16S_i32S32_16S_u1__(2,)cconstexpr_None__(4,)cconstexpr_0__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x, %idxs, %flip) : (tensor<32x16xi32>, tensor<32x16xi32>, i1) -> (tensor<32x16xi32>, tensor<32x16xi32>) loc(#loc100) + %1:2 = tt.call @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S32_16S_i32S32_16S_u1__(2,)cconstexpr_None__(4,)cconstexpr_1__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%0#0, %0#1, %flip) : (tensor<32x16xi32>, tensor<32x16xi32>, i1) -> (tensor<32x16xi32>, tensor<32x16xi32>) loc(#loc100) + %2:2 = tt.call @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S32_16S_i32S32_16S_u1__(2,)cconstexpr_None__(4,)cconstexpr_2__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%1#0, %1#1, %flip) : (tensor<32x16xi32>, tensor<32x16xi32>, i1) -> (tensor<32x16xi32>, tensor<32x16xi32>) loc(#loc100) + %3:2 = tt.call @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S32_16S_i32S32_16S_u1__(2,)cconstexpr_None__(4,)cconstexpr_3__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%2#0, %2#1, %flip) : (tensor<32x16xi32>, tensor<32x16xi32>, i1) -> (tensor<32x16xi32>, tensor<32x16xi32>) loc(#loc100) + tt.return %3#0, %3#1 : tensor<32x16xi32>, tensor<32x16xi32> loc(#loc101) + ^bb1: // no predecessors + %4 = ub.poison : tensor<32x16xi32> loc(#loc102) + %5 = ub.poison : tensor<32x16xi32> loc(#loc102) + tt.return %4, %5 : tensor<32x16xi32>, tensor<32x16xi32> loc(#loc102) + } loc(#loc95) + tt.func private @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S32_16S_i32S32_16S_u1__(2,)cconstexpr_None__(4,)cconstexpr_0__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x: tensor<32x16xi32> loc("x"(#loc103)), %idxs: tensor<32x16xi32> loc("idxs"(#loc103)), %flip: i1 loc("flip"(#loc103))) -> (tensor<32x16xi32>, tensor<32x16xi32>) attributes {noinline = false} { + %y = tt.reshape %x : tensor<32x16xi32> -> tensor<32x2x8xi32> loc(#loc268) + %right_mask = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc269) + %right_mask_0 = tt.expand_dims %right_mask {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc270) + %right_mask_1 = tt.expand_dims %right_mask_0 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc270) + %left_mask = arith.constant 1 : i32 loc(#loc271) + %left_mask_2 = arith.constant 1 : i32 loc(#loc271) + %left_mask_3 = arith.constant dense<1> : tensor<1x2x1xi32> loc(#loc271) + %left_mask_4 = arith.subi %left_mask_3, %right_mask_1 : tensor<1x2x1xi32> loc(#loc271) + %ileft = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<32x2x8xi32> loc(#loc272) + %ileft_5 = arith.muli %y, %ileft : tensor<32x2x8xi32> loc(#loc272) + %ileft_6 = tt.call @"triton.language.standard.sum__i32S32_2_8S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%ileft_5) : (tensor<32x2x8xi32>) -> tensor<32x8xi32> loc(#loc273) + %ileft_7 = tt.expand_dims %ileft_6 {axis = 1 : i32} : tensor<32x8xi32> -> tensor<32x1x8xi32> loc(#loc274) + %ileft_8 = tt.broadcast %ileft_7 : tensor<32x1x8xi32> -> tensor<32x2x8xi32> loc(#loc275) + %iright = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<32x2x8xi32> loc(#loc276) + %iright_9 = arith.muli %y, %iright : tensor<32x2x8xi32> loc(#loc276) + %iright_10 = tt.call @"triton.language.standard.sum__i32S32_2_8S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%iright_9) : (tensor<32x2x8xi32>) -> tensor<32x8xi32> loc(#loc277) + %iright_11 = tt.expand_dims %iright_10 {axis = 1 : i32} : tensor<32x8xi32> -> tensor<32x1x8xi32> loc(#loc278) + %iright_12 = tt.broadcast %iright_11 : tensor<32x1x8xi32> -> tensor<32x2x8xi32> loc(#loc279) + %ileft_13 = tt.reshape %ileft_8 : tensor<32x2x8xi32> -> tensor<32x16xi32> loc(#loc280) + %iright_14 = tt.reshape %iright_12 : tensor<32x2x8xi32> -> tensor<32x16xi32> loc(#loc281) + %y_idx = tt.reshape %idxs : tensor<32x16xi32> -> tensor<32x2x8xi32> loc(#loc282) + %left_idx = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<32x2x8xi32> loc(#loc284) + %left_idx_15 = arith.muli %y_idx, %left_idx : tensor<32x2x8xi32> loc(#loc284) + %left_idx_16 = tt.call @"triton.language.standard.sum__i32S32_2_8S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%left_idx_15) : (tensor<32x2x8xi32>) -> tensor<32x8xi32> loc(#loc285) + %left_idx_17 = tt.expand_dims %left_idx_16 {axis = 1 : i32} : tensor<32x8xi32> -> tensor<32x1x8xi32> loc(#loc286) + %left_idx_18 = tt.broadcast %left_idx_17 : tensor<32x1x8xi32> -> tensor<32x2x8xi32> loc(#loc287) + %right_idx = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<32x2x8xi32> loc(#loc289) + %right_idx_19 = arith.muli %y_idx, %right_idx : tensor<32x2x8xi32> loc(#loc289) + %right_idx_20 = tt.call @"triton.language.standard.sum__i32S32_2_8S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%right_idx_19) : (tensor<32x2x8xi32>) -> tensor<32x8xi32> loc(#loc290) + %right_idx_21 = tt.expand_dims %right_idx_20 {axis = 1 : i32} : tensor<32x8xi32> -> tensor<32x1x8xi32> loc(#loc291) + %right_idx_22 = tt.broadcast %right_idx_21 : tensor<32x1x8xi32> -> tensor<32x2x8xi32> loc(#loc292) + %left_idx_23 = tt.reshape %left_idx_18 : tensor<32x2x8xi32> -> tensor<32x16xi32> loc(#loc293) + %right_idx_24 = tt.reshape %right_idx_22 : tensor<32x2x8xi32> -> tensor<32x16xi32> loc(#loc294) + %left_valid_mask = arith.constant true loc(#loc295) + %left_valid_mask_25 = arith.constant dense : tensor<32x16xi1> loc(#loc295) + %right_valid_mask = arith.constant true loc(#loc296) + %right_valid_mask_26 = arith.constant dense : tensor<32x16xi1> loc(#loc296) + %left_isnan = arith.cmpi ne, %ileft_13, %ileft_13 : tensor<32x16xi32> loc(#loc297) + %right_isnan = arith.cmpi ne, %iright_14, %iright_14 : tensor<32x16xi32> loc(#loc298) + %cond = arith.cmpi slt, %ileft_13, %iright_14 : tensor<32x16xi32> loc(#loc331) + %0 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S32_16S__(%ileft_13) : (tensor<32x16xi32>) -> i1 loc(#loc136) + %1 = scf.if %0 -> (tensor<32x16xi1>) { + %cond_42 = arith.constant true loc(#loc300) + %cond_43 = arith.constant dense : tensor<32x16xi1> loc(#loc300) + %cond_44 = arith.xori %left_isnan, %cond_43 : tensor<32x16xi1> loc(#loc300) + %cond_45 = arith.andi %right_isnan, %cond_44 : tensor<32x16xi1> loc(#loc301) + %cond_46 = arith.ori %cond, %cond_45 : tensor<32x16xi1> loc(#loc332) + scf.yield %cond_46 : tensor<32x16xi1> loc(#loc332) + } else { + scf.yield %cond : tensor<32x16xi1> loc(#loc141) + } loc(#loc137) + %eq = arith.cmpi eq, %ileft_13, %iright_14 : tensor<32x16xi32> loc(#loc333) + %2 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S32_16S__(%ileft_13) : (tensor<32x16xi32>) -> i1 loc(#loc143) + %3 = scf.if %2 -> (tensor<32x16xi1>) { + %eq_42 = arith.andi %left_isnan, %right_isnan : tensor<32x16xi1> loc(#loc304) + %eq_43 = arith.ori %eq, %eq_42 : tensor<32x16xi1> loc(#loc334) + scf.yield %eq_43 : tensor<32x16xi1> loc(#loc334) + } else { + scf.yield %eq : tensor<32x16xi1> loc(#loc141) + } loc(#loc144) + %cond_27 = arith.cmpi sgt, %left_idx_23, %right_idx_24 : tensor<32x16xi32> loc(#loc306) + %cond_28 = arith.andi %3, %cond_27 : tensor<32x16xi1> loc(#loc307) + %cond_29 = arith.ori %1, %cond_28 : tensor<32x16xi1> loc(#loc308) + %cond_30 = arith.cmpi ugt, %right_valid_mask_26, %left_valid_mask_25 : tensor<32x16xi1> loc(#loc309) + %cond_31 = arith.cmpi eq, %right_valid_mask_26, %left_valid_mask_25 : tensor<32x16xi1> loc(#loc310) + %cond_32 = arith.andi %cond_31, %cond_29 : tensor<32x16xi1> loc(#loc311) + %cond_33 = arith.ori %cond_30, %cond_32 : tensor<32x16xi1> loc(#loc312) + %cond_34 = tt.splat %flip : i1 -> tensor<32x16xi1> loc(#loc313) + %cond_35 = arith.xori %cond_33, %cond_34 : tensor<32x16xi1> loc(#loc313) + %ret = arith.xori %ileft_13, %iright_14 : tensor<32x16xi32> loc(#loc315) + %ret_36 = tt.call @triton.language.standard.zeros_like__i32S32_16S__(%x) : (tensor<32x16xi32>) -> tensor<32x16xi32> loc(#loc316) + %ret_37 = arith.select %cond_35, %ret, %ret_36 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc317) + %ret_38 = arith.xori %x, %ret_37 : tensor<32x16xi32> loc(#loc318) + %new_idxs = arith.xori %left_idx_23, %right_idx_24 : tensor<32x16xi32> loc(#loc319) + %new_idxs_39 = tt.call @triton.language.standard.zeros_like__i32S32_16S__(%idxs) : (tensor<32x16xi32>) -> tensor<32x16xi32> loc(#loc320) + %new_idxs_40 = arith.select %cond_35, %new_idxs, %new_idxs_39 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc321) + %new_idxs_41 = arith.xori %idxs, %new_idxs_40 : tensor<32x16xi32> loc(#loc322) + tt.return %ret_38, %new_idxs_41 : tensor<32x16xi32>, tensor<32x16xi32> loc(#loc164) + ^bb1: // no predecessors + %4 = ub.poison : tensor<32x16xi32> loc(#loc165) + %5 = ub.poison : tensor<32x16xi32> loc(#loc165) + tt.return %4, %5 : tensor<32x16xi32>, tensor<32x16xi32> loc(#loc165) + } loc(#loc103) + tt.func private @"triton.language.standard.sum__i32S32_2_8S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%input: tensor<32x2x8xi32> loc("input"(#loc166))) -> tensor<32x8xi32> attributes {noinline = false} { + %0 = "tt.reduce"(%input) <{axis = 1 : i32}> ({ + ^bb0(%arg1: i32 loc(unknown), %arg2: i32 loc(unknown)): + %2 = tt.call @triton.language.standard._sum_combine__i32_i32__(%arg1, %arg2) : (i32, i32) -> i32 loc(#loc167) + tt.reduce.return %2 : i32 loc(#loc167) + }) : (tensor<32x2x8xi32>) -> tensor<32x8xi32> loc(#loc167) + tt.return %0 : tensor<32x8xi32> loc(#loc168) + ^bb1: // no predecessors + %1 = ub.poison : tensor<32x8xi32> loc(#loc169) + tt.return %1 : tensor<32x8xi32> loc(#loc169) + } loc(#loc166) + tt.func private @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S32_16S_i32S32_16S_u1__(2,)cconstexpr_None__(4,)cconstexpr_1__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x: tensor<32x16xi32> loc("x"(#loc103)), %idxs: tensor<32x16xi32> loc("idxs"(#loc103)), %flip: i1 loc("flip"(#loc103))) -> (tensor<32x16xi32>, tensor<32x16xi32>) attributes {noinline = false} { + %y = tt.reshape %x : tensor<32x16xi32> -> tensor<64x2x4xi32> loc(#loc268) + %right_mask = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc269) + %right_mask_0 = tt.expand_dims %right_mask {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc270) + %right_mask_1 = tt.expand_dims %right_mask_0 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc270) + %left_mask = arith.constant 1 : i32 loc(#loc271) + %left_mask_2 = arith.constant 1 : i32 loc(#loc271) + %left_mask_3 = arith.constant dense<1> : tensor<1x2x1xi32> loc(#loc271) + %left_mask_4 = arith.subi %left_mask_3, %right_mask_1 : tensor<1x2x1xi32> loc(#loc271) + %ileft = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<64x2x4xi32> loc(#loc272) + %ileft_5 = arith.muli %y, %ileft : tensor<64x2x4xi32> loc(#loc272) + %ileft_6 = tt.call @"triton.language.standard.sum__i32S64_2_4S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%ileft_5) : (tensor<64x2x4xi32>) -> tensor<64x4xi32> loc(#loc273) + %ileft_7 = tt.expand_dims %ileft_6 {axis = 1 : i32} : tensor<64x4xi32> -> tensor<64x1x4xi32> loc(#loc274) + %ileft_8 = tt.broadcast %ileft_7 : tensor<64x1x4xi32> -> tensor<64x2x4xi32> loc(#loc275) + %iright = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<64x2x4xi32> loc(#loc276) + %iright_9 = arith.muli %y, %iright : tensor<64x2x4xi32> loc(#loc276) + %iright_10 = tt.call @"triton.language.standard.sum__i32S64_2_4S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%iright_9) : (tensor<64x2x4xi32>) -> tensor<64x4xi32> loc(#loc277) + %iright_11 = tt.expand_dims %iright_10 {axis = 1 : i32} : tensor<64x4xi32> -> tensor<64x1x4xi32> loc(#loc278) + %iright_12 = tt.broadcast %iright_11 : tensor<64x1x4xi32> -> tensor<64x2x4xi32> loc(#loc279) + %ileft_13 = tt.reshape %ileft_8 : tensor<64x2x4xi32> -> tensor<32x16xi32> loc(#loc280) + %iright_14 = tt.reshape %iright_12 : tensor<64x2x4xi32> -> tensor<32x16xi32> loc(#loc281) + %y_idx = tt.reshape %idxs : tensor<32x16xi32> -> tensor<64x2x4xi32> loc(#loc282) + %left_idx = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<64x2x4xi32> loc(#loc284) + %left_idx_15 = arith.muli %y_idx, %left_idx : tensor<64x2x4xi32> loc(#loc284) + %left_idx_16 = tt.call @"triton.language.standard.sum__i32S64_2_4S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%left_idx_15) : (tensor<64x2x4xi32>) -> tensor<64x4xi32> loc(#loc285) + %left_idx_17 = tt.expand_dims %left_idx_16 {axis = 1 : i32} : tensor<64x4xi32> -> tensor<64x1x4xi32> loc(#loc286) + %left_idx_18 = tt.broadcast %left_idx_17 : tensor<64x1x4xi32> -> tensor<64x2x4xi32> loc(#loc287) + %right_idx = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<64x2x4xi32> loc(#loc289) + %right_idx_19 = arith.muli %y_idx, %right_idx : tensor<64x2x4xi32> loc(#loc289) + %right_idx_20 = tt.call @"triton.language.standard.sum__i32S64_2_4S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%right_idx_19) : (tensor<64x2x4xi32>) -> tensor<64x4xi32> loc(#loc290) + %right_idx_21 = tt.expand_dims %right_idx_20 {axis = 1 : i32} : tensor<64x4xi32> -> tensor<64x1x4xi32> loc(#loc291) + %right_idx_22 = tt.broadcast %right_idx_21 : tensor<64x1x4xi32> -> tensor<64x2x4xi32> loc(#loc292) + %left_idx_23 = tt.reshape %left_idx_18 : tensor<64x2x4xi32> -> tensor<32x16xi32> loc(#loc293) + %right_idx_24 = tt.reshape %right_idx_22 : tensor<64x2x4xi32> -> tensor<32x16xi32> loc(#loc294) + %left_valid_mask = arith.constant true loc(#loc295) + %left_valid_mask_25 = arith.constant dense : tensor<32x16xi1> loc(#loc295) + %right_valid_mask = arith.constant true loc(#loc296) + %right_valid_mask_26 = arith.constant dense : tensor<32x16xi1> loc(#loc296) + %left_isnan = arith.cmpi ne, %ileft_13, %ileft_13 : tensor<32x16xi32> loc(#loc297) + %right_isnan = arith.cmpi ne, %iright_14, %iright_14 : tensor<32x16xi32> loc(#loc298) + %cond = arith.cmpi slt, %ileft_13, %iright_14 : tensor<32x16xi32> loc(#loc331) + %0 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S32_16S__(%ileft_13) : (tensor<32x16xi32>) -> i1 loc(#loc136) + %1 = scf.if %0 -> (tensor<32x16xi1>) { + %cond_42 = arith.constant true loc(#loc300) + %cond_43 = arith.constant dense : tensor<32x16xi1> loc(#loc300) + %cond_44 = arith.xori %left_isnan, %cond_43 : tensor<32x16xi1> loc(#loc300) + %cond_45 = arith.andi %right_isnan, %cond_44 : tensor<32x16xi1> loc(#loc301) + %cond_46 = arith.ori %cond, %cond_45 : tensor<32x16xi1> loc(#loc332) + scf.yield %cond_46 : tensor<32x16xi1> loc(#loc332) + } else { + scf.yield %cond : tensor<32x16xi1> loc(#loc141) + } loc(#loc137) + %eq = arith.cmpi eq, %ileft_13, %iright_14 : tensor<32x16xi32> loc(#loc333) + %2 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S32_16S__(%ileft_13) : (tensor<32x16xi32>) -> i1 loc(#loc143) + %3 = scf.if %2 -> (tensor<32x16xi1>) { + %eq_42 = arith.andi %left_isnan, %right_isnan : tensor<32x16xi1> loc(#loc304) + %eq_43 = arith.ori %eq, %eq_42 : tensor<32x16xi1> loc(#loc334) + scf.yield %eq_43 : tensor<32x16xi1> loc(#loc334) + } else { + scf.yield %eq : tensor<32x16xi1> loc(#loc141) + } loc(#loc144) + %cond_27 = arith.cmpi sgt, %left_idx_23, %right_idx_24 : tensor<32x16xi32> loc(#loc306) + %cond_28 = arith.andi %3, %cond_27 : tensor<32x16xi1> loc(#loc307) + %cond_29 = arith.ori %1, %cond_28 : tensor<32x16xi1> loc(#loc308) + %cond_30 = arith.cmpi ugt, %right_valid_mask_26, %left_valid_mask_25 : tensor<32x16xi1> loc(#loc309) + %cond_31 = arith.cmpi eq, %right_valid_mask_26, %left_valid_mask_25 : tensor<32x16xi1> loc(#loc310) + %cond_32 = arith.andi %cond_31, %cond_29 : tensor<32x16xi1> loc(#loc311) + %cond_33 = arith.ori %cond_30, %cond_32 : tensor<32x16xi1> loc(#loc312) + %cond_34 = tt.splat %flip : i1 -> tensor<32x16xi1> loc(#loc313) + %cond_35 = arith.xori %cond_33, %cond_34 : tensor<32x16xi1> loc(#loc313) + %ret = arith.xori %ileft_13, %iright_14 : tensor<32x16xi32> loc(#loc315) + %ret_36 = tt.call @triton.language.standard.zeros_like__i32S32_16S__(%x) : (tensor<32x16xi32>) -> tensor<32x16xi32> loc(#loc316) + %ret_37 = arith.select %cond_35, %ret, %ret_36 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc317) + %ret_38 = arith.xori %x, %ret_37 : tensor<32x16xi32> loc(#loc318) + %new_idxs = arith.xori %left_idx_23, %right_idx_24 : tensor<32x16xi32> loc(#loc319) + %new_idxs_39 = tt.call @triton.language.standard.zeros_like__i32S32_16S__(%idxs) : (tensor<32x16xi32>) -> tensor<32x16xi32> loc(#loc320) + %new_idxs_40 = arith.select %cond_35, %new_idxs, %new_idxs_39 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc321) + %new_idxs_41 = arith.xori %idxs, %new_idxs_40 : tensor<32x16xi32> loc(#loc322) + tt.return %ret_38, %new_idxs_41 : tensor<32x16xi32>, tensor<32x16xi32> loc(#loc164) + ^bb1: // no predecessors + %4 = ub.poison : tensor<32x16xi32> loc(#loc165) + %5 = ub.poison : tensor<32x16xi32> loc(#loc165) + tt.return %4, %5 : tensor<32x16xi32>, tensor<32x16xi32> loc(#loc165) + } loc(#loc103) + tt.func private @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S32_16S_i32S32_16S_u1__(2,)cconstexpr_None__(4,)cconstexpr_2__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x: tensor<32x16xi32> loc("x"(#loc103)), %idxs: tensor<32x16xi32> loc("idxs"(#loc103)), %flip: i1 loc("flip"(#loc103))) -> (tensor<32x16xi32>, tensor<32x16xi32>) attributes {noinline = false} { + %y = tt.reshape %x : tensor<32x16xi32> -> tensor<128x2x2xi32> loc(#loc268) + %right_mask = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc269) + %right_mask_0 = tt.expand_dims %right_mask {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc270) + %right_mask_1 = tt.expand_dims %right_mask_0 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc270) + %left_mask = arith.constant 1 : i32 loc(#loc271) + %left_mask_2 = arith.constant 1 : i32 loc(#loc271) + %left_mask_3 = arith.constant dense<1> : tensor<1x2x1xi32> loc(#loc271) + %left_mask_4 = arith.subi %left_mask_3, %right_mask_1 : tensor<1x2x1xi32> loc(#loc271) + %ileft = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<128x2x2xi32> loc(#loc272) + %ileft_5 = arith.muli %y, %ileft : tensor<128x2x2xi32> loc(#loc272) + %ileft_6 = tt.call @"triton.language.standard.sum__i32S128_2_2S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%ileft_5) : (tensor<128x2x2xi32>) -> tensor<128x2xi32> loc(#loc273) + %ileft_7 = tt.expand_dims %ileft_6 {axis = 1 : i32} : tensor<128x2xi32> -> tensor<128x1x2xi32> loc(#loc274) + %ileft_8 = tt.broadcast %ileft_7 : tensor<128x1x2xi32> -> tensor<128x2x2xi32> loc(#loc275) + %iright = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<128x2x2xi32> loc(#loc276) + %iright_9 = arith.muli %y, %iright : tensor<128x2x2xi32> loc(#loc276) + %iright_10 = tt.call @"triton.language.standard.sum__i32S128_2_2S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%iright_9) : (tensor<128x2x2xi32>) -> tensor<128x2xi32> loc(#loc277) + %iright_11 = tt.expand_dims %iright_10 {axis = 1 : i32} : tensor<128x2xi32> -> tensor<128x1x2xi32> loc(#loc278) + %iright_12 = tt.broadcast %iright_11 : tensor<128x1x2xi32> -> tensor<128x2x2xi32> loc(#loc279) + %ileft_13 = tt.reshape %ileft_8 : tensor<128x2x2xi32> -> tensor<32x16xi32> loc(#loc280) + %iright_14 = tt.reshape %iright_12 : tensor<128x2x2xi32> -> tensor<32x16xi32> loc(#loc281) + %y_idx = tt.reshape %idxs : tensor<32x16xi32> -> tensor<128x2x2xi32> loc(#loc282) + %left_idx = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<128x2x2xi32> loc(#loc284) + %left_idx_15 = arith.muli %y_idx, %left_idx : tensor<128x2x2xi32> loc(#loc284) + %left_idx_16 = tt.call @"triton.language.standard.sum__i32S128_2_2S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%left_idx_15) : (tensor<128x2x2xi32>) -> tensor<128x2xi32> loc(#loc285) + %left_idx_17 = tt.expand_dims %left_idx_16 {axis = 1 : i32} : tensor<128x2xi32> -> tensor<128x1x2xi32> loc(#loc286) + %left_idx_18 = tt.broadcast %left_idx_17 : tensor<128x1x2xi32> -> tensor<128x2x2xi32> loc(#loc287) + %right_idx = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<128x2x2xi32> loc(#loc289) + %right_idx_19 = arith.muli %y_idx, %right_idx : tensor<128x2x2xi32> loc(#loc289) + %right_idx_20 = tt.call @"triton.language.standard.sum__i32S128_2_2S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%right_idx_19) : (tensor<128x2x2xi32>) -> tensor<128x2xi32> loc(#loc290) + %right_idx_21 = tt.expand_dims %right_idx_20 {axis = 1 : i32} : tensor<128x2xi32> -> tensor<128x1x2xi32> loc(#loc291) + %right_idx_22 = tt.broadcast %right_idx_21 : tensor<128x1x2xi32> -> tensor<128x2x2xi32> loc(#loc292) + %left_idx_23 = tt.reshape %left_idx_18 : tensor<128x2x2xi32> -> tensor<32x16xi32> loc(#loc293) + %right_idx_24 = tt.reshape %right_idx_22 : tensor<128x2x2xi32> -> tensor<32x16xi32> loc(#loc294) + %left_valid_mask = arith.constant true loc(#loc295) + %left_valid_mask_25 = arith.constant dense : tensor<32x16xi1> loc(#loc295) + %right_valid_mask = arith.constant true loc(#loc296) + %right_valid_mask_26 = arith.constant dense : tensor<32x16xi1> loc(#loc296) + %left_isnan = arith.cmpi ne, %ileft_13, %ileft_13 : tensor<32x16xi32> loc(#loc297) + %right_isnan = arith.cmpi ne, %iright_14, %iright_14 : tensor<32x16xi32> loc(#loc298) + %cond = arith.cmpi slt, %ileft_13, %iright_14 : tensor<32x16xi32> loc(#loc331) + %0 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S32_16S__(%ileft_13) : (tensor<32x16xi32>) -> i1 loc(#loc136) + %1 = scf.if %0 -> (tensor<32x16xi1>) { + %cond_42 = arith.constant true loc(#loc300) + %cond_43 = arith.constant dense : tensor<32x16xi1> loc(#loc300) + %cond_44 = arith.xori %left_isnan, %cond_43 : tensor<32x16xi1> loc(#loc300) + %cond_45 = arith.andi %right_isnan, %cond_44 : tensor<32x16xi1> loc(#loc301) + %cond_46 = arith.ori %cond, %cond_45 : tensor<32x16xi1> loc(#loc332) + scf.yield %cond_46 : tensor<32x16xi1> loc(#loc332) + } else { + scf.yield %cond : tensor<32x16xi1> loc(#loc141) + } loc(#loc137) + %eq = arith.cmpi eq, %ileft_13, %iright_14 : tensor<32x16xi32> loc(#loc333) + %2 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S32_16S__(%ileft_13) : (tensor<32x16xi32>) -> i1 loc(#loc143) + %3 = scf.if %2 -> (tensor<32x16xi1>) { + %eq_42 = arith.andi %left_isnan, %right_isnan : tensor<32x16xi1> loc(#loc304) + %eq_43 = arith.ori %eq, %eq_42 : tensor<32x16xi1> loc(#loc334) + scf.yield %eq_43 : tensor<32x16xi1> loc(#loc334) + } else { + scf.yield %eq : tensor<32x16xi1> loc(#loc141) + } loc(#loc144) + %cond_27 = arith.cmpi sgt, %left_idx_23, %right_idx_24 : tensor<32x16xi32> loc(#loc306) + %cond_28 = arith.andi %3, %cond_27 : tensor<32x16xi1> loc(#loc307) + %cond_29 = arith.ori %1, %cond_28 : tensor<32x16xi1> loc(#loc308) + %cond_30 = arith.cmpi ugt, %right_valid_mask_26, %left_valid_mask_25 : tensor<32x16xi1> loc(#loc309) + %cond_31 = arith.cmpi eq, %right_valid_mask_26, %left_valid_mask_25 : tensor<32x16xi1> loc(#loc310) + %cond_32 = arith.andi %cond_31, %cond_29 : tensor<32x16xi1> loc(#loc311) + %cond_33 = arith.ori %cond_30, %cond_32 : tensor<32x16xi1> loc(#loc312) + %cond_34 = tt.splat %flip : i1 -> tensor<32x16xi1> loc(#loc313) + %cond_35 = arith.xori %cond_33, %cond_34 : tensor<32x16xi1> loc(#loc313) + %ret = arith.xori %ileft_13, %iright_14 : tensor<32x16xi32> loc(#loc315) + %ret_36 = tt.call @triton.language.standard.zeros_like__i32S32_16S__(%x) : (tensor<32x16xi32>) -> tensor<32x16xi32> loc(#loc316) + %ret_37 = arith.select %cond_35, %ret, %ret_36 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc317) + %ret_38 = arith.xori %x, %ret_37 : tensor<32x16xi32> loc(#loc318) + %new_idxs = arith.xori %left_idx_23, %right_idx_24 : tensor<32x16xi32> loc(#loc319) + %new_idxs_39 = tt.call @triton.language.standard.zeros_like__i32S32_16S__(%idxs) : (tensor<32x16xi32>) -> tensor<32x16xi32> loc(#loc320) + %new_idxs_40 = arith.select %cond_35, %new_idxs, %new_idxs_39 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc321) + %new_idxs_41 = arith.xori %idxs, %new_idxs_40 : tensor<32x16xi32> loc(#loc322) + tt.return %ret_38, %new_idxs_41 : tensor<32x16xi32>, tensor<32x16xi32> loc(#loc164) + ^bb1: // no predecessors + %4 = ub.poison : tensor<32x16xi32> loc(#loc165) + %5 = ub.poison : tensor<32x16xi32> loc(#loc165) + tt.return %4, %5 : tensor<32x16xi32>, tensor<32x16xi32> loc(#loc165) + } loc(#loc103) + tt.func private @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S32_16S_i32S32_16S_u1__(2,)cconstexpr_None__(4,)cconstexpr_3__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x: tensor<32x16xi32> loc("x"(#loc103)), %idxs: tensor<32x16xi32> loc("idxs"(#loc103)), %flip: i1 loc("flip"(#loc103))) -> (tensor<32x16xi32>, tensor<32x16xi32>) attributes {noinline = false} { + %y = tt.reshape %x : tensor<32x16xi32> -> tensor<256x2x1xi32> loc(#loc268) + %right_mask = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc269) + %right_mask_0 = tt.expand_dims %right_mask {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc270) + %right_mask_1 = tt.expand_dims %right_mask_0 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc270) + %left_mask = arith.constant 1 : i32 loc(#loc271) + %left_mask_2 = arith.constant 1 : i32 loc(#loc271) + %left_mask_3 = arith.constant dense<1> : tensor<1x2x1xi32> loc(#loc271) + %left_mask_4 = arith.subi %left_mask_3, %right_mask_1 : tensor<1x2x1xi32> loc(#loc271) + %ileft = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<256x2x1xi32> loc(#loc272) + %ileft_5 = arith.muli %y, %ileft : tensor<256x2x1xi32> loc(#loc272) + %ileft_6 = tt.call @"triton.language.standard.sum__i32S256_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%ileft_5) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc273) + %ileft_7 = tt.expand_dims %ileft_6 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc274) + %ileft_8 = tt.broadcast %ileft_7 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc275) + %iright = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<256x2x1xi32> loc(#loc276) + %iright_9 = arith.muli %y, %iright : tensor<256x2x1xi32> loc(#loc276) + %iright_10 = tt.call @"triton.language.standard.sum__i32S256_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%iright_9) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc277) + %iright_11 = tt.expand_dims %iright_10 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc278) + %iright_12 = tt.broadcast %iright_11 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc279) + %ileft_13 = tt.reshape %ileft_8 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc280) + %iright_14 = tt.reshape %iright_12 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc281) + %y_idx = tt.reshape %idxs : tensor<32x16xi32> -> tensor<256x2x1xi32> loc(#loc282) + %left_idx = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<256x2x1xi32> loc(#loc284) + %left_idx_15 = arith.muli %y_idx, %left_idx : tensor<256x2x1xi32> loc(#loc284) + %left_idx_16 = tt.call @"triton.language.standard.sum__i32S256_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%left_idx_15) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc285) + %left_idx_17 = tt.expand_dims %left_idx_16 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc286) + %left_idx_18 = tt.broadcast %left_idx_17 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc287) + %right_idx = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<256x2x1xi32> loc(#loc289) + %right_idx_19 = arith.muli %y_idx, %right_idx : tensor<256x2x1xi32> loc(#loc289) + %right_idx_20 = tt.call @"triton.language.standard.sum__i32S256_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%right_idx_19) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc290) + %right_idx_21 = tt.expand_dims %right_idx_20 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc291) + %right_idx_22 = tt.broadcast %right_idx_21 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc292) + %left_idx_23 = tt.reshape %left_idx_18 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc293) + %right_idx_24 = tt.reshape %right_idx_22 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc294) + %left_valid_mask = arith.constant true loc(#loc295) + %left_valid_mask_25 = arith.constant dense : tensor<32x16xi1> loc(#loc295) + %right_valid_mask = arith.constant true loc(#loc296) + %right_valid_mask_26 = arith.constant dense : tensor<32x16xi1> loc(#loc296) + %left_isnan = arith.cmpi ne, %ileft_13, %ileft_13 : tensor<32x16xi32> loc(#loc297) + %right_isnan = arith.cmpi ne, %iright_14, %iright_14 : tensor<32x16xi32> loc(#loc298) + %cond = arith.cmpi slt, %ileft_13, %iright_14 : tensor<32x16xi32> loc(#loc331) + %0 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S32_16S__(%ileft_13) : (tensor<32x16xi32>) -> i1 loc(#loc136) + %1 = scf.if %0 -> (tensor<32x16xi1>) { + %cond_42 = arith.constant true loc(#loc300) + %cond_43 = arith.constant dense : tensor<32x16xi1> loc(#loc300) + %cond_44 = arith.xori %left_isnan, %cond_43 : tensor<32x16xi1> loc(#loc300) + %cond_45 = arith.andi %right_isnan, %cond_44 : tensor<32x16xi1> loc(#loc301) + %cond_46 = arith.ori %cond, %cond_45 : tensor<32x16xi1> loc(#loc332) + scf.yield %cond_46 : tensor<32x16xi1> loc(#loc332) + } else { + scf.yield %cond : tensor<32x16xi1> loc(#loc141) + } loc(#loc137) + %eq = arith.cmpi eq, %ileft_13, %iright_14 : tensor<32x16xi32> loc(#loc333) + %2 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S32_16S__(%ileft_13) : (tensor<32x16xi32>) -> i1 loc(#loc143) + %3 = scf.if %2 -> (tensor<32x16xi1>) { + %eq_42 = arith.andi %left_isnan, %right_isnan : tensor<32x16xi1> loc(#loc304) + %eq_43 = arith.ori %eq, %eq_42 : tensor<32x16xi1> loc(#loc334) + scf.yield %eq_43 : tensor<32x16xi1> loc(#loc334) + } else { + scf.yield %eq : tensor<32x16xi1> loc(#loc141) + } loc(#loc144) + %cond_27 = arith.cmpi sgt, %left_idx_23, %right_idx_24 : tensor<32x16xi32> loc(#loc306) + %cond_28 = arith.andi %3, %cond_27 : tensor<32x16xi1> loc(#loc307) + %cond_29 = arith.ori %1, %cond_28 : tensor<32x16xi1> loc(#loc308) + %cond_30 = arith.cmpi ugt, %right_valid_mask_26, %left_valid_mask_25 : tensor<32x16xi1> loc(#loc309) + %cond_31 = arith.cmpi eq, %right_valid_mask_26, %left_valid_mask_25 : tensor<32x16xi1> loc(#loc310) + %cond_32 = arith.andi %cond_31, %cond_29 : tensor<32x16xi1> loc(#loc311) + %cond_33 = arith.ori %cond_30, %cond_32 : tensor<32x16xi1> loc(#loc312) + %cond_34 = tt.splat %flip : i1 -> tensor<32x16xi1> loc(#loc313) + %cond_35 = arith.xori %cond_33, %cond_34 : tensor<32x16xi1> loc(#loc313) + %ret = arith.xori %ileft_13, %iright_14 : tensor<32x16xi32> loc(#loc315) + %ret_36 = tt.call @triton.language.standard.zeros_like__i32S32_16S__(%x) : (tensor<32x16xi32>) -> tensor<32x16xi32> loc(#loc316) + %ret_37 = arith.select %cond_35, %ret, %ret_36 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc317) + %ret_38 = arith.xori %x, %ret_37 : tensor<32x16xi32> loc(#loc318) + %new_idxs = arith.xori %left_idx_23, %right_idx_24 : tensor<32x16xi32> loc(#loc319) + %new_idxs_39 = tt.call @triton.language.standard.zeros_like__i32S32_16S__(%idxs) : (tensor<32x16xi32>) -> tensor<32x16xi32> loc(#loc320) + %new_idxs_40 = arith.select %cond_35, %new_idxs, %new_idxs_39 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc321) + %new_idxs_41 = arith.xori %idxs, %new_idxs_40 : tensor<32x16xi32> loc(#loc322) + tt.return %ret_38, %new_idxs_41 : tensor<32x16xi32>, tensor<32x16xi32> loc(#loc164) + ^bb1: // no predecessors + %4 = ub.poison : tensor<32x16xi32> loc(#loc165) + %5 = ub.poison : tensor<32x16xi32> loc(#loc165) + tt.return %4, %5 : tensor<32x16xi32>, tensor<32x16xi32> loc(#loc165) + } loc(#loc103) + tt.func private @"triton.language.standard.sum__i64S32_16S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%input: tensor<32x16xi64> loc("input"(#loc166))) -> tensor<32xi64> attributes {noinline = false} { + %0 = "tt.reduce"(%input) <{axis = 1 : i32}> ({ + ^bb0(%arg1: i64 loc(unknown), %arg2: i64 loc(unknown)): + %2 = tt.call @triton.language.standard._sum_combine__i64_i64__(%arg1, %arg2) : (i64, i64) -> i64 loc(#loc167) + tt.reduce.return %2 : i64 loc(#loc167) + }) : (tensor<32x16xi64>) -> tensor<32xi64> loc(#loc167) + tt.return %0 : tensor<32xi64> loc(#loc168) + ^bb1: // no predecessors + %1 = ub.poison : tensor<32xi64> loc(#loc169) + tt.return %1 : tensor<32xi64> loc(#loc169) + } loc(#loc166) + tt.func private @triton.language.standard._sum_combine__i64_i64__(%a: i64 loc("a"(#loc170)), %b: i64 loc("b"(#loc170))) -> i64 attributes {noinline = false} { + %0 = arith.addi %a, %b : i64 loc(#loc171) + tt.return %0 : i64 loc(#loc172) + ^bb1: // no predecessors + %1 = ub.poison : i64 loc(#loc173) + tt.return %1 : i64 loc(#loc173) + } loc(#loc170) +} loc(#loc) +#loc1 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":19:13) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":20:15) +#loc3 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":24:28) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":24:33) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":25:36) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":25:44) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":25:23) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":26:21) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":27:28) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":27:38) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":28:16) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":29:48) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":34:40) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":34:37) +#loc15 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":34:30) +#loc16 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":34:45) +#loc17 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":35:30) +#loc18 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":36:18) +#loc19 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":37:34) +#loc20 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":38:18) +#loc21 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":39:18) +#loc22 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":40:19) +#loc23 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":41:19) +#loc24 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":43:19) +#loc25 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":45:34) +#loc26 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":46:71) +#loc27 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":47:20) +#loc28 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":48:21) +#loc29 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":49:21) +#loc30 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":51:71) +#loc31 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":52:20) +#loc32 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":54:35) +#loc33 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":55:26) +#loc34 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":55:29) +#loc35 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":56:21) +#loc36 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":58:35) +#loc37 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":59:26) +#loc38 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":59:29) +#loc39 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":60:21) +#loc40 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":61:21) +#loc41 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":62:21) +#loc42 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":63:21) +#loc43 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":64:19) +#loc44 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":65:32) +#loc45 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":66:35) +#loc46 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":67:44) +#loc47 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":68:20) +#loc48 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":69:20) +#loc49 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":70:35) +#loc50 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":71:28) +#loc51 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":71:46) +#loc52 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":71:38) +#loc53 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":71:55) +#loc54 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":71:53) +#loc55 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":71:63) +#loc56 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":72:31) +#loc57 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":73:21) +#loc58 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":74:21) +#loc59 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":75:19) +#loc60 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":76:35) +#loc61 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":77:20) +#loc62 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":78:20) +#loc63 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":79:35) +#loc64 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":80:28) +#loc65 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":80:46) +#loc66 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":80:38) +#loc67 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":80:55) +#loc68 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":80:53) +#loc69 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":80:63) +#loc70 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":81:25) +#loc71 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":81:37) +#loc72 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":82:25) +#loc73 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":82:37) +#loc74 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":83:35) +#loc75 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":83:32) +#loc76 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":83:25) +#loc77 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":83:47) +#loc78 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":84:52) +#loc79 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":84:49) +#loc80 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":84:25) +#loc81 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":84:85) +#loc82 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":85:35) +#loc83 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":85:32) +#loc84 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":85:25) +#loc85 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":85:47) +#loc86 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":86:52) +#loc87 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":86:49) +#loc88 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":86:25) +#loc89 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":86:85) +#loc90 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":86:4) +#loc92 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":662:12) +#loc93 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":668:11) +#loc94 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":668:4) +#loc96 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":627:41) +#loc97 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":627:44) +#loc98 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":627:60) +#loc99 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":627:68) +#loc100 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":634:73) +#loc101 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":636:11) +#loc102 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":636:4) +#loc104 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":533:22) +#loc105 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":536:30) +#loc106 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":536:33) +#loc107 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":537:21) +#loc108 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":538:40) +#loc109 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":538:51) +#loc110 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":538:65) +#loc111 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":538:78) +#loc112 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":539:41) +#loc113 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":539:53) +#loc114 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":539:67) +#loc115 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":539:80) +#loc116 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":540:30) +#loc117 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":541:32) +#loc118 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":546:29) +#loc119 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:36) +#loc120 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:23) +#loc121 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:50) +#loc122 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:53) +#loc123 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:66) +#loc124 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:37) +#loc125 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:23) +#loc126 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:51) +#loc127 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:54) +#loc128 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:67) +#loc129 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":553:36) +#loc130 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":554:38) +#loc131 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":558:49) +#loc132 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":559:50) +#loc133 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":570:25) +#loc134 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":571:27) +#loc135 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":574:22) +#loc136 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":575:23) +#loc137 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":575:11) +#loc138 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":579:47) +#loc139 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":579:46) +#loc140 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":579:31) +#loc142 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":591:21) +#loc143 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":592:23) +#loc144 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":592:11) +#loc145 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":593:36) +#loc146 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":593:23) +#loc147 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":594:40) +#loc148 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":594:29) +#loc149 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":594:23) +#loc150 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":596:31) +#loc151 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":597:29) +#loc152 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":597:48) +#loc153 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":597:8) +#loc154 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":599:19) +#loc155 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":599:28) +#loc156 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":600:38) +#loc157 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":600:60) +#loc158 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":600:46) +#loc159 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":600:15) +#loc160 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":601:48) +#loc161 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":601:73) +#loc162 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":601:59) +#loc163 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":601:22) +#loc164 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":603:11) +#loc165 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":603:4) +#loc167 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:36) +#loc168 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:11) +#loc169 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:4) +#loc171 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:15) +#loc172 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:11) +#loc173 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:4) +#loc174 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":290:25) +#loc176 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":87:29) +#loc177 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":87:11) +#loc178 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":87:4) +#loc180 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":65:30) +#loc181 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":65:15) +#loc182 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":65:11) +#loc183 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":65:4) +#loc184 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":118:0) +#loc185 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":127:31) +#loc186 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":127:11) +#loc187 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":127:4) +#loc189 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":138:30) +#loc190 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":138:11) +#loc191 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":138:4) +#loc192 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":630:15) +#loc202 = loc("xnumel"(#loc1)) +#loc203 = loc("r0_numel"(#loc2)) +#loc204 = loc("xoffset"(#loc3)) +#loc205 = loc("xoffset"(#loc4)) +#loc206 = loc("xindex"(#loc5)) +#loc207 = loc("xindex"(#loc6)) +#loc208 = loc("xindex"(#loc7)) +#loc209 = loc("xmask"(#loc8)) +#loc210 = loc("r0_index"(#loc9)) +#loc211 = loc("r0_index"(#loc10)) +#loc212 = loc("r0_offset"(#loc11)) +#loc213 = loc("r0_mask"(#loc12)) +#loc214 = loc("tmp0"(#loc13)) +#loc215 = loc("tmp0"(#loc14)) +#loc216 = loc("tmp0"(#loc15)) +#loc217 = loc("tmp0"(#loc16)) +#loc218 = loc("tmp1"(#loc17)) +#loc219 = loc("tmp2"(#loc18)) +#loc220 = loc("tmp3"(#loc19)) +#loc221 = loc("tmp4"(#loc20)) +#loc222 = loc("tmp5"(#loc21)) +#loc223 = loc("tmp6"(#loc22)) +#loc224 = loc("tmp7"(#loc23)) +#loc225 = loc("tmp9"(#loc24)) +#loc226 = loc("tmp11"(#loc25)) +#loc227 = loc("tmp14"(#loc27)) +#loc228 = loc("tmp15"(#loc28)) +#loc229 = loc("tmp16"(#loc29)) +#loc230 = loc("tmp20"(#loc31)) +#loc231 = loc("tmp23"(#loc32)) +#loc232 = loc("tmp24"(#loc33)) +#loc233 = loc("tmp24"(#loc34)) +#loc234 = loc("tmp25"(#loc35)) +#loc235 = loc("tmp28"(#loc36)) +#loc236 = loc("tmp29"(#loc37)) +#loc237 = loc("tmp29"(#loc38)) +#loc238 = loc("tmp30"(#loc39)) +#loc239 = loc("tmp31"(#loc40)) +#loc240 = loc("tmp32"(#loc41)) +#loc241 = loc("tmp33"(#loc42)) +#loc242 = loc("tmp34"(#loc43)) +#loc243 = loc("tmp35"(#loc44)) +#loc244 = loc("tmp36"(#loc45)) +#loc245 = loc("tmp37"(#loc46)) +#loc246 = loc("tmp38"(#loc47)) +#loc247 = loc("tmp39"(#loc48)) +#loc248 = loc("tmp40"(#loc49)) +#loc249 = loc("tmp42"(#loc56)) +#loc250 = loc("tmp43"(#loc57)) +#loc251 = loc("tmp44"(#loc58)) +#loc252 = loc("tmp45"(#loc59)) +#loc253 = loc("tmp46"(#loc60)) +#loc254 = loc("tmp47"(#loc61)) +#loc255 = loc("tmp48"(#loc62)) +#loc256 = loc("tmp49"(#loc63)) +#loc261 = loc("flip"(#loc96)) +#loc262 = loc("flip"(#loc97)) +#loc263 = loc("flip"(#loc98)) +#loc264 = loc("flip"(#loc99)) +#loc268 = loc("y"(#loc104)) +#loc269 = loc("right_mask"(#loc105)) +#loc270 = loc("right_mask"(#loc106)) +#loc271 = loc("left_mask"(#loc107)) +#loc272 = loc("ileft"(#loc108)) +#loc273 = loc("ileft"(#loc109)) +#loc274 = loc("ileft"(#loc110)) +#loc275 = loc("ileft"(#loc111)) +#loc276 = loc("iright"(#loc112)) +#loc277 = loc("iright"(#loc113)) +#loc278 = loc("iright"(#loc114)) +#loc279 = loc("iright"(#loc115)) +#loc280 = loc("ileft"(#loc116)) +#loc281 = loc("iright"(#loc117)) +#loc282 = loc("y_idx"(#loc118)) +#loc283 = loc("left_idx"(#loc119)) +#loc284 = loc("left_idx"(#loc120)) +#loc285 = loc("left_idx"(#loc121)) +#loc286 = loc("left_idx"(#loc122)) +#loc287 = loc("left_idx"(#loc123)) +#loc288 = loc("right_idx"(#loc124)) +#loc289 = loc("right_idx"(#loc125)) +#loc290 = loc("right_idx"(#loc126)) +#loc291 = loc("right_idx"(#loc127)) +#loc292 = loc("right_idx"(#loc128)) +#loc293 = loc("left_idx"(#loc129)) +#loc294 = loc("right_idx"(#loc130)) +#loc295 = loc("left_valid_mask"(#loc131)) +#loc296 = loc("right_valid_mask"(#loc132)) +#loc297 = loc("left_isnan"(#loc133)) +#loc298 = loc("right_isnan"(#loc134)) +#loc299 = loc("cond"(#loc135)) +#loc300 = loc("cond"(#loc138)) +#loc301 = loc("cond"(#loc139)) +#loc302 = loc("cond"(#loc140)) +#loc303 = loc("eq"(#loc142)) +#loc304 = loc("eq"(#loc145)) +#loc305 = loc("eq"(#loc146)) +#loc306 = loc("cond"(#loc147)) +#loc307 = loc("cond"(#loc148)) +#loc308 = loc("cond"(#loc149)) +#loc309 = loc("cond"(#loc150)) +#loc310 = loc("cond"(#loc151)) +#loc311 = loc("cond"(#loc152)) +#loc312 = loc("cond"(#loc153)) +#loc313 = loc("cond"(#loc154)) +#loc314 = loc("cond"(#loc155)) +#loc315 = loc("ret"(#loc156)) +#loc316 = loc("ret"(#loc157)) +#loc317 = loc("ret"(#loc158)) +#loc318 = loc("ret"(#loc159)) +#loc319 = loc("new_idxs"(#loc160)) +#loc320 = loc("new_idxs"(#loc161)) +#loc321 = loc("new_idxs"(#loc162)) +#loc322 = loc("new_idxs"(#loc163)) +#loc326 = loc("input"(#loc174)) +#loc330 = loc("flip"(#loc192)) +#loc331 = loc("cond"(#loc299)) +#loc332 = loc("cond"(#loc302)) +#loc333 = loc("eq"(#loc303)) +#loc334 = loc("eq"(#loc305)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/EFSFIQ3KOQUXLGL22JESW5DLVY6LJWVBODRP6EOX4ISS2B62HCMA/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ttgir b/SpecForge-ext/cache/compiled_kernels/triton/0/EFSFIQ3KOQUXLGL22JESW5DLVY6LJWVBODRP6EOX4ISS2B62HCMA/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ttgir new file mode 100644 index 0000000000000000000000000000000000000000..09f8f37324662708ff1984a7ad7b7569e6ef20d4 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/EFSFIQ3KOQUXLGL22JESW5DLVY6LJWVBODRP6EOX4ISS2B62HCMA/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ttgir @@ -0,0 +1,1480 @@ +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 2, 2], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [16, 2, 1], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 2, 2], threadsPerWarp = [32, 1, 1], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}> +#blocked4 = #ttg.blocked<{sizePerThread = [2, 2, 1], threadsPerWarp = [32, 1, 1], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}> +#blocked5 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":18:0) +#loc1 = loc(unknown) +#loc20 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":662:12) +#loc21 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":46:71) +#loc25 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":634:73) +#loc29 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":538:51) +#loc34 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":539:53) +#loc43 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:50) +#loc48 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:51) +#loc69 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":51:71) +#loc72 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":55:26) +#loc77 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":59:26) +#loc117 = loc("in_ptr0"(#loc)) +#loc118 = loc("out_ptr4"(#loc)) +#loc119 = loc("out_ptr5"(#loc)) +#loc120 = loc("out_ptr6"(#loc)) +#loc121 = loc("out_ptr7"(#loc)) +#loc122 = loc("out_ptr8"(#loc)) +#loc123 = loc("out_ptr9"(#loc)) +#loc124 = loc("xnumel"(#loc)) +#loc125 = loc("r0_numel"(#loc)) +#loc144 = loc(callsite(#loc20 at #loc21)) +#loc150 = loc("ileft"(#loc29)) +#loc154 = loc("iright"(#loc34)) +#loc163 = loc("left_idx"(#loc43)) +#loc168 = loc("right_idx"(#loc48)) +#loc189 = loc(callsite(#loc20 at #loc69)) +#loc192 = loc("tmp24"(#loc72)) +#loc197 = loc("tmp29"(#loc77)) +#loc214 = loc(callsite(#loc25 at #loc144)) +#loc218 = loc(callsite(#loc25 at #loc189)) +#loc221 = loc(callsite(#loc1 at #loc192)) +#loc224 = loc(callsite(#loc1 at #loc197)) +#loc228 = loc(callsite(#loc150 at #loc214)) +#loc232 = loc(callsite(#loc154 at #loc214)) +#loc240 = loc(callsite(#loc163 at #loc214)) +#loc245 = loc(callsite(#loc168 at #loc214)) +#loc265 = loc(callsite(#loc150 at #loc218)) +#loc269 = loc(callsite(#loc154 at #loc218)) +#loc287 = loc(callsite(#loc163 at #loc218)) +#loc291 = loc(callsite(#loc168 at #loc218)) +#loc301 = loc(callsite(#loc1 at #loc228)) +#loc303 = loc(callsite(#loc1 at #loc232)) +#loc306 = loc(callsite(#loc1 at #loc240)) +#loc309 = loc(callsite(#loc1 at #loc245)) +#loc311 = loc(callsite(#loc1 at #loc265)) +#loc313 = loc(callsite(#loc1 at #loc269)) +#loc315 = loc(callsite(#loc1 at #loc287)) +#loc317 = loc(callsite(#loc1 at #loc291)) +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %out_ptr4: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr4"(#loc)), %out_ptr5: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr5"(#loc)), %out_ptr6: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr6"(#loc)), %out_ptr7: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr7"(#loc)), %out_ptr8: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr8"(#loc)), %out_ptr9: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr9"(#loc)), %xnumel: i32 {tt.divisibility = 16 : i32} loc("xnumel"(#loc)), %r0_numel: i32 {tt.divisibility = 16 : i32} loc("r0_numel"(#loc))) attributes {noinline = false} { + %cst = arith.constant dense : tensor<32x1xi1, #blocked> loc(#loc1) + %cst_0 = arith.constant dense<17> : tensor<32x1xi32, #blocked> loc(#loc1) + %cst_1 = arith.constant dense<1> : tensor<1x2x1xi32, #blocked1> loc(#loc1) + %cst_2 = arith.constant dense<1> : tensor<1x2x1xi32, #blocked2> loc(#loc1) + %cst_3 = arith.constant dense<1> : tensor<1x2x1xi32, #blocked3> loc(#loc1) + %cst_4 = arith.constant dense<1> : tensor<1x2x1xi32, #blocked4> loc(#loc1) + %cst_5 = arith.constant dense<16> : tensor<32x1xi32, #blocked> loc(#loc1) + %cst_6 = arith.constant dense<128> : tensor<32x1xi32, #blocked5> loc(#loc1) + %cst_7 = arith.constant dense<128> : tensor<32x1xi32, #blocked> loc(#loc1) + %c32_i32 = arith.constant 32 : i32 loc(#loc1) + %cst_8 = arith.constant dense<1> : tensor<32x16xi32, #blocked5> loc(#loc1) + %cst_9 = arith.constant dense<0> : tensor<32x16xi32, #blocked> loc(#loc1) + %cst_10 = arith.constant dense<17> : tensor<32x16xi32, #blocked> loc(#loc1) + %cst_11 = arith.constant dense<16> : tensor<32x16xi32, #blocked> loc(#loc1) + %cst_12 = arith.constant dense<16384> : tensor<32x16xi64, #blocked> loc(#loc1) + %cst_13 = arith.constant dense<0> : tensor<32x16xi64, #blocked> loc(#loc1) + %xoffset = tt.get_program_id x : i32 loc(#loc126) + %xoffset_14 = arith.muli %xoffset, %c32_i32 : i32 loc(#loc127) + %xindex = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc128) + %xindex_15 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked5}>> loc(#loc128) + %xindex_16 = tt.expand_dims %xindex {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> loc(#loc128) + %xindex_17 = tt.expand_dims %xindex_15 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked5}>> -> tensor<32x1xi32, #blocked5> loc(#loc128) + %xindex_18 = tt.splat %xoffset_14 : i32 -> tensor<32x1xi32, #blocked> loc(#loc129) + %xindex_19 = tt.splat %xoffset_14 : i32 -> tensor<32x1xi32, #blocked5> loc(#loc129) + %xindex_20 = arith.addi %xindex_18, %xindex_16 : tensor<32x1xi32, #blocked> loc(#loc129) + %xindex_21 = arith.addi %xindex_19, %xindex_17 : tensor<32x1xi32, #blocked5> loc(#loc129) + %xmask = arith.cmpi slt, %xindex_20, %cst_7 : tensor<32x1xi32, #blocked> loc(#loc130) + %xmask_22 = arith.cmpi slt, %xindex_21, %cst_6 : tensor<32x1xi32, #blocked5> loc(#loc130) + %r0_index = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc131) + %r0_index_23 = tt.expand_dims %r0_index {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> loc(#loc131) + %tmp0 = arith.muli %xindex_20, %cst_5 : tensor<32x1xi32, #blocked> loc(#loc132) + %tmp0_24 = tt.broadcast %r0_index_23 : tensor<1x16xi32, #blocked> -> tensor<32x16xi32, #blocked> loc(#loc133) + %tmp0_25 = tt.broadcast %tmp0 : tensor<32x1xi32, #blocked> -> tensor<32x16xi32, #blocked> loc(#loc133) + %tmp0_26 = arith.addi %tmp0_24, %tmp0_25 : tensor<32x16xi32, #blocked> loc(#loc133) + %tmp0_27 = tt.splat %in_ptr0 : !tt.ptr -> tensor<32x16x!tt.ptr, #blocked> loc(#loc134) + %tmp0_28 = tt.addptr %tmp0_27, %tmp0_26 : tensor<32x16x!tt.ptr, #blocked>, tensor<32x16xi32, #blocked> loc(#loc134) + %tmp0_29 = tt.broadcast %xmask : tensor<32x1xi1, #blocked> -> tensor<32x16xi1, #blocked> loc(#loc135) + %tmp0_30 = tt.broadcast %xmask_22 : tensor<32x1xi1, #blocked5> -> tensor<32x16xi1, #blocked5> loc(#loc135) + %tmp0_31 = tt.load %tmp0_28, %tmp0_29, %cst_13 : tensor<32x16x!tt.ptr, #blocked> loc(#loc135) + %tmp2 = arith.cmpi sgt, %tmp0_31, %cst_13 : tensor<32x16xi64, #blocked> loc(#loc136) + %tmp4 = arith.cmpi slt, %tmp0_31, %cst_12 : tensor<32x16xi64, #blocked> loc(#loc137) + %tmp5 = arith.andi %tmp2, %tmp4 : tensor<32x16xi1, #blocked> loc(#loc138) + %tmp7 = arith.extui %tmp5 : tensor<32x16xi1, #blocked> to tensor<32x16xi32, #blocked> loc(#loc210) + %tmp9 = arith.trunci %r0_index_23 : tensor<1x16xi32, #blocked> to tensor<1x16xi16, #blocked> loc(#loc141) + %tmp11 = tt.broadcast %tmp9 : tensor<1x16xi16, #blocked> -> tensor<32x16xi16, #blocked> loc(#loc142) + %flip = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #blocked3}>}>> loc(#loc211) + %flip_32 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #blocked4}>}>> loc(#loc211) + %flip_33 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #blocked2}>}>> loc(#loc211) + %flip_34 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #blocked1}>}>> loc(#loc211) + %flip_35 = tt.expand_dims %flip {axis = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #blocked3}>}>> -> tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #blocked3}>> loc(#loc211) + %flip_36 = tt.expand_dims %flip_32 {axis = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #blocked4}>}>> -> tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #blocked4}>> loc(#loc211) + %flip_37 = tt.expand_dims %flip_33 {axis = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #blocked2}>}>> -> tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #blocked2}>> loc(#loc211) + %flip_38 = tt.expand_dims %flip_34 {axis = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #blocked1}>}>> -> tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #blocked1}>> loc(#loc211) + %flip_39 = tt.expand_dims %flip_35 {axis = 2 : i32} : tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #blocked3}>> -> tensor<1x2x1xi32, #blocked3> loc(#loc211) + %flip_40 = tt.expand_dims %flip_36 {axis = 2 : i32} : tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #blocked4}>> -> tensor<1x2x1xi32, #blocked4> loc(#loc211) + %flip_41 = tt.expand_dims %flip_37 {axis = 2 : i32} : tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #blocked2}>> -> tensor<1x2x1xi32, #blocked2> loc(#loc211) + %flip_42 = tt.expand_dims %flip_38 {axis = 2 : i32} : tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #blocked1}>> -> tensor<1x2x1xi32, #blocked1> loc(#loc211) + %flip_43 = tt.broadcast %flip_39 : tensor<1x2x1xi32, #blocked3> -> tensor<128x2x2xi32, #blocked3> loc(#loc212) + %flip_44 = tt.reshape %flip_43 : tensor<128x2x2xi32, #blocked3> -> tensor<32x16xi32, #blocked> loc(#loc213) + %y = tt.reshape %tmp7 : tensor<32x16xi32, #blocked> -> tensor<256x2x1xi32, #blocked4> loc(#loc225) + %left_mask = arith.subi %cst_4, %flip_40 : tensor<1x2x1xi32, #blocked4> loc(#loc226) + %left_mask_45 = arith.subi %cst_3, %flip_39 : tensor<1x2x1xi32, #blocked3> loc(#loc226) + %left_mask_46 = arith.subi %cst_2, %flip_41 : tensor<1x2x1xi32, #blocked2> loc(#loc226) + %left_mask_47 = arith.subi %cst_1, %flip_42 : tensor<1x2x1xi32, #blocked1> loc(#loc226) + %ileft = tt.broadcast %left_mask : tensor<1x2x1xi32, #blocked4> -> tensor<256x2x1xi32, #blocked4> loc(#loc227) + %ileft_48 = arith.muli %y, %ileft : tensor<256x2x1xi32, #blocked4> loc(#loc227) + %ileft_49 = "tt.reduce"(%ileft_48) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc228)), %ileft_742: i32 loc(callsite(#loc1 at #loc228))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc318) + tt.reduce.return %ileft_743 : i32 loc(#loc300) + }) : (tensor<256x2x1xi32, #blocked4>) -> tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc300) + %ileft_50 = tt.expand_dims %ileft_49 {axis = 1 : i32} : tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<256x1x1xi32, #blocked4> loc(#loc229) + %ileft_51 = tt.broadcast %ileft_50 : tensor<256x1x1xi32, #blocked4> -> tensor<256x2x1xi32, #blocked4> loc(#loc230) + %iright = tt.broadcast %flip_40 : tensor<1x2x1xi32, #blocked4> -> tensor<256x2x1xi32, #blocked4> loc(#loc231) + %iright_52 = arith.muli %y, %iright : tensor<256x2x1xi32, #blocked4> loc(#loc231) + %iright_53 = "tt.reduce"(%iright_52) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc232)), %iright_742: i32 loc(callsite(#loc1 at #loc232))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc319) + tt.reduce.return %iright_743 : i32 loc(#loc302) + }) : (tensor<256x2x1xi32, #blocked4>) -> tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc302) + %iright_54 = tt.expand_dims %iright_53 {axis = 1 : i32} : tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<256x1x1xi32, #blocked4> loc(#loc233) + %iright_55 = tt.broadcast %iright_54 : tensor<256x1x1xi32, #blocked4> -> tensor<256x2x1xi32, #blocked4> loc(#loc234) + %ileft_56 = tt.reshape %ileft_51 : tensor<256x2x1xi32, #blocked4> -> tensor<32x16xi32, #blocked> loc(#loc235) + %iright_57 = tt.reshape %iright_55 : tensor<256x2x1xi32, #blocked4> -> tensor<32x16xi32, #blocked> loc(#loc236) + %y_idx = tt.reshape %tmp11 : tensor<32x16xi16, #blocked> -> tensor<256x2x1xi16, #blocked4> loc(#loc237) + %left_idx = arith.trunci %left_mask : tensor<1x2x1xi32, #blocked4> to tensor<1x2x1xi16, #blocked4> loc(#loc238) + %left_idx_58 = tt.broadcast %left_idx : tensor<1x2x1xi16, #blocked4> -> tensor<256x2x1xi16, #blocked4> loc(#loc239) + %left_idx_59 = arith.muli %y_idx, %left_idx_58 : tensor<256x2x1xi16, #blocked4> loc(#loc239) + %input = arith.extsi %left_idx_59 : tensor<256x2x1xi16, #blocked4> to tensor<256x2x1xi32, #blocked4> loc(#loc304) + %left_idx_60 = "tt.reduce"(%input) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc240)), %left_idx_742: i32 loc(callsite(#loc1 at #loc240))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc320) + tt.reduce.return %left_idx_743 : i32 loc(#loc305) + }) : (tensor<256x2x1xi32, #blocked4>) -> tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc305) + %left_idx_61 = tt.expand_dims %left_idx_60 {axis = 1 : i32} : tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<256x1x1xi32, #blocked4> loc(#loc241) + %left_idx_62 = tt.broadcast %left_idx_61 : tensor<256x1x1xi32, #blocked4> -> tensor<256x2x1xi32, #blocked4> loc(#loc242) + %right_idx = arith.trunci %flip_40 : tensor<1x2x1xi32, #blocked4> to tensor<1x2x1xi16, #blocked4> loc(#loc243) + %right_idx_63 = tt.broadcast %right_idx : tensor<1x2x1xi16, #blocked4> -> tensor<256x2x1xi16, #blocked4> loc(#loc244) + %right_idx_64 = arith.muli %y_idx, %right_idx_63 : tensor<256x2x1xi16, #blocked4> loc(#loc244) + %input_65 = arith.extsi %right_idx_64 : tensor<256x2x1xi16, #blocked4> to tensor<256x2x1xi32, #blocked4> loc(#loc307) + %right_idx_66 = "tt.reduce"(%input_65) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc245)), %right_idx_742: i32 loc(callsite(#loc1 at #loc245))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc321) + tt.reduce.return %right_idx_743 : i32 loc(#loc308) + }) : (tensor<256x2x1xi32, #blocked4>) -> tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc308) + %right_idx_67 = tt.expand_dims %right_idx_66 {axis = 1 : i32} : tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<256x1x1xi32, #blocked4> loc(#loc246) + %right_idx_68 = tt.broadcast %right_idx_67 : tensor<256x1x1xi32, #blocked4> -> tensor<256x2x1xi32, #blocked4> loc(#loc247) + %left_idx_69 = tt.reshape %left_idx_62 : tensor<256x2x1xi32, #blocked4> -> tensor<32x16xi32, #blocked> loc(#loc248) + %right_idx_70 = tt.reshape %right_idx_68 : tensor<256x2x1xi32, #blocked4> -> tensor<32x16xi32, #blocked> loc(#loc249) + %cond = arith.cmpi slt, %ileft_56, %iright_57 : tensor<32x16xi32, #blocked> loc(#loc250) + %eq = arith.cmpi eq, %ileft_56, %iright_57 : tensor<32x16xi32, #blocked> loc(#loc251) + %cond_71 = arith.cmpi sgt, %left_idx_69, %right_idx_70 : tensor<32x16xi32, #blocked> loc(#loc252) + %cond_72 = arith.andi %eq, %cond_71 : tensor<32x16xi1, #blocked> loc(#loc253) + %cond_73 = arith.ori %cond, %cond_72 : tensor<32x16xi1, #blocked> loc(#loc254) + %cond_74 = arith.extui %cond_73 : tensor<32x16xi1, #blocked> to tensor<32x16xi32, #blocked> loc(#loc255) + %cond_75 = arith.xori %cond_74, %flip_44 : tensor<32x16xi32, #blocked> loc(#loc255) + %cond_76 = arith.cmpi ne, %cond_75, %cst_9 : tensor<32x16xi32, #blocked> loc(#loc256) + %ret = arith.xori %ileft_56, %iright_57 : tensor<32x16xi32, #blocked> loc(#loc257) + %ret_77 = arith.select %cond_76, %ret, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc258) + %ret_78 = arith.xori %tmp7, %ret_77 : tensor<32x16xi32, #blocked> loc(#loc259) + %new_idxs = arith.xori %left_idx_69, %right_idx_70 : tensor<32x16xi32, #blocked> loc(#loc260) + %new_idxs_79 = arith.select %cond_76, %new_idxs, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc261) + %new_idxs_80 = arith.extsi %tmp9 : tensor<1x16xi16, #blocked> to tensor<1x16xi32, #blocked> loc(#loc262) + %new_idxs_81 = tt.broadcast %new_idxs_80 : tensor<1x16xi32, #blocked> -> tensor<32x16xi32, #blocked> loc(#loc262) + %new_idxs_82 = arith.xori %new_idxs_81, %new_idxs_79 : tensor<32x16xi32, #blocked> loc(#loc262) + %flip_83 = tt.broadcast %flip_41 : tensor<1x2x1xi32, #blocked2> -> tensor<64x2x4xi32, #blocked2> loc(#loc212) + %flip_84 = tt.reshape %flip_83 : tensor<64x2x4xi32, #blocked2> -> tensor<32x16xi32, #blocked> loc(#loc213) + %y_85 = tt.reshape %ret_78 : tensor<32x16xi32, #blocked> -> tensor<128x2x2xi32, #blocked3> loc(#loc225) + %ileft_86 = tt.broadcast %left_mask_45 : tensor<1x2x1xi32, #blocked3> -> tensor<128x2x2xi32, #blocked3> loc(#loc227) + %ileft_87 = arith.muli %y_85, %ileft_86 : tensor<128x2x2xi32, #blocked3> loc(#loc227) + %ileft_88 = "tt.reduce"(%ileft_87) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc228)), %ileft_742: i32 loc(callsite(#loc1 at #loc228))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc318) + tt.reduce.return %ileft_743 : i32 loc(#loc300) + }) : (tensor<128x2x2xi32, #blocked3>) -> tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc300) + %ileft_89 = tt.expand_dims %ileft_88 {axis = 1 : i32} : tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<128x1x2xi32, #blocked3> loc(#loc229) + %ileft_90 = tt.broadcast %ileft_89 : tensor<128x1x2xi32, #blocked3> -> tensor<128x2x2xi32, #blocked3> loc(#loc230) + %iright_91 = arith.muli %y_85, %flip_43 : tensor<128x2x2xi32, #blocked3> loc(#loc231) + %iright_92 = "tt.reduce"(%iright_91) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc232)), %iright_742: i32 loc(callsite(#loc1 at #loc232))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc319) + tt.reduce.return %iright_743 : i32 loc(#loc302) + }) : (tensor<128x2x2xi32, #blocked3>) -> tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc302) + %iright_93 = tt.expand_dims %iright_92 {axis = 1 : i32} : tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<128x1x2xi32, #blocked3> loc(#loc233) + %iright_94 = tt.broadcast %iright_93 : tensor<128x1x2xi32, #blocked3> -> tensor<128x2x2xi32, #blocked3> loc(#loc234) + %ileft_95 = tt.reshape %ileft_90 : tensor<128x2x2xi32, #blocked3> -> tensor<32x16xi32, #blocked> loc(#loc235) + %iright_96 = tt.reshape %iright_94 : tensor<128x2x2xi32, #blocked3> -> tensor<32x16xi32, #blocked> loc(#loc236) + %y_idx_97 = tt.reshape %new_idxs_82 : tensor<32x16xi32, #blocked> -> tensor<128x2x2xi32, #blocked3> loc(#loc237) + %left_idx_98 = arith.muli %y_idx_97, %ileft_86 : tensor<128x2x2xi32, #blocked3> loc(#loc239) + %left_idx_99 = "tt.reduce"(%left_idx_98) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc240)), %left_idx_742: i32 loc(callsite(#loc1 at #loc240))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc320) + tt.reduce.return %left_idx_743 : i32 loc(#loc305) + }) : (tensor<128x2x2xi32, #blocked3>) -> tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc305) + %left_idx_100 = tt.expand_dims %left_idx_99 {axis = 1 : i32} : tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<128x1x2xi32, #blocked3> loc(#loc241) + %left_idx_101 = tt.broadcast %left_idx_100 : tensor<128x1x2xi32, #blocked3> -> tensor<128x2x2xi32, #blocked3> loc(#loc242) + %right_idx_102 = arith.muli %y_idx_97, %flip_43 : tensor<128x2x2xi32, #blocked3> loc(#loc244) + %right_idx_103 = "tt.reduce"(%right_idx_102) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc245)), %right_idx_742: i32 loc(callsite(#loc1 at #loc245))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc321) + tt.reduce.return %right_idx_743 : i32 loc(#loc308) + }) : (tensor<128x2x2xi32, #blocked3>) -> tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc308) + %right_idx_104 = tt.expand_dims %right_idx_103 {axis = 1 : i32} : tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<128x1x2xi32, #blocked3> loc(#loc246) + %right_idx_105 = tt.broadcast %right_idx_104 : tensor<128x1x2xi32, #blocked3> -> tensor<128x2x2xi32, #blocked3> loc(#loc247) + %left_idx_106 = tt.reshape %left_idx_101 : tensor<128x2x2xi32, #blocked3> -> tensor<32x16xi32, #blocked> loc(#loc248) + %right_idx_107 = tt.reshape %right_idx_105 : tensor<128x2x2xi32, #blocked3> -> tensor<32x16xi32, #blocked> loc(#loc249) + %cond_108 = arith.cmpi slt, %ileft_95, %iright_96 : tensor<32x16xi32, #blocked> loc(#loc250) + %eq_109 = arith.cmpi eq, %ileft_95, %iright_96 : tensor<32x16xi32, #blocked> loc(#loc251) + %cond_110 = arith.cmpi sgt, %left_idx_106, %right_idx_107 : tensor<32x16xi32, #blocked> loc(#loc252) + %cond_111 = arith.andi %eq_109, %cond_110 : tensor<32x16xi1, #blocked> loc(#loc253) + %cond_112 = arith.ori %cond_108, %cond_111 : tensor<32x16xi1, #blocked> loc(#loc254) + %cond_113 = arith.extui %cond_112 : tensor<32x16xi1, #blocked> to tensor<32x16xi32, #blocked> loc(#loc255) + %cond_114 = arith.xori %cond_113, %flip_84 : tensor<32x16xi32, #blocked> loc(#loc255) + %cond_115 = arith.cmpi ne, %cond_114, %cst_9 : tensor<32x16xi32, #blocked> loc(#loc256) + %ret_116 = arith.xori %ileft_95, %iright_96 : tensor<32x16xi32, #blocked> loc(#loc257) + %ret_117 = arith.select %cond_115, %ret_116, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc258) + %ret_118 = arith.xori %ret_78, %ret_117 : tensor<32x16xi32, #blocked> loc(#loc259) + %new_idxs_119 = arith.xori %left_idx_106, %right_idx_107 : tensor<32x16xi32, #blocked> loc(#loc260) + %new_idxs_120 = arith.select %cond_115, %new_idxs_119, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc261) + %new_idxs_121 = arith.xori %new_idxs_82, %new_idxs_120 : tensor<32x16xi32, #blocked> loc(#loc262) + %y_122 = tt.reshape %ret_118 : tensor<32x16xi32, #blocked> -> tensor<256x2x1xi32, #blocked4> loc(#loc225) + %ileft_123 = arith.muli %y_122, %ileft : tensor<256x2x1xi32, #blocked4> loc(#loc227) + %ileft_124 = "tt.reduce"(%ileft_123) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc228)), %ileft_742: i32 loc(callsite(#loc1 at #loc228))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc318) + tt.reduce.return %ileft_743 : i32 loc(#loc300) + }) : (tensor<256x2x1xi32, #blocked4>) -> tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc300) + %ileft_125 = tt.expand_dims %ileft_124 {axis = 1 : i32} : tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<256x1x1xi32, #blocked4> loc(#loc229) + %ileft_126 = tt.broadcast %ileft_125 : tensor<256x1x1xi32, #blocked4> -> tensor<256x2x1xi32, #blocked4> loc(#loc230) + %iright_127 = arith.muli %y_122, %iright : tensor<256x2x1xi32, #blocked4> loc(#loc231) + %iright_128 = "tt.reduce"(%iright_127) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc232)), %iright_742: i32 loc(callsite(#loc1 at #loc232))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc319) + tt.reduce.return %iright_743 : i32 loc(#loc302) + }) : (tensor<256x2x1xi32, #blocked4>) -> tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc302) + %iright_129 = tt.expand_dims %iright_128 {axis = 1 : i32} : tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<256x1x1xi32, #blocked4> loc(#loc233) + %iright_130 = tt.broadcast %iright_129 : tensor<256x1x1xi32, #blocked4> -> tensor<256x2x1xi32, #blocked4> loc(#loc234) + %ileft_131 = tt.reshape %ileft_126 : tensor<256x2x1xi32, #blocked4> -> tensor<32x16xi32, #blocked> loc(#loc235) + %iright_132 = tt.reshape %iright_130 : tensor<256x2x1xi32, #blocked4> -> tensor<32x16xi32, #blocked> loc(#loc236) + %y_idx_133 = tt.reshape %new_idxs_121 : tensor<32x16xi32, #blocked> -> tensor<256x2x1xi32, #blocked4> loc(#loc237) + %left_idx_134 = arith.muli %y_idx_133, %ileft : tensor<256x2x1xi32, #blocked4> loc(#loc239) + %left_idx_135 = "tt.reduce"(%left_idx_134) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc240)), %left_idx_742: i32 loc(callsite(#loc1 at #loc240))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc320) + tt.reduce.return %left_idx_743 : i32 loc(#loc305) + }) : (tensor<256x2x1xi32, #blocked4>) -> tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc305) + %left_idx_136 = tt.expand_dims %left_idx_135 {axis = 1 : i32} : tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<256x1x1xi32, #blocked4> loc(#loc241) + %left_idx_137 = tt.broadcast %left_idx_136 : tensor<256x1x1xi32, #blocked4> -> tensor<256x2x1xi32, #blocked4> loc(#loc242) + %right_idx_138 = arith.muli %y_idx_133, %iright : tensor<256x2x1xi32, #blocked4> loc(#loc244) + %right_idx_139 = "tt.reduce"(%right_idx_138) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc245)), %right_idx_742: i32 loc(callsite(#loc1 at #loc245))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc321) + tt.reduce.return %right_idx_743 : i32 loc(#loc308) + }) : (tensor<256x2x1xi32, #blocked4>) -> tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc308) + %right_idx_140 = tt.expand_dims %right_idx_139 {axis = 1 : i32} : tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<256x1x1xi32, #blocked4> loc(#loc246) + %right_idx_141 = tt.broadcast %right_idx_140 : tensor<256x1x1xi32, #blocked4> -> tensor<256x2x1xi32, #blocked4> loc(#loc247) + %left_idx_142 = tt.reshape %left_idx_137 : tensor<256x2x1xi32, #blocked4> -> tensor<32x16xi32, #blocked> loc(#loc248) + %right_idx_143 = tt.reshape %right_idx_141 : tensor<256x2x1xi32, #blocked4> -> tensor<32x16xi32, #blocked> loc(#loc249) + %cond_144 = arith.cmpi slt, %ileft_131, %iright_132 : tensor<32x16xi32, #blocked> loc(#loc250) + %eq_145 = arith.cmpi eq, %ileft_131, %iright_132 : tensor<32x16xi32, #blocked> loc(#loc251) + %cond_146 = arith.cmpi sgt, %left_idx_142, %right_idx_143 : tensor<32x16xi32, #blocked> loc(#loc252) + %cond_147 = arith.andi %eq_145, %cond_146 : tensor<32x16xi1, #blocked> loc(#loc253) + %cond_148 = arith.ori %cond_144, %cond_147 : tensor<32x16xi1, #blocked> loc(#loc254) + %cond_149 = arith.extui %cond_148 : tensor<32x16xi1, #blocked> to tensor<32x16xi32, #blocked> loc(#loc255) + %cond_150 = arith.xori %cond_149, %flip_84 : tensor<32x16xi32, #blocked> loc(#loc255) + %cond_151 = arith.cmpi ne, %cond_150, %cst_9 : tensor<32x16xi32, #blocked> loc(#loc256) + %ret_152 = arith.xori %ileft_131, %iright_132 : tensor<32x16xi32, #blocked> loc(#loc257) + %ret_153 = arith.select %cond_151, %ret_152, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc258) + %ret_154 = arith.xori %ret_118, %ret_153 : tensor<32x16xi32, #blocked> loc(#loc259) + %new_idxs_155 = arith.xori %left_idx_142, %right_idx_143 : tensor<32x16xi32, #blocked> loc(#loc260) + %new_idxs_156 = arith.select %cond_151, %new_idxs_155, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc261) + %new_idxs_157 = arith.xori %new_idxs_121, %new_idxs_156 : tensor<32x16xi32, #blocked> loc(#loc262) + %flip_158 = tt.broadcast %flip_42 : tensor<1x2x1xi32, #blocked1> -> tensor<32x2x8xi32, #blocked1> loc(#loc212) + %flip_159 = tt.reshape %flip_158 : tensor<32x2x8xi32, #blocked1> -> tensor<32x16xi32, #blocked> loc(#loc213) + %y_160 = tt.reshape %ret_154 : tensor<32x16xi32, #blocked> -> tensor<64x2x4xi32, #blocked2> loc(#loc225) + %ileft_161 = tt.broadcast %left_mask_46 : tensor<1x2x1xi32, #blocked2> -> tensor<64x2x4xi32, #blocked2> loc(#loc227) + %ileft_162 = arith.muli %y_160, %ileft_161 : tensor<64x2x4xi32, #blocked2> loc(#loc227) + %ileft_163 = "tt.reduce"(%ileft_162) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc228)), %ileft_742: i32 loc(callsite(#loc1 at #loc228))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc318) + tt.reduce.return %ileft_743 : i32 loc(#loc300) + }) : (tensor<64x2x4xi32, #blocked2>) -> tensor<64x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc300) + %ileft_164 = tt.expand_dims %ileft_163 {axis = 1 : i32} : tensor<64x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x4xi32, #blocked2> loc(#loc229) + %ileft_165 = tt.broadcast %ileft_164 : tensor<64x1x4xi32, #blocked2> -> tensor<64x2x4xi32, #blocked2> loc(#loc230) + %iright_166 = arith.muli %y_160, %flip_83 : tensor<64x2x4xi32, #blocked2> loc(#loc231) + %iright_167 = "tt.reduce"(%iright_166) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc232)), %iright_742: i32 loc(callsite(#loc1 at #loc232))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc319) + tt.reduce.return %iright_743 : i32 loc(#loc302) + }) : (tensor<64x2x4xi32, #blocked2>) -> tensor<64x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc302) + %iright_168 = tt.expand_dims %iright_167 {axis = 1 : i32} : tensor<64x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x4xi32, #blocked2> loc(#loc233) + %iright_169 = tt.broadcast %iright_168 : tensor<64x1x4xi32, #blocked2> -> tensor<64x2x4xi32, #blocked2> loc(#loc234) + %ileft_170 = tt.reshape %ileft_165 : tensor<64x2x4xi32, #blocked2> -> tensor<32x16xi32, #blocked> loc(#loc235) + %iright_171 = tt.reshape %iright_169 : tensor<64x2x4xi32, #blocked2> -> tensor<32x16xi32, #blocked> loc(#loc236) + %y_idx_172 = tt.reshape %new_idxs_157 : tensor<32x16xi32, #blocked> -> tensor<64x2x4xi32, #blocked2> loc(#loc237) + %left_idx_173 = arith.muli %y_idx_172, %ileft_161 : tensor<64x2x4xi32, #blocked2> loc(#loc239) + %left_idx_174 = "tt.reduce"(%left_idx_173) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc240)), %left_idx_742: i32 loc(callsite(#loc1 at #loc240))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc320) + tt.reduce.return %left_idx_743 : i32 loc(#loc305) + }) : (tensor<64x2x4xi32, #blocked2>) -> tensor<64x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc305) + %left_idx_175 = tt.expand_dims %left_idx_174 {axis = 1 : i32} : tensor<64x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x4xi32, #blocked2> loc(#loc241) + %left_idx_176 = tt.broadcast %left_idx_175 : tensor<64x1x4xi32, #blocked2> -> tensor<64x2x4xi32, #blocked2> loc(#loc242) + %right_idx_177 = arith.muli %y_idx_172, %flip_83 : tensor<64x2x4xi32, #blocked2> loc(#loc244) + %right_idx_178 = "tt.reduce"(%right_idx_177) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc245)), %right_idx_742: i32 loc(callsite(#loc1 at #loc245))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc321) + tt.reduce.return %right_idx_743 : i32 loc(#loc308) + }) : (tensor<64x2x4xi32, #blocked2>) -> tensor<64x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc308) + %right_idx_179 = tt.expand_dims %right_idx_178 {axis = 1 : i32} : tensor<64x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x4xi32, #blocked2> loc(#loc246) + %right_idx_180 = tt.broadcast %right_idx_179 : tensor<64x1x4xi32, #blocked2> -> tensor<64x2x4xi32, #blocked2> loc(#loc247) + %left_idx_181 = tt.reshape %left_idx_176 : tensor<64x2x4xi32, #blocked2> -> tensor<32x16xi32, #blocked> loc(#loc248) + %right_idx_182 = tt.reshape %right_idx_180 : tensor<64x2x4xi32, #blocked2> -> tensor<32x16xi32, #blocked> loc(#loc249) + %cond_183 = arith.cmpi slt, %ileft_170, %iright_171 : tensor<32x16xi32, #blocked> loc(#loc250) + %eq_184 = arith.cmpi eq, %ileft_170, %iright_171 : tensor<32x16xi32, #blocked> loc(#loc251) + %cond_185 = arith.cmpi sgt, %left_idx_181, %right_idx_182 : tensor<32x16xi32, #blocked> loc(#loc252) + %cond_186 = arith.andi %eq_184, %cond_185 : tensor<32x16xi1, #blocked> loc(#loc253) + %cond_187 = arith.ori %cond_183, %cond_186 : tensor<32x16xi1, #blocked> loc(#loc254) + %cond_188 = arith.extui %cond_187 : tensor<32x16xi1, #blocked> to tensor<32x16xi32, #blocked> loc(#loc255) + %cond_189 = arith.xori %cond_188, %flip_159 : tensor<32x16xi32, #blocked> loc(#loc255) + %cond_190 = arith.cmpi ne, %cond_189, %cst_9 : tensor<32x16xi32, #blocked> loc(#loc256) + %ret_191 = arith.xori %ileft_170, %iright_171 : tensor<32x16xi32, #blocked> loc(#loc257) + %ret_192 = arith.select %cond_190, %ret_191, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc258) + %ret_193 = arith.xori %ret_154, %ret_192 : tensor<32x16xi32, #blocked> loc(#loc259) + %new_idxs_194 = arith.xori %left_idx_181, %right_idx_182 : tensor<32x16xi32, #blocked> loc(#loc260) + %new_idxs_195 = arith.select %cond_190, %new_idxs_194, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc261) + %new_idxs_196 = arith.xori %new_idxs_157, %new_idxs_195 : tensor<32x16xi32, #blocked> loc(#loc262) + %y_197 = tt.reshape %ret_193 : tensor<32x16xi32, #blocked> -> tensor<128x2x2xi32, #blocked3> loc(#loc225) + %ileft_198 = arith.muli %y_197, %ileft_86 : tensor<128x2x2xi32, #blocked3> loc(#loc227) + %ileft_199 = "tt.reduce"(%ileft_198) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc228)), %ileft_742: i32 loc(callsite(#loc1 at #loc228))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc318) + tt.reduce.return %ileft_743 : i32 loc(#loc300) + }) : (tensor<128x2x2xi32, #blocked3>) -> tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc300) + %ileft_200 = tt.expand_dims %ileft_199 {axis = 1 : i32} : tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<128x1x2xi32, #blocked3> loc(#loc229) + %ileft_201 = tt.broadcast %ileft_200 : tensor<128x1x2xi32, #blocked3> -> tensor<128x2x2xi32, #blocked3> loc(#loc230) + %iright_202 = arith.muli %y_197, %flip_43 : tensor<128x2x2xi32, #blocked3> loc(#loc231) + %iright_203 = "tt.reduce"(%iright_202) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc232)), %iright_742: i32 loc(callsite(#loc1 at #loc232))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc319) + tt.reduce.return %iright_743 : i32 loc(#loc302) + }) : (tensor<128x2x2xi32, #blocked3>) -> tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc302) + %iright_204 = tt.expand_dims %iright_203 {axis = 1 : i32} : tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<128x1x2xi32, #blocked3> loc(#loc233) + %iright_205 = tt.broadcast %iright_204 : tensor<128x1x2xi32, #blocked3> -> tensor<128x2x2xi32, #blocked3> loc(#loc234) + %ileft_206 = tt.reshape %ileft_201 : tensor<128x2x2xi32, #blocked3> -> tensor<32x16xi32, #blocked> loc(#loc235) + %iright_207 = tt.reshape %iright_205 : tensor<128x2x2xi32, #blocked3> -> tensor<32x16xi32, #blocked> loc(#loc236) + %y_idx_208 = tt.reshape %new_idxs_196 : tensor<32x16xi32, #blocked> -> tensor<128x2x2xi32, #blocked3> loc(#loc237) + %left_idx_209 = arith.muli %y_idx_208, %ileft_86 : tensor<128x2x2xi32, #blocked3> loc(#loc239) + %left_idx_210 = "tt.reduce"(%left_idx_209) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc240)), %left_idx_742: i32 loc(callsite(#loc1 at #loc240))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc320) + tt.reduce.return %left_idx_743 : i32 loc(#loc305) + }) : (tensor<128x2x2xi32, #blocked3>) -> tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc305) + %left_idx_211 = tt.expand_dims %left_idx_210 {axis = 1 : i32} : tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<128x1x2xi32, #blocked3> loc(#loc241) + %left_idx_212 = tt.broadcast %left_idx_211 : tensor<128x1x2xi32, #blocked3> -> tensor<128x2x2xi32, #blocked3> loc(#loc242) + %right_idx_213 = arith.muli %y_idx_208, %flip_43 : tensor<128x2x2xi32, #blocked3> loc(#loc244) + %right_idx_214 = "tt.reduce"(%right_idx_213) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc245)), %right_idx_742: i32 loc(callsite(#loc1 at #loc245))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc321) + tt.reduce.return %right_idx_743 : i32 loc(#loc308) + }) : (tensor<128x2x2xi32, #blocked3>) -> tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc308) + %right_idx_215 = tt.expand_dims %right_idx_214 {axis = 1 : i32} : tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<128x1x2xi32, #blocked3> loc(#loc246) + %right_idx_216 = tt.broadcast %right_idx_215 : tensor<128x1x2xi32, #blocked3> -> tensor<128x2x2xi32, #blocked3> loc(#loc247) + %left_idx_217 = tt.reshape %left_idx_212 : tensor<128x2x2xi32, #blocked3> -> tensor<32x16xi32, #blocked> loc(#loc248) + %right_idx_218 = tt.reshape %right_idx_216 : tensor<128x2x2xi32, #blocked3> -> tensor<32x16xi32, #blocked> loc(#loc249) + %cond_219 = arith.cmpi slt, %ileft_206, %iright_207 : tensor<32x16xi32, #blocked> loc(#loc250) + %eq_220 = arith.cmpi eq, %ileft_206, %iright_207 : tensor<32x16xi32, #blocked> loc(#loc251) + %cond_221 = arith.cmpi sgt, %left_idx_217, %right_idx_218 : tensor<32x16xi32, #blocked> loc(#loc252) + %cond_222 = arith.andi %eq_220, %cond_221 : tensor<32x16xi1, #blocked> loc(#loc253) + %cond_223 = arith.ori %cond_219, %cond_222 : tensor<32x16xi1, #blocked> loc(#loc254) + %cond_224 = arith.extui %cond_223 : tensor<32x16xi1, #blocked> to tensor<32x16xi32, #blocked> loc(#loc255) + %cond_225 = arith.xori %cond_224, %flip_159 : tensor<32x16xi32, #blocked> loc(#loc255) + %cond_226 = arith.cmpi ne, %cond_225, %cst_9 : tensor<32x16xi32, #blocked> loc(#loc256) + %ret_227 = arith.xori %ileft_206, %iright_207 : tensor<32x16xi32, #blocked> loc(#loc257) + %ret_228 = arith.select %cond_226, %ret_227, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc258) + %ret_229 = arith.xori %ret_193, %ret_228 : tensor<32x16xi32, #blocked> loc(#loc259) + %new_idxs_230 = arith.xori %left_idx_217, %right_idx_218 : tensor<32x16xi32, #blocked> loc(#loc260) + %new_idxs_231 = arith.select %cond_226, %new_idxs_230, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc261) + %new_idxs_232 = arith.xori %new_idxs_196, %new_idxs_231 : tensor<32x16xi32, #blocked> loc(#loc262) + %y_233 = tt.reshape %ret_229 : tensor<32x16xi32, #blocked> -> tensor<256x2x1xi32, #blocked4> loc(#loc225) + %ileft_234 = arith.muli %y_233, %ileft : tensor<256x2x1xi32, #blocked4> loc(#loc227) + %ileft_235 = "tt.reduce"(%ileft_234) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc228)), %ileft_742: i32 loc(callsite(#loc1 at #loc228))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc318) + tt.reduce.return %ileft_743 : i32 loc(#loc300) + }) : (tensor<256x2x1xi32, #blocked4>) -> tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc300) + %ileft_236 = tt.expand_dims %ileft_235 {axis = 1 : i32} : tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<256x1x1xi32, #blocked4> loc(#loc229) + %ileft_237 = tt.broadcast %ileft_236 : tensor<256x1x1xi32, #blocked4> -> tensor<256x2x1xi32, #blocked4> loc(#loc230) + %iright_238 = arith.muli %y_233, %iright : tensor<256x2x1xi32, #blocked4> loc(#loc231) + %iright_239 = "tt.reduce"(%iright_238) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc232)), %iright_742: i32 loc(callsite(#loc1 at #loc232))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc319) + tt.reduce.return %iright_743 : i32 loc(#loc302) + }) : (tensor<256x2x1xi32, #blocked4>) -> tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc302) + %iright_240 = tt.expand_dims %iright_239 {axis = 1 : i32} : tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<256x1x1xi32, #blocked4> loc(#loc233) + %iright_241 = tt.broadcast %iright_240 : tensor<256x1x1xi32, #blocked4> -> tensor<256x2x1xi32, #blocked4> loc(#loc234) + %ileft_242 = tt.reshape %ileft_237 : tensor<256x2x1xi32, #blocked4> -> tensor<32x16xi32, #blocked> loc(#loc235) + %iright_243 = tt.reshape %iright_241 : tensor<256x2x1xi32, #blocked4> -> tensor<32x16xi32, #blocked> loc(#loc236) + %y_idx_244 = tt.reshape %new_idxs_232 : tensor<32x16xi32, #blocked> -> tensor<256x2x1xi32, #blocked4> loc(#loc237) + %left_idx_245 = arith.muli %y_idx_244, %ileft : tensor<256x2x1xi32, #blocked4> loc(#loc239) + %left_idx_246 = "tt.reduce"(%left_idx_245) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc240)), %left_idx_742: i32 loc(callsite(#loc1 at #loc240))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc320) + tt.reduce.return %left_idx_743 : i32 loc(#loc305) + }) : (tensor<256x2x1xi32, #blocked4>) -> tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc305) + %left_idx_247 = tt.expand_dims %left_idx_246 {axis = 1 : i32} : tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<256x1x1xi32, #blocked4> loc(#loc241) + %left_idx_248 = tt.broadcast %left_idx_247 : tensor<256x1x1xi32, #blocked4> -> tensor<256x2x1xi32, #blocked4> loc(#loc242) + %right_idx_249 = arith.muli %y_idx_244, %iright : tensor<256x2x1xi32, #blocked4> loc(#loc244) + %right_idx_250 = "tt.reduce"(%right_idx_249) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc245)), %right_idx_742: i32 loc(callsite(#loc1 at #loc245))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc321) + tt.reduce.return %right_idx_743 : i32 loc(#loc308) + }) : (tensor<256x2x1xi32, #blocked4>) -> tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc308) + %right_idx_251 = tt.expand_dims %right_idx_250 {axis = 1 : i32} : tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<256x1x1xi32, #blocked4> loc(#loc246) + %right_idx_252 = tt.broadcast %right_idx_251 : tensor<256x1x1xi32, #blocked4> -> tensor<256x2x1xi32, #blocked4> loc(#loc247) + %left_idx_253 = tt.reshape %left_idx_248 : tensor<256x2x1xi32, #blocked4> -> tensor<32x16xi32, #blocked> loc(#loc248) + %right_idx_254 = tt.reshape %right_idx_252 : tensor<256x2x1xi32, #blocked4> -> tensor<32x16xi32, #blocked> loc(#loc249) + %cond_255 = arith.cmpi slt, %ileft_242, %iright_243 : tensor<32x16xi32, #blocked> loc(#loc250) + %eq_256 = arith.cmpi eq, %ileft_242, %iright_243 : tensor<32x16xi32, #blocked> loc(#loc251) + %cond_257 = arith.cmpi sgt, %left_idx_253, %right_idx_254 : tensor<32x16xi32, #blocked> loc(#loc252) + %cond_258 = arith.andi %eq_256, %cond_257 : tensor<32x16xi1, #blocked> loc(#loc253) + %cond_259 = arith.ori %cond_255, %cond_258 : tensor<32x16xi1, #blocked> loc(#loc254) + %cond_260 = arith.extui %cond_259 : tensor<32x16xi1, #blocked> to tensor<32x16xi32, #blocked> loc(#loc255) + %cond_261 = arith.xori %cond_260, %flip_159 : tensor<32x16xi32, #blocked> loc(#loc255) + %cond_262 = arith.cmpi ne, %cond_261, %cst_9 : tensor<32x16xi32, #blocked> loc(#loc256) + %ret_263 = arith.xori %ileft_242, %iright_243 : tensor<32x16xi32, #blocked> loc(#loc257) + %ret_264 = arith.select %cond_262, %ret_263, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc258) + %ret_265 = arith.xori %ret_229, %ret_264 : tensor<32x16xi32, #blocked> loc(#loc259) + %new_idxs_266 = arith.xori %left_idx_253, %right_idx_254 : tensor<32x16xi32, #blocked> loc(#loc260) + %new_idxs_267 = arith.select %cond_262, %new_idxs_266, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc261) + %new_idxs_268 = arith.xori %new_idxs_232, %new_idxs_267 : tensor<32x16xi32, #blocked> loc(#loc262) + %y_269 = tt.reshape %ret_265 : tensor<32x16xi32, #blocked> -> tensor<32x2x8xi32, #blocked1> loc(#loc225) + %ileft_270 = tt.broadcast %left_mask_47 : tensor<1x2x1xi32, #blocked1> -> tensor<32x2x8xi32, #blocked1> loc(#loc227) + %ileft_271 = arith.muli %y_269, %ileft_270 : tensor<32x2x8xi32, #blocked1> loc(#loc227) + %ileft_272 = "tt.reduce"(%ileft_271) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc228)), %ileft_742: i32 loc(callsite(#loc1 at #loc228))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc318) + tt.reduce.return %ileft_743 : i32 loc(#loc300) + }) : (tensor<32x2x8xi32, #blocked1>) -> tensor<32x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc300) + %ileft_273 = tt.expand_dims %ileft_272 {axis = 1 : i32} : tensor<32x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1x8xi32, #blocked1> loc(#loc229) + %ileft_274 = tt.broadcast %ileft_273 : tensor<32x1x8xi32, #blocked1> -> tensor<32x2x8xi32, #blocked1> loc(#loc230) + %iright_275 = arith.muli %y_269, %flip_158 : tensor<32x2x8xi32, #blocked1> loc(#loc231) + %iright_276 = "tt.reduce"(%iright_275) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc232)), %iright_742: i32 loc(callsite(#loc1 at #loc232))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc319) + tt.reduce.return %iright_743 : i32 loc(#loc302) + }) : (tensor<32x2x8xi32, #blocked1>) -> tensor<32x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc302) + %iright_277 = tt.expand_dims %iright_276 {axis = 1 : i32} : tensor<32x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1x8xi32, #blocked1> loc(#loc233) + %iright_278 = tt.broadcast %iright_277 : tensor<32x1x8xi32, #blocked1> -> tensor<32x2x8xi32, #blocked1> loc(#loc234) + %ileft_279 = tt.reshape %ileft_274 : tensor<32x2x8xi32, #blocked1> -> tensor<32x16xi32, #blocked> loc(#loc235) + %iright_280 = tt.reshape %iright_278 : tensor<32x2x8xi32, #blocked1> -> tensor<32x16xi32, #blocked> loc(#loc236) + %y_idx_281 = tt.reshape %new_idxs_268 : tensor<32x16xi32, #blocked> -> tensor<32x2x8xi32, #blocked1> loc(#loc237) + %left_idx_282 = arith.muli %y_idx_281, %ileft_270 : tensor<32x2x8xi32, #blocked1> loc(#loc239) + %left_idx_283 = "tt.reduce"(%left_idx_282) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc240)), %left_idx_742: i32 loc(callsite(#loc1 at #loc240))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc320) + tt.reduce.return %left_idx_743 : i32 loc(#loc305) + }) : (tensor<32x2x8xi32, #blocked1>) -> tensor<32x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc305) + %left_idx_284 = tt.expand_dims %left_idx_283 {axis = 1 : i32} : tensor<32x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1x8xi32, #blocked1> loc(#loc241) + %left_idx_285 = tt.broadcast %left_idx_284 : tensor<32x1x8xi32, #blocked1> -> tensor<32x2x8xi32, #blocked1> loc(#loc242) + %right_idx_286 = arith.muli %y_idx_281, %flip_158 : tensor<32x2x8xi32, #blocked1> loc(#loc244) + %right_idx_287 = "tt.reduce"(%right_idx_286) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc245)), %right_idx_742: i32 loc(callsite(#loc1 at #loc245))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc321) + tt.reduce.return %right_idx_743 : i32 loc(#loc308) + }) : (tensor<32x2x8xi32, #blocked1>) -> tensor<32x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc308) + %right_idx_288 = tt.expand_dims %right_idx_287 {axis = 1 : i32} : tensor<32x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1x8xi32, #blocked1> loc(#loc246) + %right_idx_289 = tt.broadcast %right_idx_288 : tensor<32x1x8xi32, #blocked1> -> tensor<32x2x8xi32, #blocked1> loc(#loc247) + %left_idx_290 = tt.reshape %left_idx_285 : tensor<32x2x8xi32, #blocked1> -> tensor<32x16xi32, #blocked> loc(#loc248) + %right_idx_291 = tt.reshape %right_idx_289 : tensor<32x2x8xi32, #blocked1> -> tensor<32x16xi32, #blocked> loc(#loc249) + %cond_292 = arith.cmpi slt, %ileft_279, %iright_280 : tensor<32x16xi32, #blocked> loc(#loc250) + %eq_293 = arith.cmpi eq, %ileft_279, %iright_280 : tensor<32x16xi32, #blocked> loc(#loc251) + %cond_294 = arith.cmpi sgt, %left_idx_290, %right_idx_291 : tensor<32x16xi32, #blocked> loc(#loc252) + %cond_295 = arith.andi %eq_293, %cond_294 : tensor<32x16xi1, #blocked> loc(#loc253) + %cond_296 = arith.ori %cond_292, %cond_295 : tensor<32x16xi1, #blocked> loc(#loc254) + %ret_297 = arith.xori %ileft_279, %iright_280 : tensor<32x16xi32, #blocked> loc(#loc257) + %ret_298 = arith.select %cond_296, %ret_297, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc258) + %ret_299 = arith.xori %ret_265, %ret_298 : tensor<32x16xi32, #blocked> loc(#loc259) + %new_idxs_300 = arith.xori %left_idx_290, %right_idx_291 : tensor<32x16xi32, #blocked> loc(#loc260) + %new_idxs_301 = arith.select %cond_296, %new_idxs_300, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc261) + %new_idxs_302 = arith.xori %new_idxs_268, %new_idxs_301 : tensor<32x16xi32, #blocked> loc(#loc262) + %y_303 = tt.reshape %ret_299 : tensor<32x16xi32, #blocked> -> tensor<64x2x4xi32, #blocked2> loc(#loc225) + %ileft_304 = arith.muli %y_303, %ileft_161 : tensor<64x2x4xi32, #blocked2> loc(#loc227) + %ileft_305 = "tt.reduce"(%ileft_304) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc228)), %ileft_742: i32 loc(callsite(#loc1 at #loc228))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc318) + tt.reduce.return %ileft_743 : i32 loc(#loc300) + }) : (tensor<64x2x4xi32, #blocked2>) -> tensor<64x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc300) + %ileft_306 = tt.expand_dims %ileft_305 {axis = 1 : i32} : tensor<64x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x4xi32, #blocked2> loc(#loc229) + %ileft_307 = tt.broadcast %ileft_306 : tensor<64x1x4xi32, #blocked2> -> tensor<64x2x4xi32, #blocked2> loc(#loc230) + %iright_308 = arith.muli %y_303, %flip_83 : tensor<64x2x4xi32, #blocked2> loc(#loc231) + %iright_309 = "tt.reduce"(%iright_308) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc232)), %iright_742: i32 loc(callsite(#loc1 at #loc232))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc319) + tt.reduce.return %iright_743 : i32 loc(#loc302) + }) : (tensor<64x2x4xi32, #blocked2>) -> tensor<64x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc302) + %iright_310 = tt.expand_dims %iright_309 {axis = 1 : i32} : tensor<64x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x4xi32, #blocked2> loc(#loc233) + %iright_311 = tt.broadcast %iright_310 : tensor<64x1x4xi32, #blocked2> -> tensor<64x2x4xi32, #blocked2> loc(#loc234) + %ileft_312 = tt.reshape %ileft_307 : tensor<64x2x4xi32, #blocked2> -> tensor<32x16xi32, #blocked> loc(#loc235) + %iright_313 = tt.reshape %iright_311 : tensor<64x2x4xi32, #blocked2> -> tensor<32x16xi32, #blocked> loc(#loc236) + %y_idx_314 = tt.reshape %new_idxs_302 : tensor<32x16xi32, #blocked> -> tensor<64x2x4xi32, #blocked2> loc(#loc237) + %left_idx_315 = arith.muli %y_idx_314, %ileft_161 : tensor<64x2x4xi32, #blocked2> loc(#loc239) + %left_idx_316 = "tt.reduce"(%left_idx_315) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc240)), %left_idx_742: i32 loc(callsite(#loc1 at #loc240))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc320) + tt.reduce.return %left_idx_743 : i32 loc(#loc305) + }) : (tensor<64x2x4xi32, #blocked2>) -> tensor<64x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc305) + %left_idx_317 = tt.expand_dims %left_idx_316 {axis = 1 : i32} : tensor<64x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x4xi32, #blocked2> loc(#loc241) + %left_idx_318 = tt.broadcast %left_idx_317 : tensor<64x1x4xi32, #blocked2> -> tensor<64x2x4xi32, #blocked2> loc(#loc242) + %right_idx_319 = arith.muli %y_idx_314, %flip_83 : tensor<64x2x4xi32, #blocked2> loc(#loc244) + %right_idx_320 = "tt.reduce"(%right_idx_319) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc245)), %right_idx_742: i32 loc(callsite(#loc1 at #loc245))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc321) + tt.reduce.return %right_idx_743 : i32 loc(#loc308) + }) : (tensor<64x2x4xi32, #blocked2>) -> tensor<64x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc308) + %right_idx_321 = tt.expand_dims %right_idx_320 {axis = 1 : i32} : tensor<64x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x4xi32, #blocked2> loc(#loc246) + %right_idx_322 = tt.broadcast %right_idx_321 : tensor<64x1x4xi32, #blocked2> -> tensor<64x2x4xi32, #blocked2> loc(#loc247) + %left_idx_323 = tt.reshape %left_idx_318 : tensor<64x2x4xi32, #blocked2> -> tensor<32x16xi32, #blocked> loc(#loc248) + %right_idx_324 = tt.reshape %right_idx_322 : tensor<64x2x4xi32, #blocked2> -> tensor<32x16xi32, #blocked> loc(#loc249) + %cond_325 = arith.cmpi slt, %ileft_312, %iright_313 : tensor<32x16xi32, #blocked> loc(#loc250) + %eq_326 = arith.cmpi eq, %ileft_312, %iright_313 : tensor<32x16xi32, #blocked> loc(#loc251) + %cond_327 = arith.cmpi sgt, %left_idx_323, %right_idx_324 : tensor<32x16xi32, #blocked> loc(#loc252) + %cond_328 = arith.andi %eq_326, %cond_327 : tensor<32x16xi1, #blocked> loc(#loc253) + %cond_329 = arith.ori %cond_325, %cond_328 : tensor<32x16xi1, #blocked> loc(#loc254) + %ret_330 = arith.xori %ileft_312, %iright_313 : tensor<32x16xi32, #blocked> loc(#loc257) + %ret_331 = arith.select %cond_329, %ret_330, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc258) + %ret_332 = arith.xori %ret_299, %ret_331 : tensor<32x16xi32, #blocked> loc(#loc259) + %new_idxs_333 = arith.xori %left_idx_323, %right_idx_324 : tensor<32x16xi32, #blocked> loc(#loc260) + %new_idxs_334 = arith.select %cond_329, %new_idxs_333, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc261) + %new_idxs_335 = arith.xori %new_idxs_302, %new_idxs_334 : tensor<32x16xi32, #blocked> loc(#loc262) + %y_336 = tt.reshape %ret_332 : tensor<32x16xi32, #blocked> -> tensor<128x2x2xi32, #blocked3> loc(#loc225) + %ileft_337 = arith.muli %y_336, %ileft_86 : tensor<128x2x2xi32, #blocked3> loc(#loc227) + %ileft_338 = "tt.reduce"(%ileft_337) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc228)), %ileft_742: i32 loc(callsite(#loc1 at #loc228))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc318) + tt.reduce.return %ileft_743 : i32 loc(#loc300) + }) : (tensor<128x2x2xi32, #blocked3>) -> tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc300) + %ileft_339 = tt.expand_dims %ileft_338 {axis = 1 : i32} : tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<128x1x2xi32, #blocked3> loc(#loc229) + %ileft_340 = tt.broadcast %ileft_339 : tensor<128x1x2xi32, #blocked3> -> tensor<128x2x2xi32, #blocked3> loc(#loc230) + %iright_341 = arith.muli %y_336, %flip_43 : tensor<128x2x2xi32, #blocked3> loc(#loc231) + %iright_342 = "tt.reduce"(%iright_341) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc232)), %iright_742: i32 loc(callsite(#loc1 at #loc232))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc319) + tt.reduce.return %iright_743 : i32 loc(#loc302) + }) : (tensor<128x2x2xi32, #blocked3>) -> tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc302) + %iright_343 = tt.expand_dims %iright_342 {axis = 1 : i32} : tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<128x1x2xi32, #blocked3> loc(#loc233) + %iright_344 = tt.broadcast %iright_343 : tensor<128x1x2xi32, #blocked3> -> tensor<128x2x2xi32, #blocked3> loc(#loc234) + %ileft_345 = tt.reshape %ileft_340 : tensor<128x2x2xi32, #blocked3> -> tensor<32x16xi32, #blocked> loc(#loc235) + %iright_346 = tt.reshape %iright_344 : tensor<128x2x2xi32, #blocked3> -> tensor<32x16xi32, #blocked> loc(#loc236) + %y_idx_347 = tt.reshape %new_idxs_335 : tensor<32x16xi32, #blocked> -> tensor<128x2x2xi32, #blocked3> loc(#loc237) + %left_idx_348 = arith.muli %y_idx_347, %ileft_86 : tensor<128x2x2xi32, #blocked3> loc(#loc239) + %left_idx_349 = "tt.reduce"(%left_idx_348) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc240)), %left_idx_742: i32 loc(callsite(#loc1 at #loc240))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc320) + tt.reduce.return %left_idx_743 : i32 loc(#loc305) + }) : (tensor<128x2x2xi32, #blocked3>) -> tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc305) + %left_idx_350 = tt.expand_dims %left_idx_349 {axis = 1 : i32} : tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<128x1x2xi32, #blocked3> loc(#loc241) + %left_idx_351 = tt.broadcast %left_idx_350 : tensor<128x1x2xi32, #blocked3> -> tensor<128x2x2xi32, #blocked3> loc(#loc242) + %right_idx_352 = arith.muli %y_idx_347, %flip_43 : tensor<128x2x2xi32, #blocked3> loc(#loc244) + %right_idx_353 = "tt.reduce"(%right_idx_352) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc245)), %right_idx_742: i32 loc(callsite(#loc1 at #loc245))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc321) + tt.reduce.return %right_idx_743 : i32 loc(#loc308) + }) : (tensor<128x2x2xi32, #blocked3>) -> tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc308) + %right_idx_354 = tt.expand_dims %right_idx_353 {axis = 1 : i32} : tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<128x1x2xi32, #blocked3> loc(#loc246) + %right_idx_355 = tt.broadcast %right_idx_354 : tensor<128x1x2xi32, #blocked3> -> tensor<128x2x2xi32, #blocked3> loc(#loc247) + %left_idx_356 = tt.reshape %left_idx_351 : tensor<128x2x2xi32, #blocked3> -> tensor<32x16xi32, #blocked> loc(#loc248) + %right_idx_357 = tt.reshape %right_idx_355 : tensor<128x2x2xi32, #blocked3> -> tensor<32x16xi32, #blocked> loc(#loc249) + %cond_358 = arith.cmpi slt, %ileft_345, %iright_346 : tensor<32x16xi32, #blocked> loc(#loc250) + %eq_359 = arith.cmpi eq, %ileft_345, %iright_346 : tensor<32x16xi32, #blocked> loc(#loc251) + %cond_360 = arith.cmpi sgt, %left_idx_356, %right_idx_357 : tensor<32x16xi32, #blocked> loc(#loc252) + %cond_361 = arith.andi %eq_359, %cond_360 : tensor<32x16xi1, #blocked> loc(#loc253) + %cond_362 = arith.ori %cond_358, %cond_361 : tensor<32x16xi1, #blocked> loc(#loc254) + %ret_363 = arith.xori %ileft_345, %iright_346 : tensor<32x16xi32, #blocked> loc(#loc257) + %ret_364 = arith.select %cond_362, %ret_363, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc258) + %ret_365 = arith.xori %ret_332, %ret_364 : tensor<32x16xi32, #blocked> loc(#loc259) + %new_idxs_366 = arith.xori %left_idx_356, %right_idx_357 : tensor<32x16xi32, #blocked> loc(#loc260) + %new_idxs_367 = arith.select %cond_362, %new_idxs_366, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc261) + %new_idxs_368 = arith.xori %new_idxs_335, %new_idxs_367 : tensor<32x16xi32, #blocked> loc(#loc262) + %y_369 = tt.reshape %ret_365 : tensor<32x16xi32, #blocked> -> tensor<256x2x1xi32, #blocked4> loc(#loc225) + %ileft_370 = arith.muli %y_369, %ileft : tensor<256x2x1xi32, #blocked4> loc(#loc227) + %ileft_371 = "tt.reduce"(%ileft_370) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc228)), %ileft_742: i32 loc(callsite(#loc1 at #loc228))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc318) + tt.reduce.return %ileft_743 : i32 loc(#loc300) + }) : (tensor<256x2x1xi32, #blocked4>) -> tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc300) + %ileft_372 = tt.expand_dims %ileft_371 {axis = 1 : i32} : tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<256x1x1xi32, #blocked4> loc(#loc229) + %ileft_373 = tt.broadcast %ileft_372 : tensor<256x1x1xi32, #blocked4> -> tensor<256x2x1xi32, #blocked4> loc(#loc230) + %iright_374 = arith.muli %y_369, %iright : tensor<256x2x1xi32, #blocked4> loc(#loc231) + %iright_375 = "tt.reduce"(%iright_374) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc232)), %iright_742: i32 loc(callsite(#loc1 at #loc232))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc319) + tt.reduce.return %iright_743 : i32 loc(#loc302) + }) : (tensor<256x2x1xi32, #blocked4>) -> tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc302) + %iright_376 = tt.expand_dims %iright_375 {axis = 1 : i32} : tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<256x1x1xi32, #blocked4> loc(#loc233) + %iright_377 = tt.broadcast %iright_376 : tensor<256x1x1xi32, #blocked4> -> tensor<256x2x1xi32, #blocked4> loc(#loc234) + %ileft_378 = tt.reshape %ileft_373 : tensor<256x2x1xi32, #blocked4> -> tensor<32x16xi32, #blocked> loc(#loc235) + %iright_379 = tt.reshape %iright_377 : tensor<256x2x1xi32, #blocked4> -> tensor<32x16xi32, #blocked> loc(#loc236) + %y_idx_380 = tt.reshape %new_idxs_368 : tensor<32x16xi32, #blocked> -> tensor<256x2x1xi32, #blocked4> loc(#loc237) + %left_idx_381 = arith.muli %y_idx_380, %ileft : tensor<256x2x1xi32, #blocked4> loc(#loc239) + %left_idx_382 = "tt.reduce"(%left_idx_381) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc240)), %left_idx_742: i32 loc(callsite(#loc1 at #loc240))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc320) + tt.reduce.return %left_idx_743 : i32 loc(#loc305) + }) : (tensor<256x2x1xi32, #blocked4>) -> tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc305) + %left_idx_383 = tt.expand_dims %left_idx_382 {axis = 1 : i32} : tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<256x1x1xi32, #blocked4> loc(#loc241) + %left_idx_384 = tt.broadcast %left_idx_383 : tensor<256x1x1xi32, #blocked4> -> tensor<256x2x1xi32, #blocked4> loc(#loc242) + %right_idx_385 = arith.muli %y_idx_380, %iright : tensor<256x2x1xi32, #blocked4> loc(#loc244) + %right_idx_386 = "tt.reduce"(%right_idx_385) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc245)), %right_idx_742: i32 loc(callsite(#loc1 at #loc245))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc321) + tt.reduce.return %right_idx_743 : i32 loc(#loc308) + }) : (tensor<256x2x1xi32, #blocked4>) -> tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc308) + %right_idx_387 = tt.expand_dims %right_idx_386 {axis = 1 : i32} : tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<256x1x1xi32, #blocked4> loc(#loc246) + %right_idx_388 = tt.broadcast %right_idx_387 : tensor<256x1x1xi32, #blocked4> -> tensor<256x2x1xi32, #blocked4> loc(#loc247) + %left_idx_389 = tt.reshape %left_idx_384 : tensor<256x2x1xi32, #blocked4> -> tensor<32x16xi32, #blocked> loc(#loc248) + %right_idx_390 = tt.reshape %right_idx_388 : tensor<256x2x1xi32, #blocked4> -> tensor<32x16xi32, #blocked> loc(#loc249) + %cond_391 = arith.cmpi slt, %ileft_378, %iright_379 : tensor<32x16xi32, #blocked> loc(#loc250) + %eq_392 = arith.cmpi eq, %ileft_378, %iright_379 : tensor<32x16xi32, #blocked> loc(#loc251) + %cond_393 = arith.cmpi sgt, %left_idx_389, %right_idx_390 : tensor<32x16xi32, #blocked> loc(#loc252) + %cond_394 = arith.andi %eq_392, %cond_393 : tensor<32x16xi1, #blocked> loc(#loc253) + %cond_395 = arith.ori %cond_391, %cond_394 : tensor<32x16xi1, #blocked> loc(#loc254) + %new_idxs_396 = arith.xori %left_idx_389, %right_idx_390 : tensor<32x16xi32, #blocked> loc(#loc260) + %new_idxs_397 = arith.select %cond_395, %new_idxs_396, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc261) + %new_idxs_398 = arith.xori %new_idxs_368, %new_idxs_397 : tensor<32x16xi32, #blocked> loc(#loc262) + %tmp14 = arith.cmpi eq, %tmp0_31, %cst_12 : tensor<32x16xi64, #blocked> loc(#loc186) + %tmp16 = arith.extui %tmp14 : tensor<32x16xi1, #blocked> to tensor<32x16xi32, #blocked> loc(#loc217) + %y_399 = tt.reshape %tmp16 : tensor<32x16xi32, #blocked> -> tensor<256x2x1xi32, #blocked4> loc(#loc263) + %ileft_400 = arith.muli %y_399, %ileft : tensor<256x2x1xi32, #blocked4> loc(#loc264) + %ileft_401 = "tt.reduce"(%ileft_400) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc265)), %ileft_742: i32 loc(callsite(#loc1 at #loc265))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc322) + tt.reduce.return %ileft_743 : i32 loc(#loc310) + }) : (tensor<256x2x1xi32, #blocked4>) -> tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc310) + %ileft_402 = tt.expand_dims %ileft_401 {axis = 1 : i32} : tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<256x1x1xi32, #blocked4> loc(#loc266) + %ileft_403 = tt.broadcast %ileft_402 : tensor<256x1x1xi32, #blocked4> -> tensor<256x2x1xi32, #blocked4> loc(#loc267) + %iright_404 = arith.muli %y_399, %iright : tensor<256x2x1xi32, #blocked4> loc(#loc268) + %iright_405 = "tt.reduce"(%iright_404) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc269)), %iright_742: i32 loc(callsite(#loc1 at #loc269))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc323) + tt.reduce.return %iright_743 : i32 loc(#loc312) + }) : (tensor<256x2x1xi32, #blocked4>) -> tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc312) + %iright_406 = tt.expand_dims %iright_405 {axis = 1 : i32} : tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<256x1x1xi32, #blocked4> loc(#loc270) + %iright_407 = tt.broadcast %iright_406 : tensor<256x1x1xi32, #blocked4> -> tensor<256x2x1xi32, #blocked4> loc(#loc271) + %ileft_408 = tt.reshape %ileft_403 : tensor<256x2x1xi32, #blocked4> -> tensor<32x16xi32, #blocked> loc(#loc272) + %iright_409 = tt.reshape %iright_407 : tensor<256x2x1xi32, #blocked4> -> tensor<32x16xi32, #blocked> loc(#loc273) + %cond_410 = arith.cmpi slt, %ileft_408, %iright_409 : tensor<32x16xi32, #blocked> loc(#loc274) + %eq_411 = arith.cmpi eq, %ileft_408, %iright_409 : tensor<32x16xi32, #blocked> loc(#loc275) + %cond_412 = arith.andi %eq_411, %cond_71 : tensor<32x16xi1, #blocked> loc(#loc276) + %cond_413 = arith.ori %cond_410, %cond_412 : tensor<32x16xi1, #blocked> loc(#loc277) + %cond_414 = arith.extui %cond_413 : tensor<32x16xi1, #blocked> to tensor<32x16xi32, #blocked> loc(#loc278) + %cond_415 = arith.xori %cond_414, %flip_44 : tensor<32x16xi32, #blocked> loc(#loc278) + %cond_416 = arith.cmpi ne, %cond_415, %cst_9 : tensor<32x16xi32, #blocked> loc(#loc279) + %ret_417 = arith.xori %ileft_408, %iright_409 : tensor<32x16xi32, #blocked> loc(#loc280) + %ret_418 = arith.select %cond_416, %ret_417, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc281) + %ret_419 = arith.xori %tmp16, %ret_418 : tensor<32x16xi32, #blocked> loc(#loc282) + %new_idxs_420 = arith.select %cond_416, %new_idxs, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc283) + %new_idxs_421 = arith.xori %new_idxs_81, %new_idxs_420 : tensor<32x16xi32, #blocked> loc(#loc284) + %y_422 = tt.reshape %ret_419 : tensor<32x16xi32, #blocked> -> tensor<128x2x2xi32, #blocked3> loc(#loc263) + %ileft_423 = arith.muli %y_422, %ileft_86 : tensor<128x2x2xi32, #blocked3> loc(#loc264) + %ileft_424 = "tt.reduce"(%ileft_423) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc265)), %ileft_742: i32 loc(callsite(#loc1 at #loc265))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc322) + tt.reduce.return %ileft_743 : i32 loc(#loc310) + }) : (tensor<128x2x2xi32, #blocked3>) -> tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc310) + %ileft_425 = tt.expand_dims %ileft_424 {axis = 1 : i32} : tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<128x1x2xi32, #blocked3> loc(#loc266) + %ileft_426 = tt.broadcast %ileft_425 : tensor<128x1x2xi32, #blocked3> -> tensor<128x2x2xi32, #blocked3> loc(#loc267) + %iright_427 = arith.muli %y_422, %flip_43 : tensor<128x2x2xi32, #blocked3> loc(#loc268) + %iright_428 = "tt.reduce"(%iright_427) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc269)), %iright_742: i32 loc(callsite(#loc1 at #loc269))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc323) + tt.reduce.return %iright_743 : i32 loc(#loc312) + }) : (tensor<128x2x2xi32, #blocked3>) -> tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc312) + %iright_429 = tt.expand_dims %iright_428 {axis = 1 : i32} : tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<128x1x2xi32, #blocked3> loc(#loc270) + %iright_430 = tt.broadcast %iright_429 : tensor<128x1x2xi32, #blocked3> -> tensor<128x2x2xi32, #blocked3> loc(#loc271) + %ileft_431 = tt.reshape %ileft_426 : tensor<128x2x2xi32, #blocked3> -> tensor<32x16xi32, #blocked> loc(#loc272) + %iright_432 = tt.reshape %iright_430 : tensor<128x2x2xi32, #blocked3> -> tensor<32x16xi32, #blocked> loc(#loc273) + %y_idx_433 = tt.reshape %new_idxs_421 : tensor<32x16xi32, #blocked> -> tensor<128x2x2xi32, #blocked3> loc(#loc285) + %left_idx_434 = arith.muli %y_idx_433, %ileft_86 : tensor<128x2x2xi32, #blocked3> loc(#loc286) + %left_idx_435 = "tt.reduce"(%left_idx_434) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc287)), %left_idx_742: i32 loc(callsite(#loc1 at #loc287))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc324) + tt.reduce.return %left_idx_743 : i32 loc(#loc314) + }) : (tensor<128x2x2xi32, #blocked3>) -> tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc314) + %left_idx_436 = tt.expand_dims %left_idx_435 {axis = 1 : i32} : tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<128x1x2xi32, #blocked3> loc(#loc288) + %left_idx_437 = tt.broadcast %left_idx_436 : tensor<128x1x2xi32, #blocked3> -> tensor<128x2x2xi32, #blocked3> loc(#loc289) + %right_idx_438 = arith.muli %y_idx_433, %flip_43 : tensor<128x2x2xi32, #blocked3> loc(#loc290) + %right_idx_439 = "tt.reduce"(%right_idx_438) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc291)), %right_idx_742: i32 loc(callsite(#loc1 at #loc291))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc325) + tt.reduce.return %right_idx_743 : i32 loc(#loc316) + }) : (tensor<128x2x2xi32, #blocked3>) -> tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc316) + %right_idx_440 = tt.expand_dims %right_idx_439 {axis = 1 : i32} : tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<128x1x2xi32, #blocked3> loc(#loc292) + %right_idx_441 = tt.broadcast %right_idx_440 : tensor<128x1x2xi32, #blocked3> -> tensor<128x2x2xi32, #blocked3> loc(#loc293) + %left_idx_442 = tt.reshape %left_idx_437 : tensor<128x2x2xi32, #blocked3> -> tensor<32x16xi32, #blocked> loc(#loc294) + %right_idx_443 = tt.reshape %right_idx_441 : tensor<128x2x2xi32, #blocked3> -> tensor<32x16xi32, #blocked> loc(#loc295) + %cond_444 = arith.cmpi slt, %ileft_431, %iright_432 : tensor<32x16xi32, #blocked> loc(#loc274) + %eq_445 = arith.cmpi eq, %ileft_431, %iright_432 : tensor<32x16xi32, #blocked> loc(#loc275) + %cond_446 = arith.cmpi sgt, %left_idx_442, %right_idx_443 : tensor<32x16xi32, #blocked> loc(#loc296) + %cond_447 = arith.andi %eq_445, %cond_446 : tensor<32x16xi1, #blocked> loc(#loc276) + %cond_448 = arith.ori %cond_444, %cond_447 : tensor<32x16xi1, #blocked> loc(#loc277) + %cond_449 = arith.extui %cond_448 : tensor<32x16xi1, #blocked> to tensor<32x16xi32, #blocked> loc(#loc278) + %cond_450 = arith.xori %cond_449, %flip_84 : tensor<32x16xi32, #blocked> loc(#loc278) + %cond_451 = arith.cmpi ne, %cond_450, %cst_9 : tensor<32x16xi32, #blocked> loc(#loc279) + %ret_452 = arith.xori %ileft_431, %iright_432 : tensor<32x16xi32, #blocked> loc(#loc280) + %ret_453 = arith.select %cond_451, %ret_452, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc281) + %ret_454 = arith.xori %ret_419, %ret_453 : tensor<32x16xi32, #blocked> loc(#loc282) + %new_idxs_455 = arith.xori %left_idx_442, %right_idx_443 : tensor<32x16xi32, #blocked> loc(#loc297) + %new_idxs_456 = arith.select %cond_451, %new_idxs_455, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc283) + %new_idxs_457 = arith.xori %new_idxs_421, %new_idxs_456 : tensor<32x16xi32, #blocked> loc(#loc284) + %y_458 = tt.reshape %ret_454 : tensor<32x16xi32, #blocked> -> tensor<256x2x1xi32, #blocked4> loc(#loc263) + %ileft_459 = arith.muli %y_458, %ileft : tensor<256x2x1xi32, #blocked4> loc(#loc264) + %ileft_460 = "tt.reduce"(%ileft_459) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc265)), %ileft_742: i32 loc(callsite(#loc1 at #loc265))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc322) + tt.reduce.return %ileft_743 : i32 loc(#loc310) + }) : (tensor<256x2x1xi32, #blocked4>) -> tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc310) + %ileft_461 = tt.expand_dims %ileft_460 {axis = 1 : i32} : tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<256x1x1xi32, #blocked4> loc(#loc266) + %ileft_462 = tt.broadcast %ileft_461 : tensor<256x1x1xi32, #blocked4> -> tensor<256x2x1xi32, #blocked4> loc(#loc267) + %iright_463 = arith.muli %y_458, %iright : tensor<256x2x1xi32, #blocked4> loc(#loc268) + %iright_464 = "tt.reduce"(%iright_463) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc269)), %iright_742: i32 loc(callsite(#loc1 at #loc269))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc323) + tt.reduce.return %iright_743 : i32 loc(#loc312) + }) : (tensor<256x2x1xi32, #blocked4>) -> tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc312) + %iright_465 = tt.expand_dims %iright_464 {axis = 1 : i32} : tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<256x1x1xi32, #blocked4> loc(#loc270) + %iright_466 = tt.broadcast %iright_465 : tensor<256x1x1xi32, #blocked4> -> tensor<256x2x1xi32, #blocked4> loc(#loc271) + %ileft_467 = tt.reshape %ileft_462 : tensor<256x2x1xi32, #blocked4> -> tensor<32x16xi32, #blocked> loc(#loc272) + %iright_468 = tt.reshape %iright_466 : tensor<256x2x1xi32, #blocked4> -> tensor<32x16xi32, #blocked> loc(#loc273) + %y_idx_469 = tt.reshape %new_idxs_457 : tensor<32x16xi32, #blocked> -> tensor<256x2x1xi32, #blocked4> loc(#loc285) + %left_idx_470 = arith.muli %y_idx_469, %ileft : tensor<256x2x1xi32, #blocked4> loc(#loc286) + %left_idx_471 = "tt.reduce"(%left_idx_470) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc287)), %left_idx_742: i32 loc(callsite(#loc1 at #loc287))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc324) + tt.reduce.return %left_idx_743 : i32 loc(#loc314) + }) : (tensor<256x2x1xi32, #blocked4>) -> tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc314) + %left_idx_472 = tt.expand_dims %left_idx_471 {axis = 1 : i32} : tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<256x1x1xi32, #blocked4> loc(#loc288) + %left_idx_473 = tt.broadcast %left_idx_472 : tensor<256x1x1xi32, #blocked4> -> tensor<256x2x1xi32, #blocked4> loc(#loc289) + %right_idx_474 = arith.muli %y_idx_469, %iright : tensor<256x2x1xi32, #blocked4> loc(#loc290) + %right_idx_475 = "tt.reduce"(%right_idx_474) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc291)), %right_idx_742: i32 loc(callsite(#loc1 at #loc291))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc325) + tt.reduce.return %right_idx_743 : i32 loc(#loc316) + }) : (tensor<256x2x1xi32, #blocked4>) -> tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc316) + %right_idx_476 = tt.expand_dims %right_idx_475 {axis = 1 : i32} : tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<256x1x1xi32, #blocked4> loc(#loc292) + %right_idx_477 = tt.broadcast %right_idx_476 : tensor<256x1x1xi32, #blocked4> -> tensor<256x2x1xi32, #blocked4> loc(#loc293) + %left_idx_478 = tt.reshape %left_idx_473 : tensor<256x2x1xi32, #blocked4> -> tensor<32x16xi32, #blocked> loc(#loc294) + %right_idx_479 = tt.reshape %right_idx_477 : tensor<256x2x1xi32, #blocked4> -> tensor<32x16xi32, #blocked> loc(#loc295) + %cond_480 = arith.cmpi slt, %ileft_467, %iright_468 : tensor<32x16xi32, #blocked> loc(#loc274) + %eq_481 = arith.cmpi eq, %ileft_467, %iright_468 : tensor<32x16xi32, #blocked> loc(#loc275) + %cond_482 = arith.cmpi sgt, %left_idx_478, %right_idx_479 : tensor<32x16xi32, #blocked> loc(#loc296) + %cond_483 = arith.andi %eq_481, %cond_482 : tensor<32x16xi1, #blocked> loc(#loc276) + %cond_484 = arith.ori %cond_480, %cond_483 : tensor<32x16xi1, #blocked> loc(#loc277) + %cond_485 = arith.extui %cond_484 : tensor<32x16xi1, #blocked> to tensor<32x16xi32, #blocked> loc(#loc278) + %cond_486 = arith.xori %cond_485, %flip_84 : tensor<32x16xi32, #blocked> loc(#loc278) + %cond_487 = arith.cmpi ne, %cond_486, %cst_9 : tensor<32x16xi32, #blocked> loc(#loc279) + %ret_488 = arith.xori %ileft_467, %iright_468 : tensor<32x16xi32, #blocked> loc(#loc280) + %ret_489 = arith.select %cond_487, %ret_488, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc281) + %ret_490 = arith.xori %ret_454, %ret_489 : tensor<32x16xi32, #blocked> loc(#loc282) + %new_idxs_491 = arith.xori %left_idx_478, %right_idx_479 : tensor<32x16xi32, #blocked> loc(#loc297) + %new_idxs_492 = arith.select %cond_487, %new_idxs_491, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc283) + %new_idxs_493 = arith.xori %new_idxs_457, %new_idxs_492 : tensor<32x16xi32, #blocked> loc(#loc284) + %y_494 = tt.reshape %ret_490 : tensor<32x16xi32, #blocked> -> tensor<64x2x4xi32, #blocked2> loc(#loc263) + %ileft_495 = arith.muli %y_494, %ileft_161 : tensor<64x2x4xi32, #blocked2> loc(#loc264) + %ileft_496 = "tt.reduce"(%ileft_495) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc265)), %ileft_742: i32 loc(callsite(#loc1 at #loc265))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc322) + tt.reduce.return %ileft_743 : i32 loc(#loc310) + }) : (tensor<64x2x4xi32, #blocked2>) -> tensor<64x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc310) + %ileft_497 = tt.expand_dims %ileft_496 {axis = 1 : i32} : tensor<64x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x4xi32, #blocked2> loc(#loc266) + %ileft_498 = tt.broadcast %ileft_497 : tensor<64x1x4xi32, #blocked2> -> tensor<64x2x4xi32, #blocked2> loc(#loc267) + %iright_499 = arith.muli %y_494, %flip_83 : tensor<64x2x4xi32, #blocked2> loc(#loc268) + %iright_500 = "tt.reduce"(%iright_499) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc269)), %iright_742: i32 loc(callsite(#loc1 at #loc269))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc323) + tt.reduce.return %iright_743 : i32 loc(#loc312) + }) : (tensor<64x2x4xi32, #blocked2>) -> tensor<64x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc312) + %iright_501 = tt.expand_dims %iright_500 {axis = 1 : i32} : tensor<64x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x4xi32, #blocked2> loc(#loc270) + %iright_502 = tt.broadcast %iright_501 : tensor<64x1x4xi32, #blocked2> -> tensor<64x2x4xi32, #blocked2> loc(#loc271) + %ileft_503 = tt.reshape %ileft_498 : tensor<64x2x4xi32, #blocked2> -> tensor<32x16xi32, #blocked> loc(#loc272) + %iright_504 = tt.reshape %iright_502 : tensor<64x2x4xi32, #blocked2> -> tensor<32x16xi32, #blocked> loc(#loc273) + %y_idx_505 = tt.reshape %new_idxs_493 : tensor<32x16xi32, #blocked> -> tensor<64x2x4xi32, #blocked2> loc(#loc285) + %left_idx_506 = arith.muli %y_idx_505, %ileft_161 : tensor<64x2x4xi32, #blocked2> loc(#loc286) + %left_idx_507 = "tt.reduce"(%left_idx_506) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc287)), %left_idx_742: i32 loc(callsite(#loc1 at #loc287))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc324) + tt.reduce.return %left_idx_743 : i32 loc(#loc314) + }) : (tensor<64x2x4xi32, #blocked2>) -> tensor<64x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc314) + %left_idx_508 = tt.expand_dims %left_idx_507 {axis = 1 : i32} : tensor<64x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x4xi32, #blocked2> loc(#loc288) + %left_idx_509 = tt.broadcast %left_idx_508 : tensor<64x1x4xi32, #blocked2> -> tensor<64x2x4xi32, #blocked2> loc(#loc289) + %right_idx_510 = arith.muli %y_idx_505, %flip_83 : tensor<64x2x4xi32, #blocked2> loc(#loc290) + %right_idx_511 = "tt.reduce"(%right_idx_510) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc291)), %right_idx_742: i32 loc(callsite(#loc1 at #loc291))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc325) + tt.reduce.return %right_idx_743 : i32 loc(#loc316) + }) : (tensor<64x2x4xi32, #blocked2>) -> tensor<64x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc316) + %right_idx_512 = tt.expand_dims %right_idx_511 {axis = 1 : i32} : tensor<64x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x4xi32, #blocked2> loc(#loc292) + %right_idx_513 = tt.broadcast %right_idx_512 : tensor<64x1x4xi32, #blocked2> -> tensor<64x2x4xi32, #blocked2> loc(#loc293) + %left_idx_514 = tt.reshape %left_idx_509 : tensor<64x2x4xi32, #blocked2> -> tensor<32x16xi32, #blocked> loc(#loc294) + %right_idx_515 = tt.reshape %right_idx_513 : tensor<64x2x4xi32, #blocked2> -> tensor<32x16xi32, #blocked> loc(#loc295) + %cond_516 = arith.cmpi slt, %ileft_503, %iright_504 : tensor<32x16xi32, #blocked> loc(#loc274) + %eq_517 = arith.cmpi eq, %ileft_503, %iright_504 : tensor<32x16xi32, #blocked> loc(#loc275) + %cond_518 = arith.cmpi sgt, %left_idx_514, %right_idx_515 : tensor<32x16xi32, #blocked> loc(#loc296) + %cond_519 = arith.andi %eq_517, %cond_518 : tensor<32x16xi1, #blocked> loc(#loc276) + %cond_520 = arith.ori %cond_516, %cond_519 : tensor<32x16xi1, #blocked> loc(#loc277) + %cond_521 = arith.extui %cond_520 : tensor<32x16xi1, #blocked> to tensor<32x16xi32, #blocked> loc(#loc278) + %cond_522 = arith.xori %cond_521, %flip_159 : tensor<32x16xi32, #blocked> loc(#loc278) + %cond_523 = arith.cmpi ne, %cond_522, %cst_9 : tensor<32x16xi32, #blocked> loc(#loc279) + %ret_524 = arith.xori %ileft_503, %iright_504 : tensor<32x16xi32, #blocked> loc(#loc280) + %ret_525 = arith.select %cond_523, %ret_524, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc281) + %ret_526 = arith.xori %ret_490, %ret_525 : tensor<32x16xi32, #blocked> loc(#loc282) + %new_idxs_527 = arith.xori %left_idx_514, %right_idx_515 : tensor<32x16xi32, #blocked> loc(#loc297) + %new_idxs_528 = arith.select %cond_523, %new_idxs_527, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc283) + %new_idxs_529 = arith.xori %new_idxs_493, %new_idxs_528 : tensor<32x16xi32, #blocked> loc(#loc284) + %y_530 = tt.reshape %ret_526 : tensor<32x16xi32, #blocked> -> tensor<128x2x2xi32, #blocked3> loc(#loc263) + %ileft_531 = arith.muli %y_530, %ileft_86 : tensor<128x2x2xi32, #blocked3> loc(#loc264) + %ileft_532 = "tt.reduce"(%ileft_531) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc265)), %ileft_742: i32 loc(callsite(#loc1 at #loc265))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc322) + tt.reduce.return %ileft_743 : i32 loc(#loc310) + }) : (tensor<128x2x2xi32, #blocked3>) -> tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc310) + %ileft_533 = tt.expand_dims %ileft_532 {axis = 1 : i32} : tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<128x1x2xi32, #blocked3> loc(#loc266) + %ileft_534 = tt.broadcast %ileft_533 : tensor<128x1x2xi32, #blocked3> -> tensor<128x2x2xi32, #blocked3> loc(#loc267) + %iright_535 = arith.muli %y_530, %flip_43 : tensor<128x2x2xi32, #blocked3> loc(#loc268) + %iright_536 = "tt.reduce"(%iright_535) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc269)), %iright_742: i32 loc(callsite(#loc1 at #loc269))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc323) + tt.reduce.return %iright_743 : i32 loc(#loc312) + }) : (tensor<128x2x2xi32, #blocked3>) -> tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc312) + %iright_537 = tt.expand_dims %iright_536 {axis = 1 : i32} : tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<128x1x2xi32, #blocked3> loc(#loc270) + %iright_538 = tt.broadcast %iright_537 : tensor<128x1x2xi32, #blocked3> -> tensor<128x2x2xi32, #blocked3> loc(#loc271) + %ileft_539 = tt.reshape %ileft_534 : tensor<128x2x2xi32, #blocked3> -> tensor<32x16xi32, #blocked> loc(#loc272) + %iright_540 = tt.reshape %iright_538 : tensor<128x2x2xi32, #blocked3> -> tensor<32x16xi32, #blocked> loc(#loc273) + %y_idx_541 = tt.reshape %new_idxs_529 : tensor<32x16xi32, #blocked> -> tensor<128x2x2xi32, #blocked3> loc(#loc285) + %left_idx_542 = arith.muli %y_idx_541, %ileft_86 : tensor<128x2x2xi32, #blocked3> loc(#loc286) + %left_idx_543 = "tt.reduce"(%left_idx_542) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc287)), %left_idx_742: i32 loc(callsite(#loc1 at #loc287))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc324) + tt.reduce.return %left_idx_743 : i32 loc(#loc314) + }) : (tensor<128x2x2xi32, #blocked3>) -> tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc314) + %left_idx_544 = tt.expand_dims %left_idx_543 {axis = 1 : i32} : tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<128x1x2xi32, #blocked3> loc(#loc288) + %left_idx_545 = tt.broadcast %left_idx_544 : tensor<128x1x2xi32, #blocked3> -> tensor<128x2x2xi32, #blocked3> loc(#loc289) + %right_idx_546 = arith.muli %y_idx_541, %flip_43 : tensor<128x2x2xi32, #blocked3> loc(#loc290) + %right_idx_547 = "tt.reduce"(%right_idx_546) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc291)), %right_idx_742: i32 loc(callsite(#loc1 at #loc291))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc325) + tt.reduce.return %right_idx_743 : i32 loc(#loc316) + }) : (tensor<128x2x2xi32, #blocked3>) -> tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc316) + %right_idx_548 = tt.expand_dims %right_idx_547 {axis = 1 : i32} : tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<128x1x2xi32, #blocked3> loc(#loc292) + %right_idx_549 = tt.broadcast %right_idx_548 : tensor<128x1x2xi32, #blocked3> -> tensor<128x2x2xi32, #blocked3> loc(#loc293) + %left_idx_550 = tt.reshape %left_idx_545 : tensor<128x2x2xi32, #blocked3> -> tensor<32x16xi32, #blocked> loc(#loc294) + %right_idx_551 = tt.reshape %right_idx_549 : tensor<128x2x2xi32, #blocked3> -> tensor<32x16xi32, #blocked> loc(#loc295) + %cond_552 = arith.cmpi slt, %ileft_539, %iright_540 : tensor<32x16xi32, #blocked> loc(#loc274) + %eq_553 = arith.cmpi eq, %ileft_539, %iright_540 : tensor<32x16xi32, #blocked> loc(#loc275) + %cond_554 = arith.cmpi sgt, %left_idx_550, %right_idx_551 : tensor<32x16xi32, #blocked> loc(#loc296) + %cond_555 = arith.andi %eq_553, %cond_554 : tensor<32x16xi1, #blocked> loc(#loc276) + %cond_556 = arith.ori %cond_552, %cond_555 : tensor<32x16xi1, #blocked> loc(#loc277) + %cond_557 = arith.extui %cond_556 : tensor<32x16xi1, #blocked> to tensor<32x16xi32, #blocked> loc(#loc278) + %cond_558 = arith.xori %cond_557, %flip_159 : tensor<32x16xi32, #blocked> loc(#loc278) + %cond_559 = arith.cmpi ne, %cond_558, %cst_9 : tensor<32x16xi32, #blocked> loc(#loc279) + %ret_560 = arith.xori %ileft_539, %iright_540 : tensor<32x16xi32, #blocked> loc(#loc280) + %ret_561 = arith.select %cond_559, %ret_560, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc281) + %ret_562 = arith.xori %ret_526, %ret_561 : tensor<32x16xi32, #blocked> loc(#loc282) + %new_idxs_563 = arith.xori %left_idx_550, %right_idx_551 : tensor<32x16xi32, #blocked> loc(#loc297) + %new_idxs_564 = arith.select %cond_559, %new_idxs_563, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc283) + %new_idxs_565 = arith.xori %new_idxs_529, %new_idxs_564 : tensor<32x16xi32, #blocked> loc(#loc284) + %y_566 = tt.reshape %ret_562 : tensor<32x16xi32, #blocked> -> tensor<256x2x1xi32, #blocked4> loc(#loc263) + %ileft_567 = arith.muli %y_566, %ileft : tensor<256x2x1xi32, #blocked4> loc(#loc264) + %ileft_568 = "tt.reduce"(%ileft_567) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc265)), %ileft_742: i32 loc(callsite(#loc1 at #loc265))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc322) + tt.reduce.return %ileft_743 : i32 loc(#loc310) + }) : (tensor<256x2x1xi32, #blocked4>) -> tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc310) + %ileft_569 = tt.expand_dims %ileft_568 {axis = 1 : i32} : tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<256x1x1xi32, #blocked4> loc(#loc266) + %ileft_570 = tt.broadcast %ileft_569 : tensor<256x1x1xi32, #blocked4> -> tensor<256x2x1xi32, #blocked4> loc(#loc267) + %iright_571 = arith.muli %y_566, %iright : tensor<256x2x1xi32, #blocked4> loc(#loc268) + %iright_572 = "tt.reduce"(%iright_571) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc269)), %iright_742: i32 loc(callsite(#loc1 at #loc269))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc323) + tt.reduce.return %iright_743 : i32 loc(#loc312) + }) : (tensor<256x2x1xi32, #blocked4>) -> tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc312) + %iright_573 = tt.expand_dims %iright_572 {axis = 1 : i32} : tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<256x1x1xi32, #blocked4> loc(#loc270) + %iright_574 = tt.broadcast %iright_573 : tensor<256x1x1xi32, #blocked4> -> tensor<256x2x1xi32, #blocked4> loc(#loc271) + %ileft_575 = tt.reshape %ileft_570 : tensor<256x2x1xi32, #blocked4> -> tensor<32x16xi32, #blocked> loc(#loc272) + %iright_576 = tt.reshape %iright_574 : tensor<256x2x1xi32, #blocked4> -> tensor<32x16xi32, #blocked> loc(#loc273) + %y_idx_577 = tt.reshape %new_idxs_565 : tensor<32x16xi32, #blocked> -> tensor<256x2x1xi32, #blocked4> loc(#loc285) + %left_idx_578 = arith.muli %y_idx_577, %ileft : tensor<256x2x1xi32, #blocked4> loc(#loc286) + %left_idx_579 = "tt.reduce"(%left_idx_578) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc287)), %left_idx_742: i32 loc(callsite(#loc1 at #loc287))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc324) + tt.reduce.return %left_idx_743 : i32 loc(#loc314) + }) : (tensor<256x2x1xi32, #blocked4>) -> tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc314) + %left_idx_580 = tt.expand_dims %left_idx_579 {axis = 1 : i32} : tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<256x1x1xi32, #blocked4> loc(#loc288) + %left_idx_581 = tt.broadcast %left_idx_580 : tensor<256x1x1xi32, #blocked4> -> tensor<256x2x1xi32, #blocked4> loc(#loc289) + %right_idx_582 = arith.muli %y_idx_577, %iright : tensor<256x2x1xi32, #blocked4> loc(#loc290) + %right_idx_583 = "tt.reduce"(%right_idx_582) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc291)), %right_idx_742: i32 loc(callsite(#loc1 at #loc291))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc325) + tt.reduce.return %right_idx_743 : i32 loc(#loc316) + }) : (tensor<256x2x1xi32, #blocked4>) -> tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc316) + %right_idx_584 = tt.expand_dims %right_idx_583 {axis = 1 : i32} : tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<256x1x1xi32, #blocked4> loc(#loc292) + %right_idx_585 = tt.broadcast %right_idx_584 : tensor<256x1x1xi32, #blocked4> -> tensor<256x2x1xi32, #blocked4> loc(#loc293) + %left_idx_586 = tt.reshape %left_idx_581 : tensor<256x2x1xi32, #blocked4> -> tensor<32x16xi32, #blocked> loc(#loc294) + %right_idx_587 = tt.reshape %right_idx_585 : tensor<256x2x1xi32, #blocked4> -> tensor<32x16xi32, #blocked> loc(#loc295) + %cond_588 = arith.cmpi slt, %ileft_575, %iright_576 : tensor<32x16xi32, #blocked> loc(#loc274) + %eq_589 = arith.cmpi eq, %ileft_575, %iright_576 : tensor<32x16xi32, #blocked> loc(#loc275) + %cond_590 = arith.cmpi sgt, %left_idx_586, %right_idx_587 : tensor<32x16xi32, #blocked> loc(#loc296) + %cond_591 = arith.andi %eq_589, %cond_590 : tensor<32x16xi1, #blocked> loc(#loc276) + %cond_592 = arith.ori %cond_588, %cond_591 : tensor<32x16xi1, #blocked> loc(#loc277) + %cond_593 = arith.extui %cond_592 : tensor<32x16xi1, #blocked> to tensor<32x16xi32, #blocked> loc(#loc278) + %cond_594 = arith.xori %cond_593, %flip_159 : tensor<32x16xi32, #blocked> loc(#loc278) + %cond_595 = arith.cmpi ne, %cond_594, %cst_9 : tensor<32x16xi32, #blocked> loc(#loc279) + %ret_596 = arith.xori %ileft_575, %iright_576 : tensor<32x16xi32, #blocked> loc(#loc280) + %ret_597 = arith.select %cond_595, %ret_596, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc281) + %ret_598 = arith.xori %ret_562, %ret_597 : tensor<32x16xi32, #blocked> loc(#loc282) + %new_idxs_599 = arith.xori %left_idx_586, %right_idx_587 : tensor<32x16xi32, #blocked> loc(#loc297) + %new_idxs_600 = arith.select %cond_595, %new_idxs_599, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc283) + %new_idxs_601 = arith.xori %new_idxs_565, %new_idxs_600 : tensor<32x16xi32, #blocked> loc(#loc284) + %y_602 = tt.reshape %ret_598 : tensor<32x16xi32, #blocked> -> tensor<32x2x8xi32, #blocked1> loc(#loc263) + %ileft_603 = arith.muli %y_602, %ileft_270 : tensor<32x2x8xi32, #blocked1> loc(#loc264) + %ileft_604 = "tt.reduce"(%ileft_603) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc265)), %ileft_742: i32 loc(callsite(#loc1 at #loc265))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc322) + tt.reduce.return %ileft_743 : i32 loc(#loc310) + }) : (tensor<32x2x8xi32, #blocked1>) -> tensor<32x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc310) + %ileft_605 = tt.expand_dims %ileft_604 {axis = 1 : i32} : tensor<32x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1x8xi32, #blocked1> loc(#loc266) + %ileft_606 = tt.broadcast %ileft_605 : tensor<32x1x8xi32, #blocked1> -> tensor<32x2x8xi32, #blocked1> loc(#loc267) + %iright_607 = arith.muli %y_602, %flip_158 : tensor<32x2x8xi32, #blocked1> loc(#loc268) + %iright_608 = "tt.reduce"(%iright_607) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc269)), %iright_742: i32 loc(callsite(#loc1 at #loc269))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc323) + tt.reduce.return %iright_743 : i32 loc(#loc312) + }) : (tensor<32x2x8xi32, #blocked1>) -> tensor<32x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc312) + %iright_609 = tt.expand_dims %iright_608 {axis = 1 : i32} : tensor<32x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1x8xi32, #blocked1> loc(#loc270) + %iright_610 = tt.broadcast %iright_609 : tensor<32x1x8xi32, #blocked1> -> tensor<32x2x8xi32, #blocked1> loc(#loc271) + %ileft_611 = tt.reshape %ileft_606 : tensor<32x2x8xi32, #blocked1> -> tensor<32x16xi32, #blocked> loc(#loc272) + %iright_612 = tt.reshape %iright_610 : tensor<32x2x8xi32, #blocked1> -> tensor<32x16xi32, #blocked> loc(#loc273) + %y_idx_613 = tt.reshape %new_idxs_601 : tensor<32x16xi32, #blocked> -> tensor<32x2x8xi32, #blocked1> loc(#loc285) + %left_idx_614 = arith.muli %y_idx_613, %ileft_270 : tensor<32x2x8xi32, #blocked1> loc(#loc286) + %left_idx_615 = "tt.reduce"(%left_idx_614) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc287)), %left_idx_742: i32 loc(callsite(#loc1 at #loc287))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc324) + tt.reduce.return %left_idx_743 : i32 loc(#loc314) + }) : (tensor<32x2x8xi32, #blocked1>) -> tensor<32x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc314) + %left_idx_616 = tt.expand_dims %left_idx_615 {axis = 1 : i32} : tensor<32x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1x8xi32, #blocked1> loc(#loc288) + %left_idx_617 = tt.broadcast %left_idx_616 : tensor<32x1x8xi32, #blocked1> -> tensor<32x2x8xi32, #blocked1> loc(#loc289) + %right_idx_618 = arith.muli %y_idx_613, %flip_158 : tensor<32x2x8xi32, #blocked1> loc(#loc290) + %right_idx_619 = "tt.reduce"(%right_idx_618) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc291)), %right_idx_742: i32 loc(callsite(#loc1 at #loc291))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc325) + tt.reduce.return %right_idx_743 : i32 loc(#loc316) + }) : (tensor<32x2x8xi32, #blocked1>) -> tensor<32x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc316) + %right_idx_620 = tt.expand_dims %right_idx_619 {axis = 1 : i32} : tensor<32x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1x8xi32, #blocked1> loc(#loc292) + %right_idx_621 = tt.broadcast %right_idx_620 : tensor<32x1x8xi32, #blocked1> -> tensor<32x2x8xi32, #blocked1> loc(#loc293) + %left_idx_622 = tt.reshape %left_idx_617 : tensor<32x2x8xi32, #blocked1> -> tensor<32x16xi32, #blocked> loc(#loc294) + %right_idx_623 = tt.reshape %right_idx_621 : tensor<32x2x8xi32, #blocked1> -> tensor<32x16xi32, #blocked> loc(#loc295) + %cond_624 = arith.cmpi slt, %ileft_611, %iright_612 : tensor<32x16xi32, #blocked> loc(#loc274) + %eq_625 = arith.cmpi eq, %ileft_611, %iright_612 : tensor<32x16xi32, #blocked> loc(#loc275) + %cond_626 = arith.cmpi sgt, %left_idx_622, %right_idx_623 : tensor<32x16xi32, #blocked> loc(#loc296) + %cond_627 = arith.andi %eq_625, %cond_626 : tensor<32x16xi1, #blocked> loc(#loc276) + %cond_628 = arith.ori %cond_624, %cond_627 : tensor<32x16xi1, #blocked> loc(#loc277) + %ret_629 = arith.xori %ileft_611, %iright_612 : tensor<32x16xi32, #blocked> loc(#loc280) + %ret_630 = arith.select %cond_628, %ret_629, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc281) + %ret_631 = arith.xori %ret_598, %ret_630 : tensor<32x16xi32, #blocked> loc(#loc282) + %new_idxs_632 = arith.xori %left_idx_622, %right_idx_623 : tensor<32x16xi32, #blocked> loc(#loc297) + %new_idxs_633 = arith.select %cond_628, %new_idxs_632, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc283) + %new_idxs_634 = arith.xori %new_idxs_601, %new_idxs_633 : tensor<32x16xi32, #blocked> loc(#loc284) + %y_635 = tt.reshape %ret_631 : tensor<32x16xi32, #blocked> -> tensor<64x2x4xi32, #blocked2> loc(#loc263) + %ileft_636 = arith.muli %y_635, %ileft_161 : tensor<64x2x4xi32, #blocked2> loc(#loc264) + %ileft_637 = "tt.reduce"(%ileft_636) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc265)), %ileft_742: i32 loc(callsite(#loc1 at #loc265))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc322) + tt.reduce.return %ileft_743 : i32 loc(#loc310) + }) : (tensor<64x2x4xi32, #blocked2>) -> tensor<64x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc310) + %ileft_638 = tt.expand_dims %ileft_637 {axis = 1 : i32} : tensor<64x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x4xi32, #blocked2> loc(#loc266) + %ileft_639 = tt.broadcast %ileft_638 : tensor<64x1x4xi32, #blocked2> -> tensor<64x2x4xi32, #blocked2> loc(#loc267) + %iright_640 = arith.muli %y_635, %flip_83 : tensor<64x2x4xi32, #blocked2> loc(#loc268) + %iright_641 = "tt.reduce"(%iright_640) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc269)), %iright_742: i32 loc(callsite(#loc1 at #loc269))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc323) + tt.reduce.return %iright_743 : i32 loc(#loc312) + }) : (tensor<64x2x4xi32, #blocked2>) -> tensor<64x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc312) + %iright_642 = tt.expand_dims %iright_641 {axis = 1 : i32} : tensor<64x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x4xi32, #blocked2> loc(#loc270) + %iright_643 = tt.broadcast %iright_642 : tensor<64x1x4xi32, #blocked2> -> tensor<64x2x4xi32, #blocked2> loc(#loc271) + %ileft_644 = tt.reshape %ileft_639 : tensor<64x2x4xi32, #blocked2> -> tensor<32x16xi32, #blocked> loc(#loc272) + %iright_645 = tt.reshape %iright_643 : tensor<64x2x4xi32, #blocked2> -> tensor<32x16xi32, #blocked> loc(#loc273) + %y_idx_646 = tt.reshape %new_idxs_634 : tensor<32x16xi32, #blocked> -> tensor<64x2x4xi32, #blocked2> loc(#loc285) + %left_idx_647 = arith.muli %y_idx_646, %ileft_161 : tensor<64x2x4xi32, #blocked2> loc(#loc286) + %left_idx_648 = "tt.reduce"(%left_idx_647) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc287)), %left_idx_742: i32 loc(callsite(#loc1 at #loc287))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc324) + tt.reduce.return %left_idx_743 : i32 loc(#loc314) + }) : (tensor<64x2x4xi32, #blocked2>) -> tensor<64x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc314) + %left_idx_649 = tt.expand_dims %left_idx_648 {axis = 1 : i32} : tensor<64x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x4xi32, #blocked2> loc(#loc288) + %left_idx_650 = tt.broadcast %left_idx_649 : tensor<64x1x4xi32, #blocked2> -> tensor<64x2x4xi32, #blocked2> loc(#loc289) + %right_idx_651 = arith.muli %y_idx_646, %flip_83 : tensor<64x2x4xi32, #blocked2> loc(#loc290) + %right_idx_652 = "tt.reduce"(%right_idx_651) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc291)), %right_idx_742: i32 loc(callsite(#loc1 at #loc291))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc325) + tt.reduce.return %right_idx_743 : i32 loc(#loc316) + }) : (tensor<64x2x4xi32, #blocked2>) -> tensor<64x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc316) + %right_idx_653 = tt.expand_dims %right_idx_652 {axis = 1 : i32} : tensor<64x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x4xi32, #blocked2> loc(#loc292) + %right_idx_654 = tt.broadcast %right_idx_653 : tensor<64x1x4xi32, #blocked2> -> tensor<64x2x4xi32, #blocked2> loc(#loc293) + %left_idx_655 = tt.reshape %left_idx_650 : tensor<64x2x4xi32, #blocked2> -> tensor<32x16xi32, #blocked> loc(#loc294) + %right_idx_656 = tt.reshape %right_idx_654 : tensor<64x2x4xi32, #blocked2> -> tensor<32x16xi32, #blocked> loc(#loc295) + %cond_657 = arith.cmpi slt, %ileft_644, %iright_645 : tensor<32x16xi32, #blocked> loc(#loc274) + %eq_658 = arith.cmpi eq, %ileft_644, %iright_645 : tensor<32x16xi32, #blocked> loc(#loc275) + %cond_659 = arith.cmpi sgt, %left_idx_655, %right_idx_656 : tensor<32x16xi32, #blocked> loc(#loc296) + %cond_660 = arith.andi %eq_658, %cond_659 : tensor<32x16xi1, #blocked> loc(#loc276) + %cond_661 = arith.ori %cond_657, %cond_660 : tensor<32x16xi1, #blocked> loc(#loc277) + %ret_662 = arith.xori %ileft_644, %iright_645 : tensor<32x16xi32, #blocked> loc(#loc280) + %ret_663 = arith.select %cond_661, %ret_662, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc281) + %ret_664 = arith.xori %ret_631, %ret_663 : tensor<32x16xi32, #blocked> loc(#loc282) + %new_idxs_665 = arith.xori %left_idx_655, %right_idx_656 : tensor<32x16xi32, #blocked> loc(#loc297) + %new_idxs_666 = arith.select %cond_661, %new_idxs_665, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc283) + %new_idxs_667 = arith.xori %new_idxs_634, %new_idxs_666 : tensor<32x16xi32, #blocked> loc(#loc284) + %y_668 = tt.reshape %ret_664 : tensor<32x16xi32, #blocked> -> tensor<128x2x2xi32, #blocked3> loc(#loc263) + %ileft_669 = arith.muli %y_668, %ileft_86 : tensor<128x2x2xi32, #blocked3> loc(#loc264) + %ileft_670 = "tt.reduce"(%ileft_669) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc265)), %ileft_742: i32 loc(callsite(#loc1 at #loc265))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc322) + tt.reduce.return %ileft_743 : i32 loc(#loc310) + }) : (tensor<128x2x2xi32, #blocked3>) -> tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc310) + %ileft_671 = tt.expand_dims %ileft_670 {axis = 1 : i32} : tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<128x1x2xi32, #blocked3> loc(#loc266) + %ileft_672 = tt.broadcast %ileft_671 : tensor<128x1x2xi32, #blocked3> -> tensor<128x2x2xi32, #blocked3> loc(#loc267) + %iright_673 = arith.muli %y_668, %flip_43 : tensor<128x2x2xi32, #blocked3> loc(#loc268) + %iright_674 = "tt.reduce"(%iright_673) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc269)), %iright_742: i32 loc(callsite(#loc1 at #loc269))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc323) + tt.reduce.return %iright_743 : i32 loc(#loc312) + }) : (tensor<128x2x2xi32, #blocked3>) -> tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc312) + %iright_675 = tt.expand_dims %iright_674 {axis = 1 : i32} : tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<128x1x2xi32, #blocked3> loc(#loc270) + %iright_676 = tt.broadcast %iright_675 : tensor<128x1x2xi32, #blocked3> -> tensor<128x2x2xi32, #blocked3> loc(#loc271) + %ileft_677 = tt.reshape %ileft_672 : tensor<128x2x2xi32, #blocked3> -> tensor<32x16xi32, #blocked> loc(#loc272) + %iright_678 = tt.reshape %iright_676 : tensor<128x2x2xi32, #blocked3> -> tensor<32x16xi32, #blocked> loc(#loc273) + %y_idx_679 = tt.reshape %new_idxs_667 : tensor<32x16xi32, #blocked> -> tensor<128x2x2xi32, #blocked3> loc(#loc285) + %left_idx_680 = arith.muli %y_idx_679, %ileft_86 : tensor<128x2x2xi32, #blocked3> loc(#loc286) + %left_idx_681 = "tt.reduce"(%left_idx_680) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc287)), %left_idx_742: i32 loc(callsite(#loc1 at #loc287))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc324) + tt.reduce.return %left_idx_743 : i32 loc(#loc314) + }) : (tensor<128x2x2xi32, #blocked3>) -> tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc314) + %left_idx_682 = tt.expand_dims %left_idx_681 {axis = 1 : i32} : tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<128x1x2xi32, #blocked3> loc(#loc288) + %left_idx_683 = tt.broadcast %left_idx_682 : tensor<128x1x2xi32, #blocked3> -> tensor<128x2x2xi32, #blocked3> loc(#loc289) + %right_idx_684 = arith.muli %y_idx_679, %flip_43 : tensor<128x2x2xi32, #blocked3> loc(#loc290) + %right_idx_685 = "tt.reduce"(%right_idx_684) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc291)), %right_idx_742: i32 loc(callsite(#loc1 at #loc291))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc325) + tt.reduce.return %right_idx_743 : i32 loc(#loc316) + }) : (tensor<128x2x2xi32, #blocked3>) -> tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc316) + %right_idx_686 = tt.expand_dims %right_idx_685 {axis = 1 : i32} : tensor<128x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<128x1x2xi32, #blocked3> loc(#loc292) + %right_idx_687 = tt.broadcast %right_idx_686 : tensor<128x1x2xi32, #blocked3> -> tensor<128x2x2xi32, #blocked3> loc(#loc293) + %left_idx_688 = tt.reshape %left_idx_683 : tensor<128x2x2xi32, #blocked3> -> tensor<32x16xi32, #blocked> loc(#loc294) + %right_idx_689 = tt.reshape %right_idx_687 : tensor<128x2x2xi32, #blocked3> -> tensor<32x16xi32, #blocked> loc(#loc295) + %cond_690 = arith.cmpi slt, %ileft_677, %iright_678 : tensor<32x16xi32, #blocked> loc(#loc274) + %eq_691 = arith.cmpi eq, %ileft_677, %iright_678 : tensor<32x16xi32, #blocked> loc(#loc275) + %cond_692 = arith.cmpi sgt, %left_idx_688, %right_idx_689 : tensor<32x16xi32, #blocked> loc(#loc296) + %cond_693 = arith.andi %eq_691, %cond_692 : tensor<32x16xi1, #blocked> loc(#loc276) + %cond_694 = arith.ori %cond_690, %cond_693 : tensor<32x16xi1, #blocked> loc(#loc277) + %ret_695 = arith.xori %ileft_677, %iright_678 : tensor<32x16xi32, #blocked> loc(#loc280) + %ret_696 = arith.select %cond_694, %ret_695, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc281) + %ret_697 = arith.xori %ret_664, %ret_696 : tensor<32x16xi32, #blocked> loc(#loc282) + %new_idxs_698 = arith.xori %left_idx_688, %right_idx_689 : tensor<32x16xi32, #blocked> loc(#loc297) + %new_idxs_699 = arith.select %cond_694, %new_idxs_698, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc283) + %new_idxs_700 = arith.xori %new_idxs_667, %new_idxs_699 : tensor<32x16xi32, #blocked> loc(#loc284) + %y_701 = tt.reshape %ret_697 : tensor<32x16xi32, #blocked> -> tensor<256x2x1xi32, #blocked4> loc(#loc263) + %ileft_702 = arith.muli %y_701, %ileft : tensor<256x2x1xi32, #blocked4> loc(#loc264) + %ileft_703 = "tt.reduce"(%ileft_702) <{axis = 1 : i32}> ({ + ^bb0(%ileft_741: i32 loc(callsite(#loc1 at #loc265)), %ileft_742: i32 loc(callsite(#loc1 at #loc265))): + %ileft_743 = arith.addi %ileft_741, %ileft_742 : i32 loc(#loc322) + tt.reduce.return %ileft_743 : i32 loc(#loc310) + }) : (tensor<256x2x1xi32, #blocked4>) -> tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc310) + %ileft_704 = tt.expand_dims %ileft_703 {axis = 1 : i32} : tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<256x1x1xi32, #blocked4> loc(#loc266) + %ileft_705 = tt.broadcast %ileft_704 : tensor<256x1x1xi32, #blocked4> -> tensor<256x2x1xi32, #blocked4> loc(#loc267) + %iright_706 = arith.muli %y_701, %iright : tensor<256x2x1xi32, #blocked4> loc(#loc268) + %iright_707 = "tt.reduce"(%iright_706) <{axis = 1 : i32}> ({ + ^bb0(%iright_741: i32 loc(callsite(#loc1 at #loc269)), %iright_742: i32 loc(callsite(#loc1 at #loc269))): + %iright_743 = arith.addi %iright_741, %iright_742 : i32 loc(#loc323) + tt.reduce.return %iright_743 : i32 loc(#loc312) + }) : (tensor<256x2x1xi32, #blocked4>) -> tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc312) + %iright_708 = tt.expand_dims %iright_707 {axis = 1 : i32} : tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<256x1x1xi32, #blocked4> loc(#loc270) + %iright_709 = tt.broadcast %iright_708 : tensor<256x1x1xi32, #blocked4> -> tensor<256x2x1xi32, #blocked4> loc(#loc271) + %ileft_710 = tt.reshape %ileft_705 : tensor<256x2x1xi32, #blocked4> -> tensor<32x16xi32, #blocked> loc(#loc272) + %iright_711 = tt.reshape %iright_709 : tensor<256x2x1xi32, #blocked4> -> tensor<32x16xi32, #blocked> loc(#loc273) + %y_idx_712 = tt.reshape %new_idxs_700 : tensor<32x16xi32, #blocked> -> tensor<256x2x1xi32, #blocked4> loc(#loc285) + %left_idx_713 = arith.muli %y_idx_712, %ileft : tensor<256x2x1xi32, #blocked4> loc(#loc286) + %left_idx_714 = "tt.reduce"(%left_idx_713) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_741: i32 loc(callsite(#loc1 at #loc287)), %left_idx_742: i32 loc(callsite(#loc1 at #loc287))): + %left_idx_743 = arith.addi %left_idx_741, %left_idx_742 : i32 loc(#loc324) + tt.reduce.return %left_idx_743 : i32 loc(#loc314) + }) : (tensor<256x2x1xi32, #blocked4>) -> tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc314) + %left_idx_715 = tt.expand_dims %left_idx_714 {axis = 1 : i32} : tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<256x1x1xi32, #blocked4> loc(#loc288) + %left_idx_716 = tt.broadcast %left_idx_715 : tensor<256x1x1xi32, #blocked4> -> tensor<256x2x1xi32, #blocked4> loc(#loc289) + %right_idx_717 = arith.muli %y_idx_712, %iright : tensor<256x2x1xi32, #blocked4> loc(#loc290) + %right_idx_718 = "tt.reduce"(%right_idx_717) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_741: i32 loc(callsite(#loc1 at #loc291)), %right_idx_742: i32 loc(callsite(#loc1 at #loc291))): + %right_idx_743 = arith.addi %right_idx_741, %right_idx_742 : i32 loc(#loc325) + tt.reduce.return %right_idx_743 : i32 loc(#loc316) + }) : (tensor<256x2x1xi32, #blocked4>) -> tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc316) + %right_idx_719 = tt.expand_dims %right_idx_718 {axis = 1 : i32} : tensor<256x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<256x1x1xi32, #blocked4> loc(#loc292) + %right_idx_720 = tt.broadcast %right_idx_719 : tensor<256x1x1xi32, #blocked4> -> tensor<256x2x1xi32, #blocked4> loc(#loc293) + %left_idx_721 = tt.reshape %left_idx_716 : tensor<256x2x1xi32, #blocked4> -> tensor<32x16xi32, #blocked> loc(#loc294) + %right_idx_722 = tt.reshape %right_idx_720 : tensor<256x2x1xi32, #blocked4> -> tensor<32x16xi32, #blocked> loc(#loc295) + %cond_723 = arith.cmpi slt, %ileft_710, %iright_711 : tensor<32x16xi32, #blocked> loc(#loc274) + %eq_724 = arith.cmpi eq, %ileft_710, %iright_711 : tensor<32x16xi32, #blocked> loc(#loc275) + %cond_725 = arith.cmpi sgt, %left_idx_721, %right_idx_722 : tensor<32x16xi32, #blocked> loc(#loc296) + %cond_726 = arith.andi %eq_724, %cond_725 : tensor<32x16xi1, #blocked> loc(#loc276) + %cond_727 = arith.ori %cond_723, %cond_726 : tensor<32x16xi1, #blocked> loc(#loc277) + %new_idxs_728 = arith.xori %left_idx_721, %right_idx_722 : tensor<32x16xi32, #blocked> loc(#loc297) + %new_idxs_729 = arith.select %cond_727, %new_idxs_728, %cst_9 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc283) + %new_idxs_730 = arith.xori %new_idxs_700, %new_idxs_729 : tensor<32x16xi32, #blocked> loc(#loc284) + %tmp20 = arith.extui %tmp5 : tensor<32x16xi1, #blocked> to tensor<32x16xi64, #blocked> loc(#loc219) + %tmp23 = arith.select %tmp0_29, %tmp20, %cst_13 : tensor<32x16xi1, #blocked>, tensor<32x16xi64, #blocked> loc(#loc191) + %tmp24 = "tt.reduce"(%tmp23) <{axis = 1 : i32}> ({ + ^bb0(%tmp24_741: i64 loc(callsite(#loc1 at #loc192)), %tmp24_742: i64 loc(callsite(#loc1 at #loc192))): + %tmp24_743 = arith.addi %tmp24_741, %tmp24_742 : i64 loc(#loc298) + tt.reduce.return %tmp24_743 : i64 loc(#loc220) + }) : (tensor<32x16xi64, #blocked>) -> tensor<32xi64, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc220) + %tmp30 = ttg.convert_layout %tmp24 : tensor<32xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xi64, #ttg.slice<{dim = 1, parent = #blocked5}>> loc(#loc193) + %tmp24_731 = tt.expand_dims %tmp30 {axis = 1 : i32} : tensor<32xi64, #ttg.slice<{dim = 1, parent = #blocked5}>> -> tensor<32x1xi64, #blocked5> loc(#loc194) + %tmp24_732 = tt.expand_dims %tmp24 {axis = 1 : i32} : tensor<32xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi64, #blocked> loc(#loc194) + %tmp25 = arith.extui %tmp14 : tensor<32x16xi1, #blocked> to tensor<32x16xi64, #blocked> loc(#loc222) + %tmp28 = arith.select %tmp0_29, %tmp25, %cst_13 : tensor<32x16xi1, #blocked>, tensor<32x16xi64, #blocked> loc(#loc196) + %tmp29 = "tt.reduce"(%tmp28) <{axis = 1 : i32}> ({ + ^bb0(%tmp29_741: i64 loc(callsite(#loc1 at #loc197)), %tmp29_742: i64 loc(callsite(#loc1 at #loc197))): + %tmp29_743 = arith.addi %tmp29_741, %tmp29_742 : i64 loc(#loc299) + tt.reduce.return %tmp29_743 : i64 loc(#loc223) + }) : (tensor<32x16xi64, #blocked>) -> tensor<32xi64, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc223) + %tmp31 = ttg.convert_layout %tmp29 : tensor<32xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xi64, #ttg.slice<{dim = 1, parent = #blocked5}>> loc(#loc198) + %tmp29_733 = tt.expand_dims %tmp31 {axis = 1 : i32} : tensor<32xi64, #ttg.slice<{dim = 1, parent = #blocked5}>> -> tensor<32x1xi64, #blocked5> loc(#loc199) + %tmp29_734 = tt.expand_dims %tmp29 {axis = 1 : i32} : tensor<32xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi64, #blocked> loc(#loc199) + %tmp30_735 = arith.trunci %tmp24_731 : tensor<32x1xi64, #blocked5> to tensor<32x1xi32, #blocked5> loc(#loc193) + %tmp30_736 = arith.trunci %tmp24_732 : tensor<32x1xi64, #blocked> to tensor<32x1xi32, #blocked> loc(#loc193) + %tmp31_737 = arith.trunci %tmp29_733 : tensor<32x1xi64, #blocked5> to tensor<32x1xi32, #blocked5> loc(#loc198) + %tmp31_738 = arith.trunci %tmp29_734 : tensor<32x1xi64, #blocked> to tensor<32x1xi32, #blocked> loc(#loc198) + %tmp34 = tt.broadcast %tmp30_736 : tensor<32x1xi32, #blocked> -> tensor<32x16xi32, #blocked> loc(#loc200) + %tmp34_739 = arith.cmpi slt, %tmp0_24, %tmp34 : tensor<32x16xi32, #blocked> loc(#loc200) + %tmp36 = arith.select %tmp34_739, %new_idxs_398, %cst_11 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc201) + %tmp38 = arith.addi %tmp36, %cst_10 : tensor<32x16xi32, #blocked> loc(#loc202) + %tmp39 = arith.cmpi slt, %tmp36, %cst_9 : tensor<32x16xi32, #blocked> loc(#loc203) + %tmp40 = arith.select %tmp39, %tmp38, %tmp36 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc204) + %0 = arith.cmpi sge, %tmp40, %cst_9 : tensor<32x16xi32, #blocked> loc(#loc85) + %1 = arith.cmpi slt, %tmp40, %cst_10 : tensor<32x16xi32, #blocked> loc(#loc86) + %2 = arith.andi %0, %1 : tensor<32x16xi1, #blocked> loc(#loc87) + %3 = arith.xori %xmask, %cst : tensor<32x1xi1, #blocked> loc(#loc88) + %4 = tt.broadcast %3 : tensor<32x1xi1, #blocked> -> tensor<32x16xi1, #blocked> loc(#loc89) + %5 = arith.ori %2, %4 : tensor<32x16xi1, #blocked> loc(#loc89) + tt.assert %5, "index out of bounds: 0 <= tmp40 < 17" : tensor<32x16xi1, #blocked> loc(#loc90) + %tmp45 = tt.broadcast %tmp31_738 : tensor<32x1xi32, #blocked> -> tensor<32x16xi32, #blocked> loc(#loc205) + %tmp45_740 = arith.cmpi slt, %tmp0_24, %tmp45 : tensor<32x16xi32, #blocked> loc(#loc205) + %tmp46 = arith.select %tmp45_740, %new_idxs_730, %cst_11 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc206) + %tmp47 = arith.addi %tmp46, %cst_10 : tensor<32x16xi32, #blocked> loc(#loc207) + %tmp48 = arith.cmpi slt, %tmp46, %cst_9 : tensor<32x16xi32, #blocked> loc(#loc208) + %tmp49 = arith.select %tmp48, %tmp47, %tmp46 : tensor<32x16xi1, #blocked>, tensor<32x16xi32, #blocked> loc(#loc209) + %6 = arith.cmpi sge, %tmp49, %cst_9 : tensor<32x16xi32, #blocked> loc(#loc96) + %7 = arith.cmpi slt, %tmp49, %cst_10 : tensor<32x16xi32, #blocked> loc(#loc97) + %8 = arith.andi %6, %7 : tensor<32x16xi1, #blocked> loc(#loc98) + %9 = arith.ori %8, %4 : tensor<32x16xi1, #blocked> loc(#loc99) + tt.assert %9, "index out of bounds: 0 <= tmp49 < 17" : tensor<32x16xi1, #blocked> loc(#loc100) + %10 = tt.splat %out_ptr4 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked5> loc(#loc101) + %11 = tt.addptr %10, %xindex_21 : tensor<32x1x!tt.ptr, #blocked5>, tensor<32x1xi32, #blocked5> loc(#loc101) + tt.store %11, %tmp30_735, %xmask_22 : tensor<32x1x!tt.ptr, #blocked5> loc(#loc102) + %12 = tt.splat %out_ptr5 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked5> loc(#loc103) + %13 = tt.addptr %12, %xindex_21 : tensor<32x1x!tt.ptr, #blocked5>, tensor<32x1xi32, #blocked5> loc(#loc103) + tt.store %13, %tmp31_737, %xmask_22 : tensor<32x1x!tt.ptr, #blocked5> loc(#loc104) + %14 = tt.splat %out_ptr6 : !tt.ptr -> tensor<32x16x!tt.ptr, #blocked> loc(#loc105) + %15 = tt.addptr %14, %tmp0_26 : tensor<32x16x!tt.ptr, #blocked>, tensor<32x16xi32, #blocked> loc(#loc105) + tt.store %15, %new_idxs_398, %tmp0_29 : tensor<32x16x!tt.ptr, #blocked> loc(#loc106) + %16 = arith.muli %xindex_20, %cst_0 : tensor<32x1xi32, #blocked> loc(#loc107) + %17 = tt.broadcast %16 : tensor<32x1xi32, #blocked> -> tensor<32x16xi32, #blocked> loc(#loc108) + %18 = arith.addi %tmp40, %17 : tensor<32x16xi32, #blocked> loc(#loc108) + %19 = tt.splat %out_ptr7 : !tt.ptr -> tensor<32x16x!tt.ptr, #blocked> loc(#loc109) + %20 = tt.addptr %19, %18 : tensor<32x16x!tt.ptr, #blocked>, tensor<32x16xi32, #blocked> loc(#loc109) + %21 = ttg.convert_layout %20 : tensor<32x16x!tt.ptr, #blocked> -> tensor<32x16x!tt.ptr, #blocked5> loc(#loc110) + tt.store %21, %cst_8, %tmp0_30 : tensor<32x16x!tt.ptr, #blocked5> loc(#loc110) + %22 = tt.splat %out_ptr8 : !tt.ptr -> tensor<32x16x!tt.ptr, #blocked> loc(#loc111) + %23 = tt.addptr %22, %tmp0_26 : tensor<32x16x!tt.ptr, #blocked>, tensor<32x16xi32, #blocked> loc(#loc111) + tt.store %23, %new_idxs_730, %tmp0_29 : tensor<32x16x!tt.ptr, #blocked> loc(#loc112) + %24 = arith.addi %tmp49, %17 : tensor<32x16xi32, #blocked> loc(#loc113) + %25 = tt.splat %out_ptr9 : !tt.ptr -> tensor<32x16x!tt.ptr, #blocked> loc(#loc114) + %26 = tt.addptr %25, %24 : tensor<32x16x!tt.ptr, #blocked>, tensor<32x16xi32, #blocked> loc(#loc114) + %27 = ttg.convert_layout %26 : tensor<32x16x!tt.ptr, #blocked> -> tensor<32x16x!tt.ptr, #blocked5> loc(#loc115) + tt.store %27, %cst_8, %tmp0_30 : tensor<32x16x!tt.ptr, #blocked5> loc(#loc115) + tt.return loc(#loc116) + } loc(#loc) +} loc(#loc) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":24:28) +#loc3 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":24:33) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":25:44) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":25:23) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":26:21) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":27:38) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":34:40) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":34:37) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":34:30) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":34:45) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":36:18) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":38:18) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":39:18) +#loc15 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":41:19) +#loc16 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":40:19) +#loc17 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":43:19) +#loc18 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":45:34) +#loc19 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":627:44) +#loc22 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":627:60) +#loc23 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":627:68) +#loc24 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":533:22) +#loc26 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":537:21) +#loc27 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":538:40) +#loc28 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:36) +#loc30 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:15) +#loc31 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":538:65) +#loc32 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":538:78) +#loc33 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":539:41) +#loc35 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":539:67) +#loc36 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":539:80) +#loc37 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":540:30) +#loc38 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":541:32) +#loc39 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":546:29) +#loc40 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:36) +#loc41 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:23) +#loc42 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":290:25) +#loc44 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:53) +#loc45 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:66) +#loc46 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:37) +#loc47 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:23) +#loc49 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:54) +#loc50 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:67) +#loc51 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":553:36) +#loc52 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":554:38) +#loc53 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":574:22) +#loc54 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":591:21) +#loc55 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":594:40) +#loc56 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":594:29) +#loc57 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":594:23) +#loc58 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":599:19) +#loc59 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":599:28) +#loc60 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":600:38) +#loc61 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":600:46) +#loc62 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":600:15) +#loc63 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":601:48) +#loc64 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":601:59) +#loc65 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":601:22) +#loc66 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":47:20) +#loc67 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":49:21) +#loc68 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":48:21) +#loc70 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":52:20) +#loc71 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":54:35) +#loc73 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":60:21) +#loc74 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":55:29) +#loc75 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":56:21) +#loc76 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":58:35) +#loc78 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":61:21) +#loc79 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":59:29) +#loc80 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":64:19) +#loc81 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":66:35) +#loc82 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":68:20) +#loc83 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":69:20) +#loc84 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":70:35) +#loc85 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":71:28) +#loc86 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":71:46) +#loc87 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":71:38) +#loc88 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":71:55) +#loc89 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":71:53) +#loc90 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":71:63) +#loc91 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":75:19) +#loc92 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":76:35) +#loc93 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":77:20) +#loc94 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":78:20) +#loc95 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":79:35) +#loc96 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":80:28) +#loc97 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":80:46) +#loc98 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":80:38) +#loc99 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":80:53) +#loc100 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":80:63) +#loc101 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":81:25) +#loc102 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":81:37) +#loc103 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":82:25) +#loc104 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":82:37) +#loc105 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":83:25) +#loc106 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":83:47) +#loc107 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":84:52) +#loc108 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":84:49) +#loc109 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":84:25) +#loc110 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":84:85) +#loc111 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":85:25) +#loc112 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":85:47) +#loc113 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":86:49) +#loc114 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":86:25) +#loc115 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":86:85) +#loc116 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":86:4) +#loc126 = loc("xoffset"(#loc2)) +#loc127 = loc("xoffset"(#loc3)) +#loc128 = loc("xindex"(#loc4)) +#loc129 = loc("xindex"(#loc5)) +#loc130 = loc("xmask"(#loc6)) +#loc131 = loc("r0_index"(#loc7)) +#loc132 = loc("tmp0"(#loc8)) +#loc133 = loc("tmp0"(#loc9)) +#loc134 = loc("tmp0"(#loc10)) +#loc135 = loc("tmp0"(#loc11)) +#loc136 = loc("tmp2"(#loc12)) +#loc137 = loc("tmp4"(#loc13)) +#loc138 = loc("tmp5"(#loc14)) +#loc139 = loc("tmp7"(#loc15)) +#loc140 = loc("tmp6"(#loc16)) +#loc141 = loc("tmp9"(#loc17)) +#loc142 = loc("tmp11"(#loc18)) +#loc143 = loc("flip"(#loc19)) +#loc145 = loc("flip"(#loc22)) +#loc146 = loc("flip"(#loc23)) +#loc147 = loc("y"(#loc24)) +#loc148 = loc("left_mask"(#loc26)) +#loc149 = loc("ileft"(#loc27)) +#loc151 = loc("ileft"(#loc31)) +#loc152 = loc("ileft"(#loc32)) +#loc153 = loc("iright"(#loc33)) +#loc155 = loc("iright"(#loc35)) +#loc156 = loc("iright"(#loc36)) +#loc157 = loc("ileft"(#loc37)) +#loc158 = loc("iright"(#loc38)) +#loc159 = loc("y_idx"(#loc39)) +#loc160 = loc("left_idx"(#loc40)) +#loc161 = loc("left_idx"(#loc41)) +#loc162 = loc("input"(#loc42)) +#loc164 = loc("left_idx"(#loc44)) +#loc165 = loc("left_idx"(#loc45)) +#loc166 = loc("right_idx"(#loc46)) +#loc167 = loc("right_idx"(#loc47)) +#loc169 = loc("right_idx"(#loc49)) +#loc170 = loc("right_idx"(#loc50)) +#loc171 = loc("left_idx"(#loc51)) +#loc172 = loc("right_idx"(#loc52)) +#loc173 = loc("cond"(#loc53)) +#loc174 = loc("eq"(#loc54)) +#loc175 = loc("cond"(#loc55)) +#loc176 = loc("cond"(#loc56)) +#loc177 = loc("cond"(#loc57)) +#loc178 = loc("cond"(#loc58)) +#loc179 = loc("cond"(#loc59)) +#loc180 = loc("ret"(#loc60)) +#loc181 = loc("ret"(#loc61)) +#loc182 = loc("ret"(#loc62)) +#loc183 = loc("new_idxs"(#loc63)) +#loc184 = loc("new_idxs"(#loc64)) +#loc185 = loc("new_idxs"(#loc65)) +#loc186 = loc("tmp14"(#loc66)) +#loc187 = loc("tmp16"(#loc67)) +#loc188 = loc("tmp15"(#loc68)) +#loc190 = loc("tmp20"(#loc70)) +#loc191 = loc("tmp23"(#loc71)) +#loc193 = loc("tmp30"(#loc73)) +#loc194 = loc("tmp24"(#loc74)) +#loc195 = loc("tmp25"(#loc75)) +#loc196 = loc("tmp28"(#loc76)) +#loc198 = loc("tmp31"(#loc78)) +#loc199 = loc("tmp29"(#loc79)) +#loc200 = loc("tmp34"(#loc80)) +#loc201 = loc("tmp36"(#loc81)) +#loc202 = loc("tmp38"(#loc82)) +#loc203 = loc("tmp39"(#loc83)) +#loc204 = loc("tmp40"(#loc84)) +#loc205 = loc("tmp45"(#loc91)) +#loc206 = loc("tmp46"(#loc92)) +#loc207 = loc("tmp47"(#loc93)) +#loc208 = loc("tmp48"(#loc94)) +#loc209 = loc("tmp49"(#loc95)) +#loc210 = loc(fused[#loc139, #loc140]) +#loc211 = loc(callsite(#loc143 at #loc144)) +#loc212 = loc(callsite(#loc145 at #loc144)) +#loc213 = loc(callsite(#loc146 at #loc144)) +#loc215 = loc("cond"(#loc173)) +#loc216 = loc("eq"(#loc174)) +#loc217 = loc(fused[#loc187, #loc188]) +#loc219 = loc(fused[#loc190, #loc139, #loc140]) +#loc220 = loc(callsite(#loc28 at #loc192)) +#loc222 = loc(fused[#loc195, #loc187, #loc188]) +#loc223 = loc(callsite(#loc28 at #loc197)) +#loc225 = loc(callsite(#loc147 at #loc214)) +#loc226 = loc(callsite(#loc148 at #loc214)) +#loc227 = loc(callsite(#loc149 at #loc214)) +#loc229 = loc(callsite(#loc151 at #loc214)) +#loc230 = loc(callsite(#loc152 at #loc214)) +#loc231 = loc(callsite(#loc153 at #loc214)) +#loc233 = loc(callsite(#loc155 at #loc214)) +#loc234 = loc(callsite(#loc156 at #loc214)) +#loc235 = loc(callsite(#loc157 at #loc214)) +#loc236 = loc(callsite(#loc158 at #loc214)) +#loc237 = loc(callsite(#loc159 at #loc214)) +#loc238 = loc(callsite(#loc160 at #loc214)) +#loc239 = loc(callsite(#loc161 at #loc214)) +#loc241 = loc(callsite(#loc164 at #loc214)) +#loc242 = loc(callsite(#loc165 at #loc214)) +#loc243 = loc(callsite(#loc166 at #loc214)) +#loc244 = loc(callsite(#loc167 at #loc214)) +#loc246 = loc(callsite(#loc169 at #loc214)) +#loc247 = loc(callsite(#loc170 at #loc214)) +#loc248 = loc(callsite(#loc171 at #loc214)) +#loc249 = loc(callsite(#loc172 at #loc214)) +#loc250 = loc(callsite(#loc215 at #loc214)) +#loc251 = loc(callsite(#loc216 at #loc214)) +#loc252 = loc(callsite(#loc175 at #loc214)) +#loc253 = loc(callsite(#loc176 at #loc214)) +#loc254 = loc(callsite(#loc177 at #loc214)) +#loc255 = loc(callsite(#loc178 at #loc214)) +#loc256 = loc(callsite(#loc179 at #loc214)) +#loc257 = loc(callsite(#loc180 at #loc214)) +#loc258 = loc(callsite(#loc181 at #loc214)) +#loc259 = loc(callsite(#loc182 at #loc214)) +#loc260 = loc(callsite(#loc183 at #loc214)) +#loc261 = loc(callsite(#loc184 at #loc214)) +#loc262 = loc(callsite(#loc185 at #loc214)) +#loc263 = loc(callsite(#loc147 at #loc218)) +#loc264 = loc(callsite(#loc149 at #loc218)) +#loc266 = loc(callsite(#loc151 at #loc218)) +#loc267 = loc(callsite(#loc152 at #loc218)) +#loc268 = loc(callsite(#loc153 at #loc218)) +#loc270 = loc(callsite(#loc155 at #loc218)) +#loc271 = loc(callsite(#loc156 at #loc218)) +#loc272 = loc(callsite(#loc157 at #loc218)) +#loc273 = loc(callsite(#loc158 at #loc218)) +#loc274 = loc(callsite(#loc215 at #loc218)) +#loc275 = loc(callsite(#loc216 at #loc218)) +#loc276 = loc(callsite(#loc176 at #loc218)) +#loc277 = loc(callsite(#loc177 at #loc218)) +#loc278 = loc(callsite(#loc178 at #loc218)) +#loc279 = loc(callsite(#loc179 at #loc218)) +#loc280 = loc(callsite(#loc180 at #loc218)) +#loc281 = loc(callsite(#loc181 at #loc218)) +#loc282 = loc(callsite(#loc182 at #loc218)) +#loc283 = loc(callsite(#loc184 at #loc218)) +#loc284 = loc(callsite(#loc185 at #loc218)) +#loc285 = loc(callsite(#loc159 at #loc218)) +#loc286 = loc(callsite(#loc161 at #loc218)) +#loc288 = loc(callsite(#loc164 at #loc218)) +#loc289 = loc(callsite(#loc165 at #loc218)) +#loc290 = loc(callsite(#loc167 at #loc218)) +#loc292 = loc(callsite(#loc169 at #loc218)) +#loc293 = loc(callsite(#loc170 at #loc218)) +#loc294 = loc(callsite(#loc171 at #loc218)) +#loc295 = loc(callsite(#loc172 at #loc218)) +#loc296 = loc(callsite(#loc175 at #loc218)) +#loc297 = loc(callsite(#loc183 at #loc218)) +#loc298 = loc(callsite(#loc30 at #loc220)) +#loc299 = loc(callsite(#loc30 at #loc223)) +#loc300 = loc(callsite(#loc28 at #loc228)) +#loc302 = loc(callsite(#loc28 at #loc232)) +#loc304 = loc(callsite(#loc162 at #loc240)) +#loc305 = loc(callsite(#loc28 at #loc240)) +#loc307 = loc(callsite(#loc162 at #loc245)) +#loc308 = loc(callsite(#loc28 at #loc245)) +#loc310 = loc(callsite(#loc28 at #loc265)) +#loc312 = loc(callsite(#loc28 at #loc269)) +#loc314 = loc(callsite(#loc28 at #loc287)) +#loc316 = loc(callsite(#loc28 at #loc291)) +#loc318 = loc(callsite(#loc30 at #loc300)) +#loc319 = loc(callsite(#loc30 at #loc302)) +#loc320 = loc(callsite(#loc30 at #loc305)) +#loc321 = loc(callsite(#loc30 at #loc308)) +#loc322 = loc(callsite(#loc30 at #loc310)) +#loc323 = loc(callsite(#loc30 at #loc312)) +#loc324 = loc(callsite(#loc30 at #loc314)) +#loc325 = loc(callsite(#loc30 at #loc316)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/EFSFIQ3KOQUXLGL22JESW5DLVY6LJWVBODRP6EOX4ISS2B62HCMA/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ttir b/SpecForge-ext/cache/compiled_kernels/triton/0/EFSFIQ3KOQUXLGL22JESW5DLVY6LJWVBODRP6EOX4ISS2B62HCMA/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ttir new file mode 100644 index 0000000000000000000000000000000000000000..a16dec0ca54676aaefabf104e29dd72ec51420f8 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/EFSFIQ3KOQUXLGL22JESW5DLVY6LJWVBODRP6EOX4ISS2B62HCMA/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ttir @@ -0,0 +1,1451 @@ +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":18:0) +#loc1 = loc(unknown) +#loc22 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":662:12) +#loc23 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":46:71) +#loc28 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":634:73) +#loc32 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":538:51) +#loc37 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":539:53) +#loc46 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:50) +#loc51 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:51) +#loc72 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":51:71) +#loc75 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":55:26) +#loc79 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":59:26) +#loc120 = loc("in_ptr0"(#loc)) +#loc121 = loc("out_ptr4"(#loc)) +#loc122 = loc("out_ptr5"(#loc)) +#loc123 = loc("out_ptr6"(#loc)) +#loc124 = loc("out_ptr7"(#loc)) +#loc125 = loc("out_ptr8"(#loc)) +#loc126 = loc("out_ptr9"(#loc)) +#loc127 = loc("xnumel"(#loc)) +#loc128 = loc("r0_numel"(#loc)) +#loc149 = loc(callsite(#loc22 at #loc23)) +#loc156 = loc("ileft"(#loc32)) +#loc160 = loc("iright"(#loc37)) +#loc169 = loc("left_idx"(#loc46)) +#loc174 = loc("right_idx"(#loc51)) +#loc195 = loc(callsite(#loc22 at #loc72)) +#loc198 = loc("tmp24"(#loc75)) +#loc202 = loc("tmp29"(#loc79)) +#loc221 = loc(callsite(#loc28 at #loc149)) +#loc225 = loc(callsite(#loc28 at #loc195)) +#loc228 = loc(callsite(#loc1 at #loc198)) +#loc231 = loc(callsite(#loc1 at #loc202)) +#loc235 = loc(callsite(#loc156 at #loc221)) +#loc239 = loc(callsite(#loc160 at #loc221)) +#loc247 = loc(callsite(#loc169 at #loc221)) +#loc252 = loc(callsite(#loc174 at #loc221)) +#loc272 = loc(callsite(#loc156 at #loc225)) +#loc276 = loc(callsite(#loc160 at #loc225)) +#loc294 = loc(callsite(#loc169 at #loc225)) +#loc298 = loc(callsite(#loc174 at #loc225)) +#loc308 = loc(callsite(#loc1 at #loc235)) +#loc310 = loc(callsite(#loc1 at #loc239)) +#loc313 = loc(callsite(#loc1 at #loc247)) +#loc316 = loc(callsite(#loc1 at #loc252)) +#loc318 = loc(callsite(#loc1 at #loc272)) +#loc320 = loc(callsite(#loc1 at #loc276)) +#loc322 = loc(callsite(#loc1 at #loc294)) +#loc324 = loc(callsite(#loc1 at #loc298)) +module { + tt.func public @triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %out_ptr4: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr4"(#loc)), %out_ptr5: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr5"(#loc)), %out_ptr6: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr6"(#loc)), %out_ptr7: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr7"(#loc)), %out_ptr8: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr8"(#loc)), %out_ptr9: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr9"(#loc)), %xnumel: i32 {tt.divisibility = 16 : i32} loc("xnumel"(#loc)), %r0_numel: i32 {tt.divisibility = 16 : i32} loc("r0_numel"(#loc))) attributes {noinline = false} { + %cst = arith.constant dense<1> : tensor<1x2x1xi32> loc(#loc1) + %cst_0 = arith.constant dense<1> : tensor<32x16xi32> loc(#loc1) + %cst_1 = arith.constant dense<17> : tensor<32x1xi32> loc(#loc1) + %cst_2 = arith.constant dense : tensor<32x1xi1> loc(#loc1) + %cst_3 = arith.constant dense<0> : tensor<32x16xi32> loc(#loc1) + %cst_4 = arith.constant dense<17> : tensor<32x16xi32> loc(#loc1) + %cst_5 = arith.constant dense<16> : tensor<32x16xi32> loc(#loc1) + %cst_6 = arith.constant dense<16384> : tensor<32x16xi64> loc(#loc1) + %cst_7 = arith.constant dense<0> : tensor<32x16xi64> loc(#loc1) + %cst_8 = arith.constant dense<16> : tensor<32x1xi32> loc(#loc1) + %xmask = arith.constant dense<128> : tensor<32x1xi32> loc(#loc129) + %c32_i32 = arith.constant 32 : i32 loc(#loc1) + %xoffset = tt.get_program_id x : i32 loc(#loc130) + %xoffset_9 = arith.muli %xoffset, %c32_i32 : i32 loc(#loc131) + %xindex = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> loc(#loc132) + %xindex_10 = tt.expand_dims %xindex {axis = 1 : i32} : tensor<32xi32> -> tensor<32x1xi32> loc(#loc133) + %xindex_11 = tt.splat %xoffset_9 : i32 -> tensor<32x1xi32> loc(#loc134) + %xindex_12 = arith.addi %xindex_11, %xindex_10 : tensor<32x1xi32> loc(#loc134) + %xmask_13 = arith.cmpi slt, %xindex_12, %xmask : tensor<32x1xi32> loc(#loc129) + %r0_index = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc135) + %r0_index_14 = tt.expand_dims %r0_index {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> loc(#loc136) + %tmp0 = arith.muli %xindex_12, %cst_8 : tensor<32x1xi32> loc(#loc137) + %tmp0_15 = tt.broadcast %r0_index_14 : tensor<1x16xi32> -> tensor<32x16xi32> loc(#loc138) + %tmp0_16 = tt.broadcast %tmp0 : tensor<32x1xi32> -> tensor<32x16xi32> loc(#loc138) + %tmp0_17 = arith.addi %tmp0_15, %tmp0_16 : tensor<32x16xi32> loc(#loc138) + %tmp0_18 = tt.splat %in_ptr0 : !tt.ptr -> tensor<32x16x!tt.ptr> loc(#loc139) + %tmp0_19 = tt.addptr %tmp0_18, %tmp0_17 : tensor<32x16x!tt.ptr>, tensor<32x16xi32> loc(#loc139) + %tmp0_20 = tt.broadcast %xmask_13 : tensor<32x1xi1> -> tensor<32x16xi1> loc(#loc140) + %tmp0_21 = tt.load %tmp0_19, %tmp0_20, %cst_7 : tensor<32x16x!tt.ptr> loc(#loc140) + %tmp2 = arith.cmpi sgt, %tmp0_21, %cst_7 : tensor<32x16xi64> loc(#loc141) + %tmp4 = arith.cmpi slt, %tmp0_21, %cst_6 : tensor<32x16xi64> loc(#loc142) + %tmp5 = arith.andi %tmp2, %tmp4 : tensor<32x16xi1> loc(#loc143) + %tmp7 = arith.extui %tmp5 : tensor<32x16xi1> to tensor<32x16xi32> loc(#loc216) + %tmp9 = arith.trunci %r0_index_14 : tensor<1x16xi32> to tensor<1x16xi16> loc(#loc146) + %tmp11 = tt.broadcast %tmp9 : tensor<1x16xi16> -> tensor<32x16xi16> loc(#loc147) + %flip = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc217) + %flip_22 = tt.expand_dims %flip {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc218) + %flip_23 = tt.expand_dims %flip_22 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc218) + %flip_24 = tt.broadcast %flip_23 : tensor<1x2x1xi32> -> tensor<128x2x2xi32> loc(#loc219) + %flip_25 = tt.reshape %flip_24 : tensor<128x2x2xi32> -> tensor<32x16xi32> loc(#loc220) + %y = tt.reshape %tmp7 : tensor<32x16xi32> -> tensor<256x2x1xi32> loc(#loc232) + %left_mask = arith.subi %cst, %flip_23 : tensor<1x2x1xi32> loc(#loc233) + %ileft = tt.broadcast %left_mask : tensor<1x2x1xi32> -> tensor<256x2x1xi32> loc(#loc234) + %ileft_26 = arith.muli %y, %ileft : tensor<256x2x1xi32> loc(#loc234) + %ileft_27 = "tt.reduce"(%ileft_26) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc235)), %ileft_714: i32 loc(callsite(#loc1 at #loc235))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc325) + tt.reduce.return %ileft_715 : i32 loc(#loc307) + }) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc307) + %ileft_28 = tt.expand_dims %ileft_27 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc236) + %ileft_29 = tt.broadcast %ileft_28 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc237) + %iright = tt.broadcast %flip_23 : tensor<1x2x1xi32> -> tensor<256x2x1xi32> loc(#loc238) + %iright_30 = arith.muli %y, %iright : tensor<256x2x1xi32> loc(#loc238) + %iright_31 = "tt.reduce"(%iright_30) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc239)), %iright_714: i32 loc(callsite(#loc1 at #loc239))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc326) + tt.reduce.return %iright_715 : i32 loc(#loc309) + }) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc309) + %iright_32 = tt.expand_dims %iright_31 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc240) + %iright_33 = tt.broadcast %iright_32 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc241) + %ileft_34 = tt.reshape %ileft_29 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc242) + %iright_35 = tt.reshape %iright_33 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc243) + %y_idx = tt.reshape %tmp11 : tensor<32x16xi16> -> tensor<256x2x1xi16> loc(#loc244) + %left_idx = arith.trunci %left_mask : tensor<1x2x1xi32> to tensor<1x2x1xi16> loc(#loc245) + %left_idx_36 = tt.broadcast %left_idx : tensor<1x2x1xi16> -> tensor<256x2x1xi16> loc(#loc246) + %left_idx_37 = arith.muli %y_idx, %left_idx_36 : tensor<256x2x1xi16> loc(#loc246) + %input = arith.extsi %left_idx_37 : tensor<256x2x1xi16> to tensor<256x2x1xi32> loc(#loc311) + %left_idx_38 = "tt.reduce"(%input) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc247)), %left_idx_714: i32 loc(callsite(#loc1 at #loc247))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc327) + tt.reduce.return %left_idx_715 : i32 loc(#loc312) + }) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc312) + %left_idx_39 = tt.expand_dims %left_idx_38 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc248) + %left_idx_40 = tt.broadcast %left_idx_39 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc249) + %right_idx = arith.trunci %flip_23 : tensor<1x2x1xi32> to tensor<1x2x1xi16> loc(#loc250) + %right_idx_41 = tt.broadcast %right_idx : tensor<1x2x1xi16> -> tensor<256x2x1xi16> loc(#loc251) + %right_idx_42 = arith.muli %y_idx, %right_idx_41 : tensor<256x2x1xi16> loc(#loc251) + %input_43 = arith.extsi %right_idx_42 : tensor<256x2x1xi16> to tensor<256x2x1xi32> loc(#loc314) + %right_idx_44 = "tt.reduce"(%input_43) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc252)), %right_idx_714: i32 loc(callsite(#loc1 at #loc252))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc328) + tt.reduce.return %right_idx_715 : i32 loc(#loc315) + }) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc315) + %right_idx_45 = tt.expand_dims %right_idx_44 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc253) + %right_idx_46 = tt.broadcast %right_idx_45 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc254) + %left_idx_47 = tt.reshape %left_idx_40 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc255) + %right_idx_48 = tt.reshape %right_idx_46 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc256) + %cond = arith.cmpi slt, %ileft_34, %iright_35 : tensor<32x16xi32> loc(#loc257) + %eq = arith.cmpi eq, %ileft_34, %iright_35 : tensor<32x16xi32> loc(#loc258) + %cond_49 = arith.cmpi sgt, %left_idx_47, %right_idx_48 : tensor<32x16xi32> loc(#loc259) + %cond_50 = arith.andi %eq, %cond_49 : tensor<32x16xi1> loc(#loc260) + %cond_51 = arith.ori %cond, %cond_50 : tensor<32x16xi1> loc(#loc261) + %cond_52 = arith.extui %cond_51 : tensor<32x16xi1> to tensor<32x16xi32> loc(#loc262) + %cond_53 = arith.xori %cond_52, %flip_25 : tensor<32x16xi32> loc(#loc262) + %cond_54 = arith.cmpi ne, %cond_53, %cst_3 : tensor<32x16xi32> loc(#loc263) + %ret = arith.xori %ileft_34, %iright_35 : tensor<32x16xi32> loc(#loc264) + %ret_55 = arith.select %cond_54, %ret, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc265) + %ret_56 = arith.xori %tmp7, %ret_55 : tensor<32x16xi32> loc(#loc266) + %new_idxs = arith.xori %left_idx_47, %right_idx_48 : tensor<32x16xi32> loc(#loc267) + %new_idxs_57 = arith.select %cond_54, %new_idxs, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc268) + %new_idxs_58 = arith.extsi %tmp9 : tensor<1x16xi16> to tensor<1x16xi32> loc(#loc269) + %new_idxs_59 = tt.broadcast %new_idxs_58 : tensor<1x16xi32> -> tensor<32x16xi32> loc(#loc269) + %new_idxs_60 = arith.xori %new_idxs_59, %new_idxs_57 : tensor<32x16xi32> loc(#loc269) + %flip_61 = tt.broadcast %flip_23 : tensor<1x2x1xi32> -> tensor<64x2x4xi32> loc(#loc219) + %flip_62 = tt.reshape %flip_61 : tensor<64x2x4xi32> -> tensor<32x16xi32> loc(#loc220) + %y_63 = tt.reshape %ret_56 : tensor<32x16xi32> -> tensor<128x2x2xi32> loc(#loc232) + %ileft_64 = tt.broadcast %left_mask : tensor<1x2x1xi32> -> tensor<128x2x2xi32> loc(#loc234) + %ileft_65 = arith.muli %y_63, %ileft_64 : tensor<128x2x2xi32> loc(#loc234) + %ileft_66 = "tt.reduce"(%ileft_65) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc235)), %ileft_714: i32 loc(callsite(#loc1 at #loc235))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc325) + tt.reduce.return %ileft_715 : i32 loc(#loc307) + }) : (tensor<128x2x2xi32>) -> tensor<128x2xi32> loc(#loc307) + %ileft_67 = tt.expand_dims %ileft_66 {axis = 1 : i32} : tensor<128x2xi32> -> tensor<128x1x2xi32> loc(#loc236) + %ileft_68 = tt.broadcast %ileft_67 : tensor<128x1x2xi32> -> tensor<128x2x2xi32> loc(#loc237) + %iright_69 = arith.muli %y_63, %flip_24 : tensor<128x2x2xi32> loc(#loc238) + %iright_70 = "tt.reduce"(%iright_69) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc239)), %iright_714: i32 loc(callsite(#loc1 at #loc239))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc326) + tt.reduce.return %iright_715 : i32 loc(#loc309) + }) : (tensor<128x2x2xi32>) -> tensor<128x2xi32> loc(#loc309) + %iright_71 = tt.expand_dims %iright_70 {axis = 1 : i32} : tensor<128x2xi32> -> tensor<128x1x2xi32> loc(#loc240) + %iright_72 = tt.broadcast %iright_71 : tensor<128x1x2xi32> -> tensor<128x2x2xi32> loc(#loc241) + %ileft_73 = tt.reshape %ileft_68 : tensor<128x2x2xi32> -> tensor<32x16xi32> loc(#loc242) + %iright_74 = tt.reshape %iright_72 : tensor<128x2x2xi32> -> tensor<32x16xi32> loc(#loc243) + %y_idx_75 = tt.reshape %new_idxs_60 : tensor<32x16xi32> -> tensor<128x2x2xi32> loc(#loc244) + %left_idx_76 = arith.muli %y_idx_75, %ileft_64 : tensor<128x2x2xi32> loc(#loc246) + %left_idx_77 = "tt.reduce"(%left_idx_76) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc247)), %left_idx_714: i32 loc(callsite(#loc1 at #loc247))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc327) + tt.reduce.return %left_idx_715 : i32 loc(#loc312) + }) : (tensor<128x2x2xi32>) -> tensor<128x2xi32> loc(#loc312) + %left_idx_78 = tt.expand_dims %left_idx_77 {axis = 1 : i32} : tensor<128x2xi32> -> tensor<128x1x2xi32> loc(#loc248) + %left_idx_79 = tt.broadcast %left_idx_78 : tensor<128x1x2xi32> -> tensor<128x2x2xi32> loc(#loc249) + %right_idx_80 = arith.muli %y_idx_75, %flip_24 : tensor<128x2x2xi32> loc(#loc251) + %right_idx_81 = "tt.reduce"(%right_idx_80) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc252)), %right_idx_714: i32 loc(callsite(#loc1 at #loc252))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc328) + tt.reduce.return %right_idx_715 : i32 loc(#loc315) + }) : (tensor<128x2x2xi32>) -> tensor<128x2xi32> loc(#loc315) + %right_idx_82 = tt.expand_dims %right_idx_81 {axis = 1 : i32} : tensor<128x2xi32> -> tensor<128x1x2xi32> loc(#loc253) + %right_idx_83 = tt.broadcast %right_idx_82 : tensor<128x1x2xi32> -> tensor<128x2x2xi32> loc(#loc254) + %left_idx_84 = tt.reshape %left_idx_79 : tensor<128x2x2xi32> -> tensor<32x16xi32> loc(#loc255) + %right_idx_85 = tt.reshape %right_idx_83 : tensor<128x2x2xi32> -> tensor<32x16xi32> loc(#loc256) + %cond_86 = arith.cmpi slt, %ileft_73, %iright_74 : tensor<32x16xi32> loc(#loc257) + %eq_87 = arith.cmpi eq, %ileft_73, %iright_74 : tensor<32x16xi32> loc(#loc258) + %cond_88 = arith.cmpi sgt, %left_idx_84, %right_idx_85 : tensor<32x16xi32> loc(#loc259) + %cond_89 = arith.andi %eq_87, %cond_88 : tensor<32x16xi1> loc(#loc260) + %cond_90 = arith.ori %cond_86, %cond_89 : tensor<32x16xi1> loc(#loc261) + %cond_91 = arith.extui %cond_90 : tensor<32x16xi1> to tensor<32x16xi32> loc(#loc262) + %cond_92 = arith.xori %cond_91, %flip_62 : tensor<32x16xi32> loc(#loc262) + %cond_93 = arith.cmpi ne, %cond_92, %cst_3 : tensor<32x16xi32> loc(#loc263) + %ret_94 = arith.xori %ileft_73, %iright_74 : tensor<32x16xi32> loc(#loc264) + %ret_95 = arith.select %cond_93, %ret_94, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc265) + %ret_96 = arith.xori %ret_56, %ret_95 : tensor<32x16xi32> loc(#loc266) + %new_idxs_97 = arith.xori %left_idx_84, %right_idx_85 : tensor<32x16xi32> loc(#loc267) + %new_idxs_98 = arith.select %cond_93, %new_idxs_97, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc268) + %new_idxs_99 = arith.xori %new_idxs_60, %new_idxs_98 : tensor<32x16xi32> loc(#loc269) + %y_100 = tt.reshape %ret_96 : tensor<32x16xi32> -> tensor<256x2x1xi32> loc(#loc232) + %ileft_101 = arith.muli %y_100, %ileft : tensor<256x2x1xi32> loc(#loc234) + %ileft_102 = "tt.reduce"(%ileft_101) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc235)), %ileft_714: i32 loc(callsite(#loc1 at #loc235))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc325) + tt.reduce.return %ileft_715 : i32 loc(#loc307) + }) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc307) + %ileft_103 = tt.expand_dims %ileft_102 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc236) + %ileft_104 = tt.broadcast %ileft_103 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc237) + %iright_105 = arith.muli %y_100, %iright : tensor<256x2x1xi32> loc(#loc238) + %iright_106 = "tt.reduce"(%iright_105) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc239)), %iright_714: i32 loc(callsite(#loc1 at #loc239))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc326) + tt.reduce.return %iright_715 : i32 loc(#loc309) + }) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc309) + %iright_107 = tt.expand_dims %iright_106 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc240) + %iright_108 = tt.broadcast %iright_107 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc241) + %ileft_109 = tt.reshape %ileft_104 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc242) + %iright_110 = tt.reshape %iright_108 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc243) + %y_idx_111 = tt.reshape %new_idxs_99 : tensor<32x16xi32> -> tensor<256x2x1xi32> loc(#loc244) + %left_idx_112 = arith.muli %y_idx_111, %ileft : tensor<256x2x1xi32> loc(#loc246) + %left_idx_113 = "tt.reduce"(%left_idx_112) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc247)), %left_idx_714: i32 loc(callsite(#loc1 at #loc247))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc327) + tt.reduce.return %left_idx_715 : i32 loc(#loc312) + }) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc312) + %left_idx_114 = tt.expand_dims %left_idx_113 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc248) + %left_idx_115 = tt.broadcast %left_idx_114 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc249) + %right_idx_116 = arith.muli %y_idx_111, %iright : tensor<256x2x1xi32> loc(#loc251) + %right_idx_117 = "tt.reduce"(%right_idx_116) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc252)), %right_idx_714: i32 loc(callsite(#loc1 at #loc252))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc328) + tt.reduce.return %right_idx_715 : i32 loc(#loc315) + }) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc315) + %right_idx_118 = tt.expand_dims %right_idx_117 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc253) + %right_idx_119 = tt.broadcast %right_idx_118 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc254) + %left_idx_120 = tt.reshape %left_idx_115 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc255) + %right_idx_121 = tt.reshape %right_idx_119 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc256) + %cond_122 = arith.cmpi slt, %ileft_109, %iright_110 : tensor<32x16xi32> loc(#loc257) + %eq_123 = arith.cmpi eq, %ileft_109, %iright_110 : tensor<32x16xi32> loc(#loc258) + %cond_124 = arith.cmpi sgt, %left_idx_120, %right_idx_121 : tensor<32x16xi32> loc(#loc259) + %cond_125 = arith.andi %eq_123, %cond_124 : tensor<32x16xi1> loc(#loc260) + %cond_126 = arith.ori %cond_122, %cond_125 : tensor<32x16xi1> loc(#loc261) + %cond_127 = arith.extui %cond_126 : tensor<32x16xi1> to tensor<32x16xi32> loc(#loc262) + %cond_128 = arith.xori %cond_127, %flip_62 : tensor<32x16xi32> loc(#loc262) + %cond_129 = arith.cmpi ne, %cond_128, %cst_3 : tensor<32x16xi32> loc(#loc263) + %ret_130 = arith.xori %ileft_109, %iright_110 : tensor<32x16xi32> loc(#loc264) + %ret_131 = arith.select %cond_129, %ret_130, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc265) + %ret_132 = arith.xori %ret_96, %ret_131 : tensor<32x16xi32> loc(#loc266) + %new_idxs_133 = arith.xori %left_idx_120, %right_idx_121 : tensor<32x16xi32> loc(#loc267) + %new_idxs_134 = arith.select %cond_129, %new_idxs_133, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc268) + %new_idxs_135 = arith.xori %new_idxs_99, %new_idxs_134 : tensor<32x16xi32> loc(#loc269) + %flip_136 = tt.broadcast %flip_23 : tensor<1x2x1xi32> -> tensor<32x2x8xi32> loc(#loc219) + %flip_137 = tt.reshape %flip_136 : tensor<32x2x8xi32> -> tensor<32x16xi32> loc(#loc220) + %y_138 = tt.reshape %ret_132 : tensor<32x16xi32> -> tensor<64x2x4xi32> loc(#loc232) + %ileft_139 = tt.broadcast %left_mask : tensor<1x2x1xi32> -> tensor<64x2x4xi32> loc(#loc234) + %ileft_140 = arith.muli %y_138, %ileft_139 : tensor<64x2x4xi32> loc(#loc234) + %ileft_141 = "tt.reduce"(%ileft_140) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc235)), %ileft_714: i32 loc(callsite(#loc1 at #loc235))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc325) + tt.reduce.return %ileft_715 : i32 loc(#loc307) + }) : (tensor<64x2x4xi32>) -> tensor<64x4xi32> loc(#loc307) + %ileft_142 = tt.expand_dims %ileft_141 {axis = 1 : i32} : tensor<64x4xi32> -> tensor<64x1x4xi32> loc(#loc236) + %ileft_143 = tt.broadcast %ileft_142 : tensor<64x1x4xi32> -> tensor<64x2x4xi32> loc(#loc237) + %iright_144 = arith.muli %y_138, %flip_61 : tensor<64x2x4xi32> loc(#loc238) + %iright_145 = "tt.reduce"(%iright_144) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc239)), %iright_714: i32 loc(callsite(#loc1 at #loc239))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc326) + tt.reduce.return %iright_715 : i32 loc(#loc309) + }) : (tensor<64x2x4xi32>) -> tensor<64x4xi32> loc(#loc309) + %iright_146 = tt.expand_dims %iright_145 {axis = 1 : i32} : tensor<64x4xi32> -> tensor<64x1x4xi32> loc(#loc240) + %iright_147 = tt.broadcast %iright_146 : tensor<64x1x4xi32> -> tensor<64x2x4xi32> loc(#loc241) + %ileft_148 = tt.reshape %ileft_143 : tensor<64x2x4xi32> -> tensor<32x16xi32> loc(#loc242) + %iright_149 = tt.reshape %iright_147 : tensor<64x2x4xi32> -> tensor<32x16xi32> loc(#loc243) + %y_idx_150 = tt.reshape %new_idxs_135 : tensor<32x16xi32> -> tensor<64x2x4xi32> loc(#loc244) + %left_idx_151 = arith.muli %y_idx_150, %ileft_139 : tensor<64x2x4xi32> loc(#loc246) + %left_idx_152 = "tt.reduce"(%left_idx_151) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc247)), %left_idx_714: i32 loc(callsite(#loc1 at #loc247))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc327) + tt.reduce.return %left_idx_715 : i32 loc(#loc312) + }) : (tensor<64x2x4xi32>) -> tensor<64x4xi32> loc(#loc312) + %left_idx_153 = tt.expand_dims %left_idx_152 {axis = 1 : i32} : tensor<64x4xi32> -> tensor<64x1x4xi32> loc(#loc248) + %left_idx_154 = tt.broadcast %left_idx_153 : tensor<64x1x4xi32> -> tensor<64x2x4xi32> loc(#loc249) + %right_idx_155 = arith.muli %y_idx_150, %flip_61 : tensor<64x2x4xi32> loc(#loc251) + %right_idx_156 = "tt.reduce"(%right_idx_155) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc252)), %right_idx_714: i32 loc(callsite(#loc1 at #loc252))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc328) + tt.reduce.return %right_idx_715 : i32 loc(#loc315) + }) : (tensor<64x2x4xi32>) -> tensor<64x4xi32> loc(#loc315) + %right_idx_157 = tt.expand_dims %right_idx_156 {axis = 1 : i32} : tensor<64x4xi32> -> tensor<64x1x4xi32> loc(#loc253) + %right_idx_158 = tt.broadcast %right_idx_157 : tensor<64x1x4xi32> -> tensor<64x2x4xi32> loc(#loc254) + %left_idx_159 = tt.reshape %left_idx_154 : tensor<64x2x4xi32> -> tensor<32x16xi32> loc(#loc255) + %right_idx_160 = tt.reshape %right_idx_158 : tensor<64x2x4xi32> -> tensor<32x16xi32> loc(#loc256) + %cond_161 = arith.cmpi slt, %ileft_148, %iright_149 : tensor<32x16xi32> loc(#loc257) + %eq_162 = arith.cmpi eq, %ileft_148, %iright_149 : tensor<32x16xi32> loc(#loc258) + %cond_163 = arith.cmpi sgt, %left_idx_159, %right_idx_160 : tensor<32x16xi32> loc(#loc259) + %cond_164 = arith.andi %eq_162, %cond_163 : tensor<32x16xi1> loc(#loc260) + %cond_165 = arith.ori %cond_161, %cond_164 : tensor<32x16xi1> loc(#loc261) + %cond_166 = arith.extui %cond_165 : tensor<32x16xi1> to tensor<32x16xi32> loc(#loc262) + %cond_167 = arith.xori %cond_166, %flip_137 : tensor<32x16xi32> loc(#loc262) + %cond_168 = arith.cmpi ne, %cond_167, %cst_3 : tensor<32x16xi32> loc(#loc263) + %ret_169 = arith.xori %ileft_148, %iright_149 : tensor<32x16xi32> loc(#loc264) + %ret_170 = arith.select %cond_168, %ret_169, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc265) + %ret_171 = arith.xori %ret_132, %ret_170 : tensor<32x16xi32> loc(#loc266) + %new_idxs_172 = arith.xori %left_idx_159, %right_idx_160 : tensor<32x16xi32> loc(#loc267) + %new_idxs_173 = arith.select %cond_168, %new_idxs_172, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc268) + %new_idxs_174 = arith.xori %new_idxs_135, %new_idxs_173 : tensor<32x16xi32> loc(#loc269) + %y_175 = tt.reshape %ret_171 : tensor<32x16xi32> -> tensor<128x2x2xi32> loc(#loc232) + %ileft_176 = arith.muli %y_175, %ileft_64 : tensor<128x2x2xi32> loc(#loc234) + %ileft_177 = "tt.reduce"(%ileft_176) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc235)), %ileft_714: i32 loc(callsite(#loc1 at #loc235))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc325) + tt.reduce.return %ileft_715 : i32 loc(#loc307) + }) : (tensor<128x2x2xi32>) -> tensor<128x2xi32> loc(#loc307) + %ileft_178 = tt.expand_dims %ileft_177 {axis = 1 : i32} : tensor<128x2xi32> -> tensor<128x1x2xi32> loc(#loc236) + %ileft_179 = tt.broadcast %ileft_178 : tensor<128x1x2xi32> -> tensor<128x2x2xi32> loc(#loc237) + %iright_180 = arith.muli %y_175, %flip_24 : tensor<128x2x2xi32> loc(#loc238) + %iright_181 = "tt.reduce"(%iright_180) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc239)), %iright_714: i32 loc(callsite(#loc1 at #loc239))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc326) + tt.reduce.return %iright_715 : i32 loc(#loc309) + }) : (tensor<128x2x2xi32>) -> tensor<128x2xi32> loc(#loc309) + %iright_182 = tt.expand_dims %iright_181 {axis = 1 : i32} : tensor<128x2xi32> -> tensor<128x1x2xi32> loc(#loc240) + %iright_183 = tt.broadcast %iright_182 : tensor<128x1x2xi32> -> tensor<128x2x2xi32> loc(#loc241) + %ileft_184 = tt.reshape %ileft_179 : tensor<128x2x2xi32> -> tensor<32x16xi32> loc(#loc242) + %iright_185 = tt.reshape %iright_183 : tensor<128x2x2xi32> -> tensor<32x16xi32> loc(#loc243) + %y_idx_186 = tt.reshape %new_idxs_174 : tensor<32x16xi32> -> tensor<128x2x2xi32> loc(#loc244) + %left_idx_187 = arith.muli %y_idx_186, %ileft_64 : tensor<128x2x2xi32> loc(#loc246) + %left_idx_188 = "tt.reduce"(%left_idx_187) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc247)), %left_idx_714: i32 loc(callsite(#loc1 at #loc247))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc327) + tt.reduce.return %left_idx_715 : i32 loc(#loc312) + }) : (tensor<128x2x2xi32>) -> tensor<128x2xi32> loc(#loc312) + %left_idx_189 = tt.expand_dims %left_idx_188 {axis = 1 : i32} : tensor<128x2xi32> -> tensor<128x1x2xi32> loc(#loc248) + %left_idx_190 = tt.broadcast %left_idx_189 : tensor<128x1x2xi32> -> tensor<128x2x2xi32> loc(#loc249) + %right_idx_191 = arith.muli %y_idx_186, %flip_24 : tensor<128x2x2xi32> loc(#loc251) + %right_idx_192 = "tt.reduce"(%right_idx_191) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc252)), %right_idx_714: i32 loc(callsite(#loc1 at #loc252))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc328) + tt.reduce.return %right_idx_715 : i32 loc(#loc315) + }) : (tensor<128x2x2xi32>) -> tensor<128x2xi32> loc(#loc315) + %right_idx_193 = tt.expand_dims %right_idx_192 {axis = 1 : i32} : tensor<128x2xi32> -> tensor<128x1x2xi32> loc(#loc253) + %right_idx_194 = tt.broadcast %right_idx_193 : tensor<128x1x2xi32> -> tensor<128x2x2xi32> loc(#loc254) + %left_idx_195 = tt.reshape %left_idx_190 : tensor<128x2x2xi32> -> tensor<32x16xi32> loc(#loc255) + %right_idx_196 = tt.reshape %right_idx_194 : tensor<128x2x2xi32> -> tensor<32x16xi32> loc(#loc256) + %cond_197 = arith.cmpi slt, %ileft_184, %iright_185 : tensor<32x16xi32> loc(#loc257) + %eq_198 = arith.cmpi eq, %ileft_184, %iright_185 : tensor<32x16xi32> loc(#loc258) + %cond_199 = arith.cmpi sgt, %left_idx_195, %right_idx_196 : tensor<32x16xi32> loc(#loc259) + %cond_200 = arith.andi %eq_198, %cond_199 : tensor<32x16xi1> loc(#loc260) + %cond_201 = arith.ori %cond_197, %cond_200 : tensor<32x16xi1> loc(#loc261) + %cond_202 = arith.extui %cond_201 : tensor<32x16xi1> to tensor<32x16xi32> loc(#loc262) + %cond_203 = arith.xori %cond_202, %flip_137 : tensor<32x16xi32> loc(#loc262) + %cond_204 = arith.cmpi ne, %cond_203, %cst_3 : tensor<32x16xi32> loc(#loc263) + %ret_205 = arith.xori %ileft_184, %iright_185 : tensor<32x16xi32> loc(#loc264) + %ret_206 = arith.select %cond_204, %ret_205, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc265) + %ret_207 = arith.xori %ret_171, %ret_206 : tensor<32x16xi32> loc(#loc266) + %new_idxs_208 = arith.xori %left_idx_195, %right_idx_196 : tensor<32x16xi32> loc(#loc267) + %new_idxs_209 = arith.select %cond_204, %new_idxs_208, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc268) + %new_idxs_210 = arith.xori %new_idxs_174, %new_idxs_209 : tensor<32x16xi32> loc(#loc269) + %y_211 = tt.reshape %ret_207 : tensor<32x16xi32> -> tensor<256x2x1xi32> loc(#loc232) + %ileft_212 = arith.muli %y_211, %ileft : tensor<256x2x1xi32> loc(#loc234) + %ileft_213 = "tt.reduce"(%ileft_212) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc235)), %ileft_714: i32 loc(callsite(#loc1 at #loc235))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc325) + tt.reduce.return %ileft_715 : i32 loc(#loc307) + }) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc307) + %ileft_214 = tt.expand_dims %ileft_213 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc236) + %ileft_215 = tt.broadcast %ileft_214 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc237) + %iright_216 = arith.muli %y_211, %iright : tensor<256x2x1xi32> loc(#loc238) + %iright_217 = "tt.reduce"(%iright_216) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc239)), %iright_714: i32 loc(callsite(#loc1 at #loc239))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc326) + tt.reduce.return %iright_715 : i32 loc(#loc309) + }) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc309) + %iright_218 = tt.expand_dims %iright_217 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc240) + %iright_219 = tt.broadcast %iright_218 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc241) + %ileft_220 = tt.reshape %ileft_215 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc242) + %iright_221 = tt.reshape %iright_219 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc243) + %y_idx_222 = tt.reshape %new_idxs_210 : tensor<32x16xi32> -> tensor<256x2x1xi32> loc(#loc244) + %left_idx_223 = arith.muli %y_idx_222, %ileft : tensor<256x2x1xi32> loc(#loc246) + %left_idx_224 = "tt.reduce"(%left_idx_223) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc247)), %left_idx_714: i32 loc(callsite(#loc1 at #loc247))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc327) + tt.reduce.return %left_idx_715 : i32 loc(#loc312) + }) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc312) + %left_idx_225 = tt.expand_dims %left_idx_224 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc248) + %left_idx_226 = tt.broadcast %left_idx_225 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc249) + %right_idx_227 = arith.muli %y_idx_222, %iright : tensor<256x2x1xi32> loc(#loc251) + %right_idx_228 = "tt.reduce"(%right_idx_227) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc252)), %right_idx_714: i32 loc(callsite(#loc1 at #loc252))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc328) + tt.reduce.return %right_idx_715 : i32 loc(#loc315) + }) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc315) + %right_idx_229 = tt.expand_dims %right_idx_228 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc253) + %right_idx_230 = tt.broadcast %right_idx_229 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc254) + %left_idx_231 = tt.reshape %left_idx_226 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc255) + %right_idx_232 = tt.reshape %right_idx_230 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc256) + %cond_233 = arith.cmpi slt, %ileft_220, %iright_221 : tensor<32x16xi32> loc(#loc257) + %eq_234 = arith.cmpi eq, %ileft_220, %iright_221 : tensor<32x16xi32> loc(#loc258) + %cond_235 = arith.cmpi sgt, %left_idx_231, %right_idx_232 : tensor<32x16xi32> loc(#loc259) + %cond_236 = arith.andi %eq_234, %cond_235 : tensor<32x16xi1> loc(#loc260) + %cond_237 = arith.ori %cond_233, %cond_236 : tensor<32x16xi1> loc(#loc261) + %cond_238 = arith.extui %cond_237 : tensor<32x16xi1> to tensor<32x16xi32> loc(#loc262) + %cond_239 = arith.xori %cond_238, %flip_137 : tensor<32x16xi32> loc(#loc262) + %cond_240 = arith.cmpi ne, %cond_239, %cst_3 : tensor<32x16xi32> loc(#loc263) + %ret_241 = arith.xori %ileft_220, %iright_221 : tensor<32x16xi32> loc(#loc264) + %ret_242 = arith.select %cond_240, %ret_241, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc265) + %ret_243 = arith.xori %ret_207, %ret_242 : tensor<32x16xi32> loc(#loc266) + %new_idxs_244 = arith.xori %left_idx_231, %right_idx_232 : tensor<32x16xi32> loc(#loc267) + %new_idxs_245 = arith.select %cond_240, %new_idxs_244, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc268) + %new_idxs_246 = arith.xori %new_idxs_210, %new_idxs_245 : tensor<32x16xi32> loc(#loc269) + %y_247 = tt.reshape %ret_243 : tensor<32x16xi32> -> tensor<32x2x8xi32> loc(#loc232) + %ileft_248 = tt.broadcast %left_mask : tensor<1x2x1xi32> -> tensor<32x2x8xi32> loc(#loc234) + %ileft_249 = arith.muli %y_247, %ileft_248 : tensor<32x2x8xi32> loc(#loc234) + %ileft_250 = "tt.reduce"(%ileft_249) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc235)), %ileft_714: i32 loc(callsite(#loc1 at #loc235))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc325) + tt.reduce.return %ileft_715 : i32 loc(#loc307) + }) : (tensor<32x2x8xi32>) -> tensor<32x8xi32> loc(#loc307) + %ileft_251 = tt.expand_dims %ileft_250 {axis = 1 : i32} : tensor<32x8xi32> -> tensor<32x1x8xi32> loc(#loc236) + %ileft_252 = tt.broadcast %ileft_251 : tensor<32x1x8xi32> -> tensor<32x2x8xi32> loc(#loc237) + %iright_253 = arith.muli %y_247, %flip_136 : tensor<32x2x8xi32> loc(#loc238) + %iright_254 = "tt.reduce"(%iright_253) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc239)), %iright_714: i32 loc(callsite(#loc1 at #loc239))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc326) + tt.reduce.return %iright_715 : i32 loc(#loc309) + }) : (tensor<32x2x8xi32>) -> tensor<32x8xi32> loc(#loc309) + %iright_255 = tt.expand_dims %iright_254 {axis = 1 : i32} : tensor<32x8xi32> -> tensor<32x1x8xi32> loc(#loc240) + %iright_256 = tt.broadcast %iright_255 : tensor<32x1x8xi32> -> tensor<32x2x8xi32> loc(#loc241) + %ileft_257 = tt.reshape %ileft_252 : tensor<32x2x8xi32> -> tensor<32x16xi32> loc(#loc242) + %iright_258 = tt.reshape %iright_256 : tensor<32x2x8xi32> -> tensor<32x16xi32> loc(#loc243) + %y_idx_259 = tt.reshape %new_idxs_246 : tensor<32x16xi32> -> tensor<32x2x8xi32> loc(#loc244) + %left_idx_260 = arith.muli %y_idx_259, %ileft_248 : tensor<32x2x8xi32> loc(#loc246) + %left_idx_261 = "tt.reduce"(%left_idx_260) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc247)), %left_idx_714: i32 loc(callsite(#loc1 at #loc247))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc327) + tt.reduce.return %left_idx_715 : i32 loc(#loc312) + }) : (tensor<32x2x8xi32>) -> tensor<32x8xi32> loc(#loc312) + %left_idx_262 = tt.expand_dims %left_idx_261 {axis = 1 : i32} : tensor<32x8xi32> -> tensor<32x1x8xi32> loc(#loc248) + %left_idx_263 = tt.broadcast %left_idx_262 : tensor<32x1x8xi32> -> tensor<32x2x8xi32> loc(#loc249) + %right_idx_264 = arith.muli %y_idx_259, %flip_136 : tensor<32x2x8xi32> loc(#loc251) + %right_idx_265 = "tt.reduce"(%right_idx_264) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc252)), %right_idx_714: i32 loc(callsite(#loc1 at #loc252))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc328) + tt.reduce.return %right_idx_715 : i32 loc(#loc315) + }) : (tensor<32x2x8xi32>) -> tensor<32x8xi32> loc(#loc315) + %right_idx_266 = tt.expand_dims %right_idx_265 {axis = 1 : i32} : tensor<32x8xi32> -> tensor<32x1x8xi32> loc(#loc253) + %right_idx_267 = tt.broadcast %right_idx_266 : tensor<32x1x8xi32> -> tensor<32x2x8xi32> loc(#loc254) + %left_idx_268 = tt.reshape %left_idx_263 : tensor<32x2x8xi32> -> tensor<32x16xi32> loc(#loc255) + %right_idx_269 = tt.reshape %right_idx_267 : tensor<32x2x8xi32> -> tensor<32x16xi32> loc(#loc256) + %cond_270 = arith.cmpi slt, %ileft_257, %iright_258 : tensor<32x16xi32> loc(#loc257) + %eq_271 = arith.cmpi eq, %ileft_257, %iright_258 : tensor<32x16xi32> loc(#loc258) + %cond_272 = arith.cmpi sgt, %left_idx_268, %right_idx_269 : tensor<32x16xi32> loc(#loc259) + %cond_273 = arith.andi %eq_271, %cond_272 : tensor<32x16xi1> loc(#loc260) + %cond_274 = arith.ori %cond_270, %cond_273 : tensor<32x16xi1> loc(#loc261) + %ret_275 = arith.xori %ileft_257, %iright_258 : tensor<32x16xi32> loc(#loc264) + %ret_276 = arith.select %cond_274, %ret_275, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc265) + %ret_277 = arith.xori %ret_243, %ret_276 : tensor<32x16xi32> loc(#loc266) + %new_idxs_278 = arith.xori %left_idx_268, %right_idx_269 : tensor<32x16xi32> loc(#loc267) + %new_idxs_279 = arith.select %cond_274, %new_idxs_278, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc268) + %new_idxs_280 = arith.xori %new_idxs_246, %new_idxs_279 : tensor<32x16xi32> loc(#loc269) + %y_281 = tt.reshape %ret_277 : tensor<32x16xi32> -> tensor<64x2x4xi32> loc(#loc232) + %ileft_282 = arith.muli %y_281, %ileft_139 : tensor<64x2x4xi32> loc(#loc234) + %ileft_283 = "tt.reduce"(%ileft_282) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc235)), %ileft_714: i32 loc(callsite(#loc1 at #loc235))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc325) + tt.reduce.return %ileft_715 : i32 loc(#loc307) + }) : (tensor<64x2x4xi32>) -> tensor<64x4xi32> loc(#loc307) + %ileft_284 = tt.expand_dims %ileft_283 {axis = 1 : i32} : tensor<64x4xi32> -> tensor<64x1x4xi32> loc(#loc236) + %ileft_285 = tt.broadcast %ileft_284 : tensor<64x1x4xi32> -> tensor<64x2x4xi32> loc(#loc237) + %iright_286 = arith.muli %y_281, %flip_61 : tensor<64x2x4xi32> loc(#loc238) + %iright_287 = "tt.reduce"(%iright_286) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc239)), %iright_714: i32 loc(callsite(#loc1 at #loc239))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc326) + tt.reduce.return %iright_715 : i32 loc(#loc309) + }) : (tensor<64x2x4xi32>) -> tensor<64x4xi32> loc(#loc309) + %iright_288 = tt.expand_dims %iright_287 {axis = 1 : i32} : tensor<64x4xi32> -> tensor<64x1x4xi32> loc(#loc240) + %iright_289 = tt.broadcast %iright_288 : tensor<64x1x4xi32> -> tensor<64x2x4xi32> loc(#loc241) + %ileft_290 = tt.reshape %ileft_285 : tensor<64x2x4xi32> -> tensor<32x16xi32> loc(#loc242) + %iright_291 = tt.reshape %iright_289 : tensor<64x2x4xi32> -> tensor<32x16xi32> loc(#loc243) + %y_idx_292 = tt.reshape %new_idxs_280 : tensor<32x16xi32> -> tensor<64x2x4xi32> loc(#loc244) + %left_idx_293 = arith.muli %y_idx_292, %ileft_139 : tensor<64x2x4xi32> loc(#loc246) + %left_idx_294 = "tt.reduce"(%left_idx_293) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc247)), %left_idx_714: i32 loc(callsite(#loc1 at #loc247))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc327) + tt.reduce.return %left_idx_715 : i32 loc(#loc312) + }) : (tensor<64x2x4xi32>) -> tensor<64x4xi32> loc(#loc312) + %left_idx_295 = tt.expand_dims %left_idx_294 {axis = 1 : i32} : tensor<64x4xi32> -> tensor<64x1x4xi32> loc(#loc248) + %left_idx_296 = tt.broadcast %left_idx_295 : tensor<64x1x4xi32> -> tensor<64x2x4xi32> loc(#loc249) + %right_idx_297 = arith.muli %y_idx_292, %flip_61 : tensor<64x2x4xi32> loc(#loc251) + %right_idx_298 = "tt.reduce"(%right_idx_297) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc252)), %right_idx_714: i32 loc(callsite(#loc1 at #loc252))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc328) + tt.reduce.return %right_idx_715 : i32 loc(#loc315) + }) : (tensor<64x2x4xi32>) -> tensor<64x4xi32> loc(#loc315) + %right_idx_299 = tt.expand_dims %right_idx_298 {axis = 1 : i32} : tensor<64x4xi32> -> tensor<64x1x4xi32> loc(#loc253) + %right_idx_300 = tt.broadcast %right_idx_299 : tensor<64x1x4xi32> -> tensor<64x2x4xi32> loc(#loc254) + %left_idx_301 = tt.reshape %left_idx_296 : tensor<64x2x4xi32> -> tensor<32x16xi32> loc(#loc255) + %right_idx_302 = tt.reshape %right_idx_300 : tensor<64x2x4xi32> -> tensor<32x16xi32> loc(#loc256) + %cond_303 = arith.cmpi slt, %ileft_290, %iright_291 : tensor<32x16xi32> loc(#loc257) + %eq_304 = arith.cmpi eq, %ileft_290, %iright_291 : tensor<32x16xi32> loc(#loc258) + %cond_305 = arith.cmpi sgt, %left_idx_301, %right_idx_302 : tensor<32x16xi32> loc(#loc259) + %cond_306 = arith.andi %eq_304, %cond_305 : tensor<32x16xi1> loc(#loc260) + %cond_307 = arith.ori %cond_303, %cond_306 : tensor<32x16xi1> loc(#loc261) + %ret_308 = arith.xori %ileft_290, %iright_291 : tensor<32x16xi32> loc(#loc264) + %ret_309 = arith.select %cond_307, %ret_308, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc265) + %ret_310 = arith.xori %ret_277, %ret_309 : tensor<32x16xi32> loc(#loc266) + %new_idxs_311 = arith.xori %left_idx_301, %right_idx_302 : tensor<32x16xi32> loc(#loc267) + %new_idxs_312 = arith.select %cond_307, %new_idxs_311, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc268) + %new_idxs_313 = arith.xori %new_idxs_280, %new_idxs_312 : tensor<32x16xi32> loc(#loc269) + %y_314 = tt.reshape %ret_310 : tensor<32x16xi32> -> tensor<128x2x2xi32> loc(#loc232) + %ileft_315 = arith.muli %y_314, %ileft_64 : tensor<128x2x2xi32> loc(#loc234) + %ileft_316 = "tt.reduce"(%ileft_315) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc235)), %ileft_714: i32 loc(callsite(#loc1 at #loc235))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc325) + tt.reduce.return %ileft_715 : i32 loc(#loc307) + }) : (tensor<128x2x2xi32>) -> tensor<128x2xi32> loc(#loc307) + %ileft_317 = tt.expand_dims %ileft_316 {axis = 1 : i32} : tensor<128x2xi32> -> tensor<128x1x2xi32> loc(#loc236) + %ileft_318 = tt.broadcast %ileft_317 : tensor<128x1x2xi32> -> tensor<128x2x2xi32> loc(#loc237) + %iright_319 = arith.muli %y_314, %flip_24 : tensor<128x2x2xi32> loc(#loc238) + %iright_320 = "tt.reduce"(%iright_319) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc239)), %iright_714: i32 loc(callsite(#loc1 at #loc239))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc326) + tt.reduce.return %iright_715 : i32 loc(#loc309) + }) : (tensor<128x2x2xi32>) -> tensor<128x2xi32> loc(#loc309) + %iright_321 = tt.expand_dims %iright_320 {axis = 1 : i32} : tensor<128x2xi32> -> tensor<128x1x2xi32> loc(#loc240) + %iright_322 = tt.broadcast %iright_321 : tensor<128x1x2xi32> -> tensor<128x2x2xi32> loc(#loc241) + %ileft_323 = tt.reshape %ileft_318 : tensor<128x2x2xi32> -> tensor<32x16xi32> loc(#loc242) + %iright_324 = tt.reshape %iright_322 : tensor<128x2x2xi32> -> tensor<32x16xi32> loc(#loc243) + %y_idx_325 = tt.reshape %new_idxs_313 : tensor<32x16xi32> -> tensor<128x2x2xi32> loc(#loc244) + %left_idx_326 = arith.muli %y_idx_325, %ileft_64 : tensor<128x2x2xi32> loc(#loc246) + %left_idx_327 = "tt.reduce"(%left_idx_326) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc247)), %left_idx_714: i32 loc(callsite(#loc1 at #loc247))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc327) + tt.reduce.return %left_idx_715 : i32 loc(#loc312) + }) : (tensor<128x2x2xi32>) -> tensor<128x2xi32> loc(#loc312) + %left_idx_328 = tt.expand_dims %left_idx_327 {axis = 1 : i32} : tensor<128x2xi32> -> tensor<128x1x2xi32> loc(#loc248) + %left_idx_329 = tt.broadcast %left_idx_328 : tensor<128x1x2xi32> -> tensor<128x2x2xi32> loc(#loc249) + %right_idx_330 = arith.muli %y_idx_325, %flip_24 : tensor<128x2x2xi32> loc(#loc251) + %right_idx_331 = "tt.reduce"(%right_idx_330) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc252)), %right_idx_714: i32 loc(callsite(#loc1 at #loc252))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc328) + tt.reduce.return %right_idx_715 : i32 loc(#loc315) + }) : (tensor<128x2x2xi32>) -> tensor<128x2xi32> loc(#loc315) + %right_idx_332 = tt.expand_dims %right_idx_331 {axis = 1 : i32} : tensor<128x2xi32> -> tensor<128x1x2xi32> loc(#loc253) + %right_idx_333 = tt.broadcast %right_idx_332 : tensor<128x1x2xi32> -> tensor<128x2x2xi32> loc(#loc254) + %left_idx_334 = tt.reshape %left_idx_329 : tensor<128x2x2xi32> -> tensor<32x16xi32> loc(#loc255) + %right_idx_335 = tt.reshape %right_idx_333 : tensor<128x2x2xi32> -> tensor<32x16xi32> loc(#loc256) + %cond_336 = arith.cmpi slt, %ileft_323, %iright_324 : tensor<32x16xi32> loc(#loc257) + %eq_337 = arith.cmpi eq, %ileft_323, %iright_324 : tensor<32x16xi32> loc(#loc258) + %cond_338 = arith.cmpi sgt, %left_idx_334, %right_idx_335 : tensor<32x16xi32> loc(#loc259) + %cond_339 = arith.andi %eq_337, %cond_338 : tensor<32x16xi1> loc(#loc260) + %cond_340 = arith.ori %cond_336, %cond_339 : tensor<32x16xi1> loc(#loc261) + %ret_341 = arith.xori %ileft_323, %iright_324 : tensor<32x16xi32> loc(#loc264) + %ret_342 = arith.select %cond_340, %ret_341, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc265) + %ret_343 = arith.xori %ret_310, %ret_342 : tensor<32x16xi32> loc(#loc266) + %new_idxs_344 = arith.xori %left_idx_334, %right_idx_335 : tensor<32x16xi32> loc(#loc267) + %new_idxs_345 = arith.select %cond_340, %new_idxs_344, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc268) + %new_idxs_346 = arith.xori %new_idxs_313, %new_idxs_345 : tensor<32x16xi32> loc(#loc269) + %y_347 = tt.reshape %ret_343 : tensor<32x16xi32> -> tensor<256x2x1xi32> loc(#loc232) + %ileft_348 = arith.muli %y_347, %ileft : tensor<256x2x1xi32> loc(#loc234) + %ileft_349 = "tt.reduce"(%ileft_348) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc235)), %ileft_714: i32 loc(callsite(#loc1 at #loc235))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc325) + tt.reduce.return %ileft_715 : i32 loc(#loc307) + }) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc307) + %ileft_350 = tt.expand_dims %ileft_349 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc236) + %ileft_351 = tt.broadcast %ileft_350 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc237) + %iright_352 = arith.muli %y_347, %iright : tensor<256x2x1xi32> loc(#loc238) + %iright_353 = "tt.reduce"(%iright_352) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc239)), %iright_714: i32 loc(callsite(#loc1 at #loc239))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc326) + tt.reduce.return %iright_715 : i32 loc(#loc309) + }) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc309) + %iright_354 = tt.expand_dims %iright_353 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc240) + %iright_355 = tt.broadcast %iright_354 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc241) + %ileft_356 = tt.reshape %ileft_351 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc242) + %iright_357 = tt.reshape %iright_355 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc243) + %y_idx_358 = tt.reshape %new_idxs_346 : tensor<32x16xi32> -> tensor<256x2x1xi32> loc(#loc244) + %left_idx_359 = arith.muli %y_idx_358, %ileft : tensor<256x2x1xi32> loc(#loc246) + %left_idx_360 = "tt.reduce"(%left_idx_359) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc247)), %left_idx_714: i32 loc(callsite(#loc1 at #loc247))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc327) + tt.reduce.return %left_idx_715 : i32 loc(#loc312) + }) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc312) + %left_idx_361 = tt.expand_dims %left_idx_360 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc248) + %left_idx_362 = tt.broadcast %left_idx_361 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc249) + %right_idx_363 = arith.muli %y_idx_358, %iright : tensor<256x2x1xi32> loc(#loc251) + %right_idx_364 = "tt.reduce"(%right_idx_363) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc252)), %right_idx_714: i32 loc(callsite(#loc1 at #loc252))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc328) + tt.reduce.return %right_idx_715 : i32 loc(#loc315) + }) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc315) + %right_idx_365 = tt.expand_dims %right_idx_364 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc253) + %right_idx_366 = tt.broadcast %right_idx_365 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc254) + %left_idx_367 = tt.reshape %left_idx_362 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc255) + %right_idx_368 = tt.reshape %right_idx_366 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc256) + %cond_369 = arith.cmpi slt, %ileft_356, %iright_357 : tensor<32x16xi32> loc(#loc257) + %eq_370 = arith.cmpi eq, %ileft_356, %iright_357 : tensor<32x16xi32> loc(#loc258) + %cond_371 = arith.cmpi sgt, %left_idx_367, %right_idx_368 : tensor<32x16xi32> loc(#loc259) + %cond_372 = arith.andi %eq_370, %cond_371 : tensor<32x16xi1> loc(#loc260) + %cond_373 = arith.ori %cond_369, %cond_372 : tensor<32x16xi1> loc(#loc261) + %new_idxs_374 = arith.xori %left_idx_367, %right_idx_368 : tensor<32x16xi32> loc(#loc267) + %new_idxs_375 = arith.select %cond_373, %new_idxs_374, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc268) + %new_idxs_376 = arith.xori %new_idxs_346, %new_idxs_375 : tensor<32x16xi32> loc(#loc269) + %tmp14 = arith.cmpi eq, %tmp0_21, %cst_6 : tensor<32x16xi64> loc(#loc192) + %tmp16 = arith.extui %tmp14 : tensor<32x16xi1> to tensor<32x16xi32> loc(#loc224) + %y_377 = tt.reshape %tmp16 : tensor<32x16xi32> -> tensor<256x2x1xi32> loc(#loc270) + %ileft_378 = arith.muli %y_377, %ileft : tensor<256x2x1xi32> loc(#loc271) + %ileft_379 = "tt.reduce"(%ileft_378) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc272)), %ileft_714: i32 loc(callsite(#loc1 at #loc272))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc329) + tt.reduce.return %ileft_715 : i32 loc(#loc317) + }) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc317) + %ileft_380 = tt.expand_dims %ileft_379 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc273) + %ileft_381 = tt.broadcast %ileft_380 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc274) + %iright_382 = arith.muli %y_377, %iright : tensor<256x2x1xi32> loc(#loc275) + %iright_383 = "tt.reduce"(%iright_382) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc276)), %iright_714: i32 loc(callsite(#loc1 at #loc276))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc330) + tt.reduce.return %iright_715 : i32 loc(#loc319) + }) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc319) + %iright_384 = tt.expand_dims %iright_383 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc277) + %iright_385 = tt.broadcast %iright_384 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc278) + %ileft_386 = tt.reshape %ileft_381 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc279) + %iright_387 = tt.reshape %iright_385 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc280) + %cond_388 = arith.cmpi slt, %ileft_386, %iright_387 : tensor<32x16xi32> loc(#loc281) + %eq_389 = arith.cmpi eq, %ileft_386, %iright_387 : tensor<32x16xi32> loc(#loc282) + %cond_390 = arith.andi %eq_389, %cond_49 : tensor<32x16xi1> loc(#loc283) + %cond_391 = arith.ori %cond_388, %cond_390 : tensor<32x16xi1> loc(#loc284) + %cond_392 = arith.extui %cond_391 : tensor<32x16xi1> to tensor<32x16xi32> loc(#loc285) + %cond_393 = arith.xori %cond_392, %flip_25 : tensor<32x16xi32> loc(#loc285) + %cond_394 = arith.cmpi ne, %cond_393, %cst_3 : tensor<32x16xi32> loc(#loc286) + %ret_395 = arith.xori %ileft_386, %iright_387 : tensor<32x16xi32> loc(#loc287) + %ret_396 = arith.select %cond_394, %ret_395, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc288) + %ret_397 = arith.xori %tmp16, %ret_396 : tensor<32x16xi32> loc(#loc289) + %new_idxs_398 = arith.select %cond_394, %new_idxs, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc290) + %new_idxs_399 = arith.xori %new_idxs_59, %new_idxs_398 : tensor<32x16xi32> loc(#loc291) + %y_400 = tt.reshape %ret_397 : tensor<32x16xi32> -> tensor<128x2x2xi32> loc(#loc270) + %ileft_401 = arith.muli %y_400, %ileft_64 : tensor<128x2x2xi32> loc(#loc271) + %ileft_402 = "tt.reduce"(%ileft_401) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc272)), %ileft_714: i32 loc(callsite(#loc1 at #loc272))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc329) + tt.reduce.return %ileft_715 : i32 loc(#loc317) + }) : (tensor<128x2x2xi32>) -> tensor<128x2xi32> loc(#loc317) + %ileft_403 = tt.expand_dims %ileft_402 {axis = 1 : i32} : tensor<128x2xi32> -> tensor<128x1x2xi32> loc(#loc273) + %ileft_404 = tt.broadcast %ileft_403 : tensor<128x1x2xi32> -> tensor<128x2x2xi32> loc(#loc274) + %iright_405 = arith.muli %y_400, %flip_24 : tensor<128x2x2xi32> loc(#loc275) + %iright_406 = "tt.reduce"(%iright_405) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc276)), %iright_714: i32 loc(callsite(#loc1 at #loc276))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc330) + tt.reduce.return %iright_715 : i32 loc(#loc319) + }) : (tensor<128x2x2xi32>) -> tensor<128x2xi32> loc(#loc319) + %iright_407 = tt.expand_dims %iright_406 {axis = 1 : i32} : tensor<128x2xi32> -> tensor<128x1x2xi32> loc(#loc277) + %iright_408 = tt.broadcast %iright_407 : tensor<128x1x2xi32> -> tensor<128x2x2xi32> loc(#loc278) + %ileft_409 = tt.reshape %ileft_404 : tensor<128x2x2xi32> -> tensor<32x16xi32> loc(#loc279) + %iright_410 = tt.reshape %iright_408 : tensor<128x2x2xi32> -> tensor<32x16xi32> loc(#loc280) + %y_idx_411 = tt.reshape %new_idxs_399 : tensor<32x16xi32> -> tensor<128x2x2xi32> loc(#loc292) + %left_idx_412 = arith.muli %y_idx_411, %ileft_64 : tensor<128x2x2xi32> loc(#loc293) + %left_idx_413 = "tt.reduce"(%left_idx_412) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc294)), %left_idx_714: i32 loc(callsite(#loc1 at #loc294))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc331) + tt.reduce.return %left_idx_715 : i32 loc(#loc321) + }) : (tensor<128x2x2xi32>) -> tensor<128x2xi32> loc(#loc321) + %left_idx_414 = tt.expand_dims %left_idx_413 {axis = 1 : i32} : tensor<128x2xi32> -> tensor<128x1x2xi32> loc(#loc295) + %left_idx_415 = tt.broadcast %left_idx_414 : tensor<128x1x2xi32> -> tensor<128x2x2xi32> loc(#loc296) + %right_idx_416 = arith.muli %y_idx_411, %flip_24 : tensor<128x2x2xi32> loc(#loc297) + %right_idx_417 = "tt.reduce"(%right_idx_416) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc298)), %right_idx_714: i32 loc(callsite(#loc1 at #loc298))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc332) + tt.reduce.return %right_idx_715 : i32 loc(#loc323) + }) : (tensor<128x2x2xi32>) -> tensor<128x2xi32> loc(#loc323) + %right_idx_418 = tt.expand_dims %right_idx_417 {axis = 1 : i32} : tensor<128x2xi32> -> tensor<128x1x2xi32> loc(#loc299) + %right_idx_419 = tt.broadcast %right_idx_418 : tensor<128x1x2xi32> -> tensor<128x2x2xi32> loc(#loc300) + %left_idx_420 = tt.reshape %left_idx_415 : tensor<128x2x2xi32> -> tensor<32x16xi32> loc(#loc301) + %right_idx_421 = tt.reshape %right_idx_419 : tensor<128x2x2xi32> -> tensor<32x16xi32> loc(#loc302) + %cond_422 = arith.cmpi slt, %ileft_409, %iright_410 : tensor<32x16xi32> loc(#loc281) + %eq_423 = arith.cmpi eq, %ileft_409, %iright_410 : tensor<32x16xi32> loc(#loc282) + %cond_424 = arith.cmpi sgt, %left_idx_420, %right_idx_421 : tensor<32x16xi32> loc(#loc303) + %cond_425 = arith.andi %eq_423, %cond_424 : tensor<32x16xi1> loc(#loc283) + %cond_426 = arith.ori %cond_422, %cond_425 : tensor<32x16xi1> loc(#loc284) + %cond_427 = arith.extui %cond_426 : tensor<32x16xi1> to tensor<32x16xi32> loc(#loc285) + %cond_428 = arith.xori %cond_427, %flip_62 : tensor<32x16xi32> loc(#loc285) + %cond_429 = arith.cmpi ne, %cond_428, %cst_3 : tensor<32x16xi32> loc(#loc286) + %ret_430 = arith.xori %ileft_409, %iright_410 : tensor<32x16xi32> loc(#loc287) + %ret_431 = arith.select %cond_429, %ret_430, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc288) + %ret_432 = arith.xori %ret_397, %ret_431 : tensor<32x16xi32> loc(#loc289) + %new_idxs_433 = arith.xori %left_idx_420, %right_idx_421 : tensor<32x16xi32> loc(#loc304) + %new_idxs_434 = arith.select %cond_429, %new_idxs_433, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc290) + %new_idxs_435 = arith.xori %new_idxs_399, %new_idxs_434 : tensor<32x16xi32> loc(#loc291) + %y_436 = tt.reshape %ret_432 : tensor<32x16xi32> -> tensor<256x2x1xi32> loc(#loc270) + %ileft_437 = arith.muli %y_436, %ileft : tensor<256x2x1xi32> loc(#loc271) + %ileft_438 = "tt.reduce"(%ileft_437) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc272)), %ileft_714: i32 loc(callsite(#loc1 at #loc272))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc329) + tt.reduce.return %ileft_715 : i32 loc(#loc317) + }) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc317) + %ileft_439 = tt.expand_dims %ileft_438 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc273) + %ileft_440 = tt.broadcast %ileft_439 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc274) + %iright_441 = arith.muli %y_436, %iright : tensor<256x2x1xi32> loc(#loc275) + %iright_442 = "tt.reduce"(%iright_441) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc276)), %iright_714: i32 loc(callsite(#loc1 at #loc276))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc330) + tt.reduce.return %iright_715 : i32 loc(#loc319) + }) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc319) + %iright_443 = tt.expand_dims %iright_442 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc277) + %iright_444 = tt.broadcast %iright_443 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc278) + %ileft_445 = tt.reshape %ileft_440 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc279) + %iright_446 = tt.reshape %iright_444 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc280) + %y_idx_447 = tt.reshape %new_idxs_435 : tensor<32x16xi32> -> tensor<256x2x1xi32> loc(#loc292) + %left_idx_448 = arith.muli %y_idx_447, %ileft : tensor<256x2x1xi32> loc(#loc293) + %left_idx_449 = "tt.reduce"(%left_idx_448) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc294)), %left_idx_714: i32 loc(callsite(#loc1 at #loc294))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc331) + tt.reduce.return %left_idx_715 : i32 loc(#loc321) + }) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc321) + %left_idx_450 = tt.expand_dims %left_idx_449 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc295) + %left_idx_451 = tt.broadcast %left_idx_450 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc296) + %right_idx_452 = arith.muli %y_idx_447, %iright : tensor<256x2x1xi32> loc(#loc297) + %right_idx_453 = "tt.reduce"(%right_idx_452) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc298)), %right_idx_714: i32 loc(callsite(#loc1 at #loc298))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc332) + tt.reduce.return %right_idx_715 : i32 loc(#loc323) + }) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc323) + %right_idx_454 = tt.expand_dims %right_idx_453 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc299) + %right_idx_455 = tt.broadcast %right_idx_454 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc300) + %left_idx_456 = tt.reshape %left_idx_451 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc301) + %right_idx_457 = tt.reshape %right_idx_455 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc302) + %cond_458 = arith.cmpi slt, %ileft_445, %iright_446 : tensor<32x16xi32> loc(#loc281) + %eq_459 = arith.cmpi eq, %ileft_445, %iright_446 : tensor<32x16xi32> loc(#loc282) + %cond_460 = arith.cmpi sgt, %left_idx_456, %right_idx_457 : tensor<32x16xi32> loc(#loc303) + %cond_461 = arith.andi %eq_459, %cond_460 : tensor<32x16xi1> loc(#loc283) + %cond_462 = arith.ori %cond_458, %cond_461 : tensor<32x16xi1> loc(#loc284) + %cond_463 = arith.extui %cond_462 : tensor<32x16xi1> to tensor<32x16xi32> loc(#loc285) + %cond_464 = arith.xori %cond_463, %flip_62 : tensor<32x16xi32> loc(#loc285) + %cond_465 = arith.cmpi ne, %cond_464, %cst_3 : tensor<32x16xi32> loc(#loc286) + %ret_466 = arith.xori %ileft_445, %iright_446 : tensor<32x16xi32> loc(#loc287) + %ret_467 = arith.select %cond_465, %ret_466, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc288) + %ret_468 = arith.xori %ret_432, %ret_467 : tensor<32x16xi32> loc(#loc289) + %new_idxs_469 = arith.xori %left_idx_456, %right_idx_457 : tensor<32x16xi32> loc(#loc304) + %new_idxs_470 = arith.select %cond_465, %new_idxs_469, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc290) + %new_idxs_471 = arith.xori %new_idxs_435, %new_idxs_470 : tensor<32x16xi32> loc(#loc291) + %y_472 = tt.reshape %ret_468 : tensor<32x16xi32> -> tensor<64x2x4xi32> loc(#loc270) + %ileft_473 = arith.muli %y_472, %ileft_139 : tensor<64x2x4xi32> loc(#loc271) + %ileft_474 = "tt.reduce"(%ileft_473) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc272)), %ileft_714: i32 loc(callsite(#loc1 at #loc272))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc329) + tt.reduce.return %ileft_715 : i32 loc(#loc317) + }) : (tensor<64x2x4xi32>) -> tensor<64x4xi32> loc(#loc317) + %ileft_475 = tt.expand_dims %ileft_474 {axis = 1 : i32} : tensor<64x4xi32> -> tensor<64x1x4xi32> loc(#loc273) + %ileft_476 = tt.broadcast %ileft_475 : tensor<64x1x4xi32> -> tensor<64x2x4xi32> loc(#loc274) + %iright_477 = arith.muli %y_472, %flip_61 : tensor<64x2x4xi32> loc(#loc275) + %iright_478 = "tt.reduce"(%iright_477) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc276)), %iright_714: i32 loc(callsite(#loc1 at #loc276))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc330) + tt.reduce.return %iright_715 : i32 loc(#loc319) + }) : (tensor<64x2x4xi32>) -> tensor<64x4xi32> loc(#loc319) + %iright_479 = tt.expand_dims %iright_478 {axis = 1 : i32} : tensor<64x4xi32> -> tensor<64x1x4xi32> loc(#loc277) + %iright_480 = tt.broadcast %iright_479 : tensor<64x1x4xi32> -> tensor<64x2x4xi32> loc(#loc278) + %ileft_481 = tt.reshape %ileft_476 : tensor<64x2x4xi32> -> tensor<32x16xi32> loc(#loc279) + %iright_482 = tt.reshape %iright_480 : tensor<64x2x4xi32> -> tensor<32x16xi32> loc(#loc280) + %y_idx_483 = tt.reshape %new_idxs_471 : tensor<32x16xi32> -> tensor<64x2x4xi32> loc(#loc292) + %left_idx_484 = arith.muli %y_idx_483, %ileft_139 : tensor<64x2x4xi32> loc(#loc293) + %left_idx_485 = "tt.reduce"(%left_idx_484) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc294)), %left_idx_714: i32 loc(callsite(#loc1 at #loc294))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc331) + tt.reduce.return %left_idx_715 : i32 loc(#loc321) + }) : (tensor<64x2x4xi32>) -> tensor<64x4xi32> loc(#loc321) + %left_idx_486 = tt.expand_dims %left_idx_485 {axis = 1 : i32} : tensor<64x4xi32> -> tensor<64x1x4xi32> loc(#loc295) + %left_idx_487 = tt.broadcast %left_idx_486 : tensor<64x1x4xi32> -> tensor<64x2x4xi32> loc(#loc296) + %right_idx_488 = arith.muli %y_idx_483, %flip_61 : tensor<64x2x4xi32> loc(#loc297) + %right_idx_489 = "tt.reduce"(%right_idx_488) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc298)), %right_idx_714: i32 loc(callsite(#loc1 at #loc298))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc332) + tt.reduce.return %right_idx_715 : i32 loc(#loc323) + }) : (tensor<64x2x4xi32>) -> tensor<64x4xi32> loc(#loc323) + %right_idx_490 = tt.expand_dims %right_idx_489 {axis = 1 : i32} : tensor<64x4xi32> -> tensor<64x1x4xi32> loc(#loc299) + %right_idx_491 = tt.broadcast %right_idx_490 : tensor<64x1x4xi32> -> tensor<64x2x4xi32> loc(#loc300) + %left_idx_492 = tt.reshape %left_idx_487 : tensor<64x2x4xi32> -> tensor<32x16xi32> loc(#loc301) + %right_idx_493 = tt.reshape %right_idx_491 : tensor<64x2x4xi32> -> tensor<32x16xi32> loc(#loc302) + %cond_494 = arith.cmpi slt, %ileft_481, %iright_482 : tensor<32x16xi32> loc(#loc281) + %eq_495 = arith.cmpi eq, %ileft_481, %iright_482 : tensor<32x16xi32> loc(#loc282) + %cond_496 = arith.cmpi sgt, %left_idx_492, %right_idx_493 : tensor<32x16xi32> loc(#loc303) + %cond_497 = arith.andi %eq_495, %cond_496 : tensor<32x16xi1> loc(#loc283) + %cond_498 = arith.ori %cond_494, %cond_497 : tensor<32x16xi1> loc(#loc284) + %cond_499 = arith.extui %cond_498 : tensor<32x16xi1> to tensor<32x16xi32> loc(#loc285) + %cond_500 = arith.xori %cond_499, %flip_137 : tensor<32x16xi32> loc(#loc285) + %cond_501 = arith.cmpi ne, %cond_500, %cst_3 : tensor<32x16xi32> loc(#loc286) + %ret_502 = arith.xori %ileft_481, %iright_482 : tensor<32x16xi32> loc(#loc287) + %ret_503 = arith.select %cond_501, %ret_502, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc288) + %ret_504 = arith.xori %ret_468, %ret_503 : tensor<32x16xi32> loc(#loc289) + %new_idxs_505 = arith.xori %left_idx_492, %right_idx_493 : tensor<32x16xi32> loc(#loc304) + %new_idxs_506 = arith.select %cond_501, %new_idxs_505, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc290) + %new_idxs_507 = arith.xori %new_idxs_471, %new_idxs_506 : tensor<32x16xi32> loc(#loc291) + %y_508 = tt.reshape %ret_504 : tensor<32x16xi32> -> tensor<128x2x2xi32> loc(#loc270) + %ileft_509 = arith.muli %y_508, %ileft_64 : tensor<128x2x2xi32> loc(#loc271) + %ileft_510 = "tt.reduce"(%ileft_509) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc272)), %ileft_714: i32 loc(callsite(#loc1 at #loc272))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc329) + tt.reduce.return %ileft_715 : i32 loc(#loc317) + }) : (tensor<128x2x2xi32>) -> tensor<128x2xi32> loc(#loc317) + %ileft_511 = tt.expand_dims %ileft_510 {axis = 1 : i32} : tensor<128x2xi32> -> tensor<128x1x2xi32> loc(#loc273) + %ileft_512 = tt.broadcast %ileft_511 : tensor<128x1x2xi32> -> tensor<128x2x2xi32> loc(#loc274) + %iright_513 = arith.muli %y_508, %flip_24 : tensor<128x2x2xi32> loc(#loc275) + %iright_514 = "tt.reduce"(%iright_513) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc276)), %iright_714: i32 loc(callsite(#loc1 at #loc276))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc330) + tt.reduce.return %iright_715 : i32 loc(#loc319) + }) : (tensor<128x2x2xi32>) -> tensor<128x2xi32> loc(#loc319) + %iright_515 = tt.expand_dims %iright_514 {axis = 1 : i32} : tensor<128x2xi32> -> tensor<128x1x2xi32> loc(#loc277) + %iright_516 = tt.broadcast %iright_515 : tensor<128x1x2xi32> -> tensor<128x2x2xi32> loc(#loc278) + %ileft_517 = tt.reshape %ileft_512 : tensor<128x2x2xi32> -> tensor<32x16xi32> loc(#loc279) + %iright_518 = tt.reshape %iright_516 : tensor<128x2x2xi32> -> tensor<32x16xi32> loc(#loc280) + %y_idx_519 = tt.reshape %new_idxs_507 : tensor<32x16xi32> -> tensor<128x2x2xi32> loc(#loc292) + %left_idx_520 = arith.muli %y_idx_519, %ileft_64 : tensor<128x2x2xi32> loc(#loc293) + %left_idx_521 = "tt.reduce"(%left_idx_520) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc294)), %left_idx_714: i32 loc(callsite(#loc1 at #loc294))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc331) + tt.reduce.return %left_idx_715 : i32 loc(#loc321) + }) : (tensor<128x2x2xi32>) -> tensor<128x2xi32> loc(#loc321) + %left_idx_522 = tt.expand_dims %left_idx_521 {axis = 1 : i32} : tensor<128x2xi32> -> tensor<128x1x2xi32> loc(#loc295) + %left_idx_523 = tt.broadcast %left_idx_522 : tensor<128x1x2xi32> -> tensor<128x2x2xi32> loc(#loc296) + %right_idx_524 = arith.muli %y_idx_519, %flip_24 : tensor<128x2x2xi32> loc(#loc297) + %right_idx_525 = "tt.reduce"(%right_idx_524) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc298)), %right_idx_714: i32 loc(callsite(#loc1 at #loc298))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc332) + tt.reduce.return %right_idx_715 : i32 loc(#loc323) + }) : (tensor<128x2x2xi32>) -> tensor<128x2xi32> loc(#loc323) + %right_idx_526 = tt.expand_dims %right_idx_525 {axis = 1 : i32} : tensor<128x2xi32> -> tensor<128x1x2xi32> loc(#loc299) + %right_idx_527 = tt.broadcast %right_idx_526 : tensor<128x1x2xi32> -> tensor<128x2x2xi32> loc(#loc300) + %left_idx_528 = tt.reshape %left_idx_523 : tensor<128x2x2xi32> -> tensor<32x16xi32> loc(#loc301) + %right_idx_529 = tt.reshape %right_idx_527 : tensor<128x2x2xi32> -> tensor<32x16xi32> loc(#loc302) + %cond_530 = arith.cmpi slt, %ileft_517, %iright_518 : tensor<32x16xi32> loc(#loc281) + %eq_531 = arith.cmpi eq, %ileft_517, %iright_518 : tensor<32x16xi32> loc(#loc282) + %cond_532 = arith.cmpi sgt, %left_idx_528, %right_idx_529 : tensor<32x16xi32> loc(#loc303) + %cond_533 = arith.andi %eq_531, %cond_532 : tensor<32x16xi1> loc(#loc283) + %cond_534 = arith.ori %cond_530, %cond_533 : tensor<32x16xi1> loc(#loc284) + %cond_535 = arith.extui %cond_534 : tensor<32x16xi1> to tensor<32x16xi32> loc(#loc285) + %cond_536 = arith.xori %cond_535, %flip_137 : tensor<32x16xi32> loc(#loc285) + %cond_537 = arith.cmpi ne, %cond_536, %cst_3 : tensor<32x16xi32> loc(#loc286) + %ret_538 = arith.xori %ileft_517, %iright_518 : tensor<32x16xi32> loc(#loc287) + %ret_539 = arith.select %cond_537, %ret_538, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc288) + %ret_540 = arith.xori %ret_504, %ret_539 : tensor<32x16xi32> loc(#loc289) + %new_idxs_541 = arith.xori %left_idx_528, %right_idx_529 : tensor<32x16xi32> loc(#loc304) + %new_idxs_542 = arith.select %cond_537, %new_idxs_541, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc290) + %new_idxs_543 = arith.xori %new_idxs_507, %new_idxs_542 : tensor<32x16xi32> loc(#loc291) + %y_544 = tt.reshape %ret_540 : tensor<32x16xi32> -> tensor<256x2x1xi32> loc(#loc270) + %ileft_545 = arith.muli %y_544, %ileft : tensor<256x2x1xi32> loc(#loc271) + %ileft_546 = "tt.reduce"(%ileft_545) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc272)), %ileft_714: i32 loc(callsite(#loc1 at #loc272))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc329) + tt.reduce.return %ileft_715 : i32 loc(#loc317) + }) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc317) + %ileft_547 = tt.expand_dims %ileft_546 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc273) + %ileft_548 = tt.broadcast %ileft_547 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc274) + %iright_549 = arith.muli %y_544, %iright : tensor<256x2x1xi32> loc(#loc275) + %iright_550 = "tt.reduce"(%iright_549) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc276)), %iright_714: i32 loc(callsite(#loc1 at #loc276))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc330) + tt.reduce.return %iright_715 : i32 loc(#loc319) + }) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc319) + %iright_551 = tt.expand_dims %iright_550 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc277) + %iright_552 = tt.broadcast %iright_551 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc278) + %ileft_553 = tt.reshape %ileft_548 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc279) + %iright_554 = tt.reshape %iright_552 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc280) + %y_idx_555 = tt.reshape %new_idxs_543 : tensor<32x16xi32> -> tensor<256x2x1xi32> loc(#loc292) + %left_idx_556 = arith.muli %y_idx_555, %ileft : tensor<256x2x1xi32> loc(#loc293) + %left_idx_557 = "tt.reduce"(%left_idx_556) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc294)), %left_idx_714: i32 loc(callsite(#loc1 at #loc294))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc331) + tt.reduce.return %left_idx_715 : i32 loc(#loc321) + }) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc321) + %left_idx_558 = tt.expand_dims %left_idx_557 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc295) + %left_idx_559 = tt.broadcast %left_idx_558 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc296) + %right_idx_560 = arith.muli %y_idx_555, %iright : tensor<256x2x1xi32> loc(#loc297) + %right_idx_561 = "tt.reduce"(%right_idx_560) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc298)), %right_idx_714: i32 loc(callsite(#loc1 at #loc298))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc332) + tt.reduce.return %right_idx_715 : i32 loc(#loc323) + }) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc323) + %right_idx_562 = tt.expand_dims %right_idx_561 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc299) + %right_idx_563 = tt.broadcast %right_idx_562 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc300) + %left_idx_564 = tt.reshape %left_idx_559 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc301) + %right_idx_565 = tt.reshape %right_idx_563 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc302) + %cond_566 = arith.cmpi slt, %ileft_553, %iright_554 : tensor<32x16xi32> loc(#loc281) + %eq_567 = arith.cmpi eq, %ileft_553, %iright_554 : tensor<32x16xi32> loc(#loc282) + %cond_568 = arith.cmpi sgt, %left_idx_564, %right_idx_565 : tensor<32x16xi32> loc(#loc303) + %cond_569 = arith.andi %eq_567, %cond_568 : tensor<32x16xi1> loc(#loc283) + %cond_570 = arith.ori %cond_566, %cond_569 : tensor<32x16xi1> loc(#loc284) + %cond_571 = arith.extui %cond_570 : tensor<32x16xi1> to tensor<32x16xi32> loc(#loc285) + %cond_572 = arith.xori %cond_571, %flip_137 : tensor<32x16xi32> loc(#loc285) + %cond_573 = arith.cmpi ne, %cond_572, %cst_3 : tensor<32x16xi32> loc(#loc286) + %ret_574 = arith.xori %ileft_553, %iright_554 : tensor<32x16xi32> loc(#loc287) + %ret_575 = arith.select %cond_573, %ret_574, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc288) + %ret_576 = arith.xori %ret_540, %ret_575 : tensor<32x16xi32> loc(#loc289) + %new_idxs_577 = arith.xori %left_idx_564, %right_idx_565 : tensor<32x16xi32> loc(#loc304) + %new_idxs_578 = arith.select %cond_573, %new_idxs_577, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc290) + %new_idxs_579 = arith.xori %new_idxs_543, %new_idxs_578 : tensor<32x16xi32> loc(#loc291) + %y_580 = tt.reshape %ret_576 : tensor<32x16xi32> -> tensor<32x2x8xi32> loc(#loc270) + %ileft_581 = arith.muli %y_580, %ileft_248 : tensor<32x2x8xi32> loc(#loc271) + %ileft_582 = "tt.reduce"(%ileft_581) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc272)), %ileft_714: i32 loc(callsite(#loc1 at #loc272))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc329) + tt.reduce.return %ileft_715 : i32 loc(#loc317) + }) : (tensor<32x2x8xi32>) -> tensor<32x8xi32> loc(#loc317) + %ileft_583 = tt.expand_dims %ileft_582 {axis = 1 : i32} : tensor<32x8xi32> -> tensor<32x1x8xi32> loc(#loc273) + %ileft_584 = tt.broadcast %ileft_583 : tensor<32x1x8xi32> -> tensor<32x2x8xi32> loc(#loc274) + %iright_585 = arith.muli %y_580, %flip_136 : tensor<32x2x8xi32> loc(#loc275) + %iright_586 = "tt.reduce"(%iright_585) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc276)), %iright_714: i32 loc(callsite(#loc1 at #loc276))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc330) + tt.reduce.return %iright_715 : i32 loc(#loc319) + }) : (tensor<32x2x8xi32>) -> tensor<32x8xi32> loc(#loc319) + %iright_587 = tt.expand_dims %iright_586 {axis = 1 : i32} : tensor<32x8xi32> -> tensor<32x1x8xi32> loc(#loc277) + %iright_588 = tt.broadcast %iright_587 : tensor<32x1x8xi32> -> tensor<32x2x8xi32> loc(#loc278) + %ileft_589 = tt.reshape %ileft_584 : tensor<32x2x8xi32> -> tensor<32x16xi32> loc(#loc279) + %iright_590 = tt.reshape %iright_588 : tensor<32x2x8xi32> -> tensor<32x16xi32> loc(#loc280) + %y_idx_591 = tt.reshape %new_idxs_579 : tensor<32x16xi32> -> tensor<32x2x8xi32> loc(#loc292) + %left_idx_592 = arith.muli %y_idx_591, %ileft_248 : tensor<32x2x8xi32> loc(#loc293) + %left_idx_593 = "tt.reduce"(%left_idx_592) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc294)), %left_idx_714: i32 loc(callsite(#loc1 at #loc294))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc331) + tt.reduce.return %left_idx_715 : i32 loc(#loc321) + }) : (tensor<32x2x8xi32>) -> tensor<32x8xi32> loc(#loc321) + %left_idx_594 = tt.expand_dims %left_idx_593 {axis = 1 : i32} : tensor<32x8xi32> -> tensor<32x1x8xi32> loc(#loc295) + %left_idx_595 = tt.broadcast %left_idx_594 : tensor<32x1x8xi32> -> tensor<32x2x8xi32> loc(#loc296) + %right_idx_596 = arith.muli %y_idx_591, %flip_136 : tensor<32x2x8xi32> loc(#loc297) + %right_idx_597 = "tt.reduce"(%right_idx_596) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc298)), %right_idx_714: i32 loc(callsite(#loc1 at #loc298))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc332) + tt.reduce.return %right_idx_715 : i32 loc(#loc323) + }) : (tensor<32x2x8xi32>) -> tensor<32x8xi32> loc(#loc323) + %right_idx_598 = tt.expand_dims %right_idx_597 {axis = 1 : i32} : tensor<32x8xi32> -> tensor<32x1x8xi32> loc(#loc299) + %right_idx_599 = tt.broadcast %right_idx_598 : tensor<32x1x8xi32> -> tensor<32x2x8xi32> loc(#loc300) + %left_idx_600 = tt.reshape %left_idx_595 : tensor<32x2x8xi32> -> tensor<32x16xi32> loc(#loc301) + %right_idx_601 = tt.reshape %right_idx_599 : tensor<32x2x8xi32> -> tensor<32x16xi32> loc(#loc302) + %cond_602 = arith.cmpi slt, %ileft_589, %iright_590 : tensor<32x16xi32> loc(#loc281) + %eq_603 = arith.cmpi eq, %ileft_589, %iright_590 : tensor<32x16xi32> loc(#loc282) + %cond_604 = arith.cmpi sgt, %left_idx_600, %right_idx_601 : tensor<32x16xi32> loc(#loc303) + %cond_605 = arith.andi %eq_603, %cond_604 : tensor<32x16xi1> loc(#loc283) + %cond_606 = arith.ori %cond_602, %cond_605 : tensor<32x16xi1> loc(#loc284) + %ret_607 = arith.xori %ileft_589, %iright_590 : tensor<32x16xi32> loc(#loc287) + %ret_608 = arith.select %cond_606, %ret_607, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc288) + %ret_609 = arith.xori %ret_576, %ret_608 : tensor<32x16xi32> loc(#loc289) + %new_idxs_610 = arith.xori %left_idx_600, %right_idx_601 : tensor<32x16xi32> loc(#loc304) + %new_idxs_611 = arith.select %cond_606, %new_idxs_610, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc290) + %new_idxs_612 = arith.xori %new_idxs_579, %new_idxs_611 : tensor<32x16xi32> loc(#loc291) + %y_613 = tt.reshape %ret_609 : tensor<32x16xi32> -> tensor<64x2x4xi32> loc(#loc270) + %ileft_614 = arith.muli %y_613, %ileft_139 : tensor<64x2x4xi32> loc(#loc271) + %ileft_615 = "tt.reduce"(%ileft_614) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc272)), %ileft_714: i32 loc(callsite(#loc1 at #loc272))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc329) + tt.reduce.return %ileft_715 : i32 loc(#loc317) + }) : (tensor<64x2x4xi32>) -> tensor<64x4xi32> loc(#loc317) + %ileft_616 = tt.expand_dims %ileft_615 {axis = 1 : i32} : tensor<64x4xi32> -> tensor<64x1x4xi32> loc(#loc273) + %ileft_617 = tt.broadcast %ileft_616 : tensor<64x1x4xi32> -> tensor<64x2x4xi32> loc(#loc274) + %iright_618 = arith.muli %y_613, %flip_61 : tensor<64x2x4xi32> loc(#loc275) + %iright_619 = "tt.reduce"(%iright_618) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc276)), %iright_714: i32 loc(callsite(#loc1 at #loc276))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc330) + tt.reduce.return %iright_715 : i32 loc(#loc319) + }) : (tensor<64x2x4xi32>) -> tensor<64x4xi32> loc(#loc319) + %iright_620 = tt.expand_dims %iright_619 {axis = 1 : i32} : tensor<64x4xi32> -> tensor<64x1x4xi32> loc(#loc277) + %iright_621 = tt.broadcast %iright_620 : tensor<64x1x4xi32> -> tensor<64x2x4xi32> loc(#loc278) + %ileft_622 = tt.reshape %ileft_617 : tensor<64x2x4xi32> -> tensor<32x16xi32> loc(#loc279) + %iright_623 = tt.reshape %iright_621 : tensor<64x2x4xi32> -> tensor<32x16xi32> loc(#loc280) + %y_idx_624 = tt.reshape %new_idxs_612 : tensor<32x16xi32> -> tensor<64x2x4xi32> loc(#loc292) + %left_idx_625 = arith.muli %y_idx_624, %ileft_139 : tensor<64x2x4xi32> loc(#loc293) + %left_idx_626 = "tt.reduce"(%left_idx_625) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc294)), %left_idx_714: i32 loc(callsite(#loc1 at #loc294))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc331) + tt.reduce.return %left_idx_715 : i32 loc(#loc321) + }) : (tensor<64x2x4xi32>) -> tensor<64x4xi32> loc(#loc321) + %left_idx_627 = tt.expand_dims %left_idx_626 {axis = 1 : i32} : tensor<64x4xi32> -> tensor<64x1x4xi32> loc(#loc295) + %left_idx_628 = tt.broadcast %left_idx_627 : tensor<64x1x4xi32> -> tensor<64x2x4xi32> loc(#loc296) + %right_idx_629 = arith.muli %y_idx_624, %flip_61 : tensor<64x2x4xi32> loc(#loc297) + %right_idx_630 = "tt.reduce"(%right_idx_629) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc298)), %right_idx_714: i32 loc(callsite(#loc1 at #loc298))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc332) + tt.reduce.return %right_idx_715 : i32 loc(#loc323) + }) : (tensor<64x2x4xi32>) -> tensor<64x4xi32> loc(#loc323) + %right_idx_631 = tt.expand_dims %right_idx_630 {axis = 1 : i32} : tensor<64x4xi32> -> tensor<64x1x4xi32> loc(#loc299) + %right_idx_632 = tt.broadcast %right_idx_631 : tensor<64x1x4xi32> -> tensor<64x2x4xi32> loc(#loc300) + %left_idx_633 = tt.reshape %left_idx_628 : tensor<64x2x4xi32> -> tensor<32x16xi32> loc(#loc301) + %right_idx_634 = tt.reshape %right_idx_632 : tensor<64x2x4xi32> -> tensor<32x16xi32> loc(#loc302) + %cond_635 = arith.cmpi slt, %ileft_622, %iright_623 : tensor<32x16xi32> loc(#loc281) + %eq_636 = arith.cmpi eq, %ileft_622, %iright_623 : tensor<32x16xi32> loc(#loc282) + %cond_637 = arith.cmpi sgt, %left_idx_633, %right_idx_634 : tensor<32x16xi32> loc(#loc303) + %cond_638 = arith.andi %eq_636, %cond_637 : tensor<32x16xi1> loc(#loc283) + %cond_639 = arith.ori %cond_635, %cond_638 : tensor<32x16xi1> loc(#loc284) + %ret_640 = arith.xori %ileft_622, %iright_623 : tensor<32x16xi32> loc(#loc287) + %ret_641 = arith.select %cond_639, %ret_640, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc288) + %ret_642 = arith.xori %ret_609, %ret_641 : tensor<32x16xi32> loc(#loc289) + %new_idxs_643 = arith.xori %left_idx_633, %right_idx_634 : tensor<32x16xi32> loc(#loc304) + %new_idxs_644 = arith.select %cond_639, %new_idxs_643, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc290) + %new_idxs_645 = arith.xori %new_idxs_612, %new_idxs_644 : tensor<32x16xi32> loc(#loc291) + %y_646 = tt.reshape %ret_642 : tensor<32x16xi32> -> tensor<128x2x2xi32> loc(#loc270) + %ileft_647 = arith.muli %y_646, %ileft_64 : tensor<128x2x2xi32> loc(#loc271) + %ileft_648 = "tt.reduce"(%ileft_647) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc272)), %ileft_714: i32 loc(callsite(#loc1 at #loc272))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc329) + tt.reduce.return %ileft_715 : i32 loc(#loc317) + }) : (tensor<128x2x2xi32>) -> tensor<128x2xi32> loc(#loc317) + %ileft_649 = tt.expand_dims %ileft_648 {axis = 1 : i32} : tensor<128x2xi32> -> tensor<128x1x2xi32> loc(#loc273) + %ileft_650 = tt.broadcast %ileft_649 : tensor<128x1x2xi32> -> tensor<128x2x2xi32> loc(#loc274) + %iright_651 = arith.muli %y_646, %flip_24 : tensor<128x2x2xi32> loc(#loc275) + %iright_652 = "tt.reduce"(%iright_651) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc276)), %iright_714: i32 loc(callsite(#loc1 at #loc276))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc330) + tt.reduce.return %iright_715 : i32 loc(#loc319) + }) : (tensor<128x2x2xi32>) -> tensor<128x2xi32> loc(#loc319) + %iright_653 = tt.expand_dims %iright_652 {axis = 1 : i32} : tensor<128x2xi32> -> tensor<128x1x2xi32> loc(#loc277) + %iright_654 = tt.broadcast %iright_653 : tensor<128x1x2xi32> -> tensor<128x2x2xi32> loc(#loc278) + %ileft_655 = tt.reshape %ileft_650 : tensor<128x2x2xi32> -> tensor<32x16xi32> loc(#loc279) + %iright_656 = tt.reshape %iright_654 : tensor<128x2x2xi32> -> tensor<32x16xi32> loc(#loc280) + %y_idx_657 = tt.reshape %new_idxs_645 : tensor<32x16xi32> -> tensor<128x2x2xi32> loc(#loc292) + %left_idx_658 = arith.muli %y_idx_657, %ileft_64 : tensor<128x2x2xi32> loc(#loc293) + %left_idx_659 = "tt.reduce"(%left_idx_658) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc294)), %left_idx_714: i32 loc(callsite(#loc1 at #loc294))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc331) + tt.reduce.return %left_idx_715 : i32 loc(#loc321) + }) : (tensor<128x2x2xi32>) -> tensor<128x2xi32> loc(#loc321) + %left_idx_660 = tt.expand_dims %left_idx_659 {axis = 1 : i32} : tensor<128x2xi32> -> tensor<128x1x2xi32> loc(#loc295) + %left_idx_661 = tt.broadcast %left_idx_660 : tensor<128x1x2xi32> -> tensor<128x2x2xi32> loc(#loc296) + %right_idx_662 = arith.muli %y_idx_657, %flip_24 : tensor<128x2x2xi32> loc(#loc297) + %right_idx_663 = "tt.reduce"(%right_idx_662) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc298)), %right_idx_714: i32 loc(callsite(#loc1 at #loc298))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc332) + tt.reduce.return %right_idx_715 : i32 loc(#loc323) + }) : (tensor<128x2x2xi32>) -> tensor<128x2xi32> loc(#loc323) + %right_idx_664 = tt.expand_dims %right_idx_663 {axis = 1 : i32} : tensor<128x2xi32> -> tensor<128x1x2xi32> loc(#loc299) + %right_idx_665 = tt.broadcast %right_idx_664 : tensor<128x1x2xi32> -> tensor<128x2x2xi32> loc(#loc300) + %left_idx_666 = tt.reshape %left_idx_661 : tensor<128x2x2xi32> -> tensor<32x16xi32> loc(#loc301) + %right_idx_667 = tt.reshape %right_idx_665 : tensor<128x2x2xi32> -> tensor<32x16xi32> loc(#loc302) + %cond_668 = arith.cmpi slt, %ileft_655, %iright_656 : tensor<32x16xi32> loc(#loc281) + %eq_669 = arith.cmpi eq, %ileft_655, %iright_656 : tensor<32x16xi32> loc(#loc282) + %cond_670 = arith.cmpi sgt, %left_idx_666, %right_idx_667 : tensor<32x16xi32> loc(#loc303) + %cond_671 = arith.andi %eq_669, %cond_670 : tensor<32x16xi1> loc(#loc283) + %cond_672 = arith.ori %cond_668, %cond_671 : tensor<32x16xi1> loc(#loc284) + %ret_673 = arith.xori %ileft_655, %iright_656 : tensor<32x16xi32> loc(#loc287) + %ret_674 = arith.select %cond_672, %ret_673, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc288) + %ret_675 = arith.xori %ret_642, %ret_674 : tensor<32x16xi32> loc(#loc289) + %new_idxs_676 = arith.xori %left_idx_666, %right_idx_667 : tensor<32x16xi32> loc(#loc304) + %new_idxs_677 = arith.select %cond_672, %new_idxs_676, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc290) + %new_idxs_678 = arith.xori %new_idxs_645, %new_idxs_677 : tensor<32x16xi32> loc(#loc291) + %y_679 = tt.reshape %ret_675 : tensor<32x16xi32> -> tensor<256x2x1xi32> loc(#loc270) + %ileft_680 = arith.muli %y_679, %ileft : tensor<256x2x1xi32> loc(#loc271) + %ileft_681 = "tt.reduce"(%ileft_680) <{axis = 1 : i32}> ({ + ^bb0(%ileft_713: i32 loc(callsite(#loc1 at #loc272)), %ileft_714: i32 loc(callsite(#loc1 at #loc272))): + %ileft_715 = arith.addi %ileft_713, %ileft_714 : i32 loc(#loc329) + tt.reduce.return %ileft_715 : i32 loc(#loc317) + }) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc317) + %ileft_682 = tt.expand_dims %ileft_681 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc273) + %ileft_683 = tt.broadcast %ileft_682 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc274) + %iright_684 = arith.muli %y_679, %iright : tensor<256x2x1xi32> loc(#loc275) + %iright_685 = "tt.reduce"(%iright_684) <{axis = 1 : i32}> ({ + ^bb0(%iright_713: i32 loc(callsite(#loc1 at #loc276)), %iright_714: i32 loc(callsite(#loc1 at #loc276))): + %iright_715 = arith.addi %iright_713, %iright_714 : i32 loc(#loc330) + tt.reduce.return %iright_715 : i32 loc(#loc319) + }) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc319) + %iright_686 = tt.expand_dims %iright_685 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc277) + %iright_687 = tt.broadcast %iright_686 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc278) + %ileft_688 = tt.reshape %ileft_683 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc279) + %iright_689 = tt.reshape %iright_687 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc280) + %y_idx_690 = tt.reshape %new_idxs_678 : tensor<32x16xi32> -> tensor<256x2x1xi32> loc(#loc292) + %left_idx_691 = arith.muli %y_idx_690, %ileft : tensor<256x2x1xi32> loc(#loc293) + %left_idx_692 = "tt.reduce"(%left_idx_691) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_713: i32 loc(callsite(#loc1 at #loc294)), %left_idx_714: i32 loc(callsite(#loc1 at #loc294))): + %left_idx_715 = arith.addi %left_idx_713, %left_idx_714 : i32 loc(#loc331) + tt.reduce.return %left_idx_715 : i32 loc(#loc321) + }) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc321) + %left_idx_693 = tt.expand_dims %left_idx_692 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc295) + %left_idx_694 = tt.broadcast %left_idx_693 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc296) + %right_idx_695 = arith.muli %y_idx_690, %iright : tensor<256x2x1xi32> loc(#loc297) + %right_idx_696 = "tt.reduce"(%right_idx_695) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_713: i32 loc(callsite(#loc1 at #loc298)), %right_idx_714: i32 loc(callsite(#loc1 at #loc298))): + %right_idx_715 = arith.addi %right_idx_713, %right_idx_714 : i32 loc(#loc332) + tt.reduce.return %right_idx_715 : i32 loc(#loc323) + }) : (tensor<256x2x1xi32>) -> tensor<256x1xi32> loc(#loc323) + %right_idx_697 = tt.expand_dims %right_idx_696 {axis = 1 : i32} : tensor<256x1xi32> -> tensor<256x1x1xi32> loc(#loc299) + %right_idx_698 = tt.broadcast %right_idx_697 : tensor<256x1x1xi32> -> tensor<256x2x1xi32> loc(#loc300) + %left_idx_699 = tt.reshape %left_idx_694 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc301) + %right_idx_700 = tt.reshape %right_idx_698 : tensor<256x2x1xi32> -> tensor<32x16xi32> loc(#loc302) + %cond_701 = arith.cmpi slt, %ileft_688, %iright_689 : tensor<32x16xi32> loc(#loc281) + %eq_702 = arith.cmpi eq, %ileft_688, %iright_689 : tensor<32x16xi32> loc(#loc282) + %cond_703 = arith.cmpi sgt, %left_idx_699, %right_idx_700 : tensor<32x16xi32> loc(#loc303) + %cond_704 = arith.andi %eq_702, %cond_703 : tensor<32x16xi1> loc(#loc283) + %cond_705 = arith.ori %cond_701, %cond_704 : tensor<32x16xi1> loc(#loc284) + %new_idxs_706 = arith.xori %left_idx_699, %right_idx_700 : tensor<32x16xi32> loc(#loc304) + %new_idxs_707 = arith.select %cond_705, %new_idxs_706, %cst_3 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc290) + %new_idxs_708 = arith.xori %new_idxs_678, %new_idxs_707 : tensor<32x16xi32> loc(#loc291) + %tmp20 = arith.extui %tmp5 : tensor<32x16xi1> to tensor<32x16xi64> loc(#loc226) + %tmp23 = arith.select %tmp0_20, %tmp20, %cst_7 : tensor<32x16xi1>, tensor<32x16xi64> loc(#loc197) + %tmp24 = "tt.reduce"(%tmp23) <{axis = 1 : i32}> ({ + ^bb0(%tmp24_713: i64 loc(callsite(#loc1 at #loc198)), %tmp24_714: i64 loc(callsite(#loc1 at #loc198))): + %tmp24_715 = arith.addi %tmp24_713, %tmp24_714 : i64 loc(#loc305) + tt.reduce.return %tmp24_715 : i64 loc(#loc227) + }) : (tensor<32x16xi64>) -> tensor<32xi64> loc(#loc227) + %tmp24_709 = tt.expand_dims %tmp24 {axis = 1 : i32} : tensor<32xi64> -> tensor<32x1xi64> loc(#loc199) + %tmp25 = arith.extui %tmp14 : tensor<32x16xi1> to tensor<32x16xi64> loc(#loc229) + %tmp28 = arith.select %tmp0_20, %tmp25, %cst_7 : tensor<32x16xi1>, tensor<32x16xi64> loc(#loc201) + %tmp29 = "tt.reduce"(%tmp28) <{axis = 1 : i32}> ({ + ^bb0(%tmp29_713: i64 loc(callsite(#loc1 at #loc202)), %tmp29_714: i64 loc(callsite(#loc1 at #loc202))): + %tmp29_715 = arith.addi %tmp29_713, %tmp29_714 : i64 loc(#loc306) + tt.reduce.return %tmp29_715 : i64 loc(#loc230) + }) : (tensor<32x16xi64>) -> tensor<32xi64> loc(#loc230) + %tmp29_710 = tt.expand_dims %tmp29 {axis = 1 : i32} : tensor<32xi64> -> tensor<32x1xi64> loc(#loc203) + %tmp30 = arith.trunci %tmp24_709 : tensor<32x1xi64> to tensor<32x1xi32> loc(#loc204) + %tmp31 = arith.trunci %tmp29_710 : tensor<32x1xi64> to tensor<32x1xi32> loc(#loc205) + %tmp34 = tt.broadcast %tmp30 : tensor<32x1xi32> -> tensor<32x16xi32> loc(#loc206) + %tmp34_711 = arith.cmpi slt, %tmp0_15, %tmp34 : tensor<32x16xi32> loc(#loc206) + %tmp36 = arith.select %tmp34_711, %new_idxs_376, %cst_5 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc207) + %tmp38 = arith.addi %tmp36, %cst_4 : tensor<32x16xi32> loc(#loc208) + %tmp39 = arith.cmpi slt, %tmp36, %cst_3 : tensor<32x16xi32> loc(#loc209) + %tmp40 = arith.select %tmp39, %tmp38, %tmp36 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc210) + %0 = arith.cmpi sge, %tmp40, %cst_3 : tensor<32x16xi32> loc(#loc88) + %1 = arith.cmpi slt, %tmp40, %cst_4 : tensor<32x16xi32> loc(#loc89) + %2 = arith.andi %0, %1 : tensor<32x16xi1> loc(#loc90) + %3 = arith.xori %xmask_13, %cst_2 : tensor<32x1xi1> loc(#loc91) + %4 = tt.broadcast %3 : tensor<32x1xi1> -> tensor<32x16xi1> loc(#loc92) + %5 = arith.ori %2, %4 : tensor<32x16xi1> loc(#loc92) + tt.assert %5, "index out of bounds: 0 <= tmp40 < 17" : tensor<32x16xi1> loc(#loc93) + %tmp45 = tt.broadcast %tmp31 : tensor<32x1xi32> -> tensor<32x16xi32> loc(#loc211) + %tmp45_712 = arith.cmpi slt, %tmp0_15, %tmp45 : tensor<32x16xi32> loc(#loc211) + %tmp46 = arith.select %tmp45_712, %new_idxs_708, %cst_5 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc212) + %tmp47 = arith.addi %tmp46, %cst_4 : tensor<32x16xi32> loc(#loc213) + %tmp48 = arith.cmpi slt, %tmp46, %cst_3 : tensor<32x16xi32> loc(#loc214) + %tmp49 = arith.select %tmp48, %tmp47, %tmp46 : tensor<32x16xi1>, tensor<32x16xi32> loc(#loc215) + %6 = arith.cmpi sge, %tmp49, %cst_3 : tensor<32x16xi32> loc(#loc99) + %7 = arith.cmpi slt, %tmp49, %cst_4 : tensor<32x16xi32> loc(#loc100) + %8 = arith.andi %6, %7 : tensor<32x16xi1> loc(#loc101) + %9 = arith.ori %8, %4 : tensor<32x16xi1> loc(#loc102) + tt.assert %9, "index out of bounds: 0 <= tmp49 < 17" : tensor<32x16xi1> loc(#loc103) + %10 = tt.splat %out_ptr4 : !tt.ptr -> tensor<32x1x!tt.ptr> loc(#loc104) + %11 = tt.addptr %10, %xindex_12 : tensor<32x1x!tt.ptr>, tensor<32x1xi32> loc(#loc104) + tt.store %11, %tmp30, %xmask_13 : tensor<32x1x!tt.ptr> loc(#loc105) + %12 = tt.splat %out_ptr5 : !tt.ptr -> tensor<32x1x!tt.ptr> loc(#loc106) + %13 = tt.addptr %12, %xindex_12 : tensor<32x1x!tt.ptr>, tensor<32x1xi32> loc(#loc106) + tt.store %13, %tmp31, %xmask_13 : tensor<32x1x!tt.ptr> loc(#loc107) + %14 = tt.splat %out_ptr6 : !tt.ptr -> tensor<32x16x!tt.ptr> loc(#loc108) + %15 = tt.addptr %14, %tmp0_17 : tensor<32x16x!tt.ptr>, tensor<32x16xi32> loc(#loc108) + tt.store %15, %new_idxs_376, %tmp0_20 : tensor<32x16x!tt.ptr> loc(#loc109) + %16 = arith.muli %xindex_12, %cst_1 : tensor<32x1xi32> loc(#loc110) + %17 = tt.broadcast %16 : tensor<32x1xi32> -> tensor<32x16xi32> loc(#loc111) + %18 = arith.addi %tmp40, %17 : tensor<32x16xi32> loc(#loc111) + %19 = tt.splat %out_ptr7 : !tt.ptr -> tensor<32x16x!tt.ptr> loc(#loc112) + %20 = tt.addptr %19, %18 : tensor<32x16x!tt.ptr>, tensor<32x16xi32> loc(#loc112) + tt.store %20, %cst_0, %tmp0_20 : tensor<32x16x!tt.ptr> loc(#loc113) + %21 = tt.splat %out_ptr8 : !tt.ptr -> tensor<32x16x!tt.ptr> loc(#loc114) + %22 = tt.addptr %21, %tmp0_17 : tensor<32x16x!tt.ptr>, tensor<32x16xi32> loc(#loc114) + tt.store %22, %new_idxs_708, %tmp0_20 : tensor<32x16x!tt.ptr> loc(#loc115) + %23 = arith.addi %tmp49, %17 : tensor<32x16xi32> loc(#loc116) + %24 = tt.splat %out_ptr9 : !tt.ptr -> tensor<32x16x!tt.ptr> loc(#loc117) + %25 = tt.addptr %24, %23 : tensor<32x16x!tt.ptr>, tensor<32x16xi32> loc(#loc117) + tt.store %25, %cst_0, %tmp0_20 : tensor<32x16x!tt.ptr> loc(#loc118) + tt.return loc(#loc119) + } loc(#loc) +} loc(#loc) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":26:21) +#loc3 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":24:28) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":24:33) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":25:36) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":25:44) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":25:23) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":27:28) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":27:38) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":34:40) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":34:37) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":34:30) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":34:45) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":36:18) +#loc15 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":38:18) +#loc16 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":39:18) +#loc17 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":41:19) +#loc18 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":40:19) +#loc19 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":43:19) +#loc20 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":45:34) +#loc21 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":627:41) +#loc24 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":627:44) +#loc25 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":627:60) +#loc26 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":627:68) +#loc27 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":533:22) +#loc29 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":537:21) +#loc30 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":538:40) +#loc31 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:36) +#loc33 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:15) +#loc34 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":538:65) +#loc35 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":538:78) +#loc36 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":539:41) +#loc38 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":539:67) +#loc39 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":539:80) +#loc40 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":540:30) +#loc41 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":541:32) +#loc42 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":546:29) +#loc43 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:36) +#loc44 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:23) +#loc45 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":290:25) +#loc47 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:53) +#loc48 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:66) +#loc49 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:37) +#loc50 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:23) +#loc52 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:54) +#loc53 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:67) +#loc54 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":553:36) +#loc55 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":554:38) +#loc56 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":574:22) +#loc57 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":591:21) +#loc58 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":594:40) +#loc59 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":594:29) +#loc60 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":594:23) +#loc61 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":599:19) +#loc62 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":599:28) +#loc63 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":600:38) +#loc64 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":600:46) +#loc65 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":600:15) +#loc66 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":601:48) +#loc67 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":601:59) +#loc68 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":601:22) +#loc69 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":47:20) +#loc70 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":49:21) +#loc71 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":48:21) +#loc73 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":52:20) +#loc74 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":54:35) +#loc76 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":55:29) +#loc77 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":56:21) +#loc78 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":58:35) +#loc80 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":59:29) +#loc81 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":60:21) +#loc82 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":61:21) +#loc83 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":64:19) +#loc84 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":66:35) +#loc85 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":68:20) +#loc86 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":69:20) +#loc87 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":70:35) +#loc88 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":71:28) +#loc89 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":71:46) +#loc90 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":71:38) +#loc91 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":71:55) +#loc92 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":71:53) +#loc93 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":71:63) +#loc94 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":75:19) +#loc95 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":76:35) +#loc96 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":77:20) +#loc97 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":78:20) +#loc98 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":79:35) +#loc99 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":80:28) +#loc100 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":80:46) +#loc101 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":80:38) +#loc102 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":80:53) +#loc103 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":80:63) +#loc104 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":81:25) +#loc105 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":81:37) +#loc106 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":82:25) +#loc107 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":82:37) +#loc108 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":83:25) +#loc109 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":83:47) +#loc110 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":84:52) +#loc111 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":84:49) +#loc112 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":84:25) +#loc113 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":84:85) +#loc114 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":85:25) +#loc115 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":85:47) +#loc116 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":86:49) +#loc117 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":86:25) +#loc118 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":86:85) +#loc119 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/sv/csv46z4ndfd65eeb2whwsukjehydbct65zjp5g5qphwxvvaec3vy.py":86:4) +#loc129 = loc("xmask"(#loc2)) +#loc130 = loc("xoffset"(#loc3)) +#loc131 = loc("xoffset"(#loc4)) +#loc132 = loc("xindex"(#loc5)) +#loc133 = loc("xindex"(#loc6)) +#loc134 = loc("xindex"(#loc7)) +#loc135 = loc("r0_index"(#loc8)) +#loc136 = loc("r0_index"(#loc9)) +#loc137 = loc("tmp0"(#loc10)) +#loc138 = loc("tmp0"(#loc11)) +#loc139 = loc("tmp0"(#loc12)) +#loc140 = loc("tmp0"(#loc13)) +#loc141 = loc("tmp2"(#loc14)) +#loc142 = loc("tmp4"(#loc15)) +#loc143 = loc("tmp5"(#loc16)) +#loc144 = loc("tmp7"(#loc17)) +#loc145 = loc("tmp6"(#loc18)) +#loc146 = loc("tmp9"(#loc19)) +#loc147 = loc("tmp11"(#loc20)) +#loc148 = loc("flip"(#loc21)) +#loc150 = loc("flip"(#loc24)) +#loc151 = loc("flip"(#loc25)) +#loc152 = loc("flip"(#loc26)) +#loc153 = loc("y"(#loc27)) +#loc154 = loc("left_mask"(#loc29)) +#loc155 = loc("ileft"(#loc30)) +#loc157 = loc("ileft"(#loc34)) +#loc158 = loc("ileft"(#loc35)) +#loc159 = loc("iright"(#loc36)) +#loc161 = loc("iright"(#loc38)) +#loc162 = loc("iright"(#loc39)) +#loc163 = loc("ileft"(#loc40)) +#loc164 = loc("iright"(#loc41)) +#loc165 = loc("y_idx"(#loc42)) +#loc166 = loc("left_idx"(#loc43)) +#loc167 = loc("left_idx"(#loc44)) +#loc168 = loc("input"(#loc45)) +#loc170 = loc("left_idx"(#loc47)) +#loc171 = loc("left_idx"(#loc48)) +#loc172 = loc("right_idx"(#loc49)) +#loc173 = loc("right_idx"(#loc50)) +#loc175 = loc("right_idx"(#loc52)) +#loc176 = loc("right_idx"(#loc53)) +#loc177 = loc("left_idx"(#loc54)) +#loc178 = loc("right_idx"(#loc55)) +#loc179 = loc("cond"(#loc56)) +#loc180 = loc("eq"(#loc57)) +#loc181 = loc("cond"(#loc58)) +#loc182 = loc("cond"(#loc59)) +#loc183 = loc("cond"(#loc60)) +#loc184 = loc("cond"(#loc61)) +#loc185 = loc("cond"(#loc62)) +#loc186 = loc("ret"(#loc63)) +#loc187 = loc("ret"(#loc64)) +#loc188 = loc("ret"(#loc65)) +#loc189 = loc("new_idxs"(#loc66)) +#loc190 = loc("new_idxs"(#loc67)) +#loc191 = loc("new_idxs"(#loc68)) +#loc192 = loc("tmp14"(#loc69)) +#loc193 = loc("tmp16"(#loc70)) +#loc194 = loc("tmp15"(#loc71)) +#loc196 = loc("tmp20"(#loc73)) +#loc197 = loc("tmp23"(#loc74)) +#loc199 = loc("tmp24"(#loc76)) +#loc200 = loc("tmp25"(#loc77)) +#loc201 = loc("tmp28"(#loc78)) +#loc203 = loc("tmp29"(#loc80)) +#loc204 = loc("tmp30"(#loc81)) +#loc205 = loc("tmp31"(#loc82)) +#loc206 = loc("tmp34"(#loc83)) +#loc207 = loc("tmp36"(#loc84)) +#loc208 = loc("tmp38"(#loc85)) +#loc209 = loc("tmp39"(#loc86)) +#loc210 = loc("tmp40"(#loc87)) +#loc211 = loc("tmp45"(#loc94)) +#loc212 = loc("tmp46"(#loc95)) +#loc213 = loc("tmp47"(#loc96)) +#loc214 = loc("tmp48"(#loc97)) +#loc215 = loc("tmp49"(#loc98)) +#loc216 = loc(fused[#loc144, #loc145]) +#loc217 = loc(callsite(#loc148 at #loc149)) +#loc218 = loc(callsite(#loc150 at #loc149)) +#loc219 = loc(callsite(#loc151 at #loc149)) +#loc220 = loc(callsite(#loc152 at #loc149)) +#loc222 = loc("cond"(#loc179)) +#loc223 = loc("eq"(#loc180)) +#loc224 = loc(fused[#loc193, #loc194]) +#loc226 = loc(fused[#loc196, #loc144, #loc145]) +#loc227 = loc(callsite(#loc31 at #loc198)) +#loc229 = loc(fused[#loc200, #loc193, #loc194]) +#loc230 = loc(callsite(#loc31 at #loc202)) +#loc232 = loc(callsite(#loc153 at #loc221)) +#loc233 = loc(callsite(#loc154 at #loc221)) +#loc234 = loc(callsite(#loc155 at #loc221)) +#loc236 = loc(callsite(#loc157 at #loc221)) +#loc237 = loc(callsite(#loc158 at #loc221)) +#loc238 = loc(callsite(#loc159 at #loc221)) +#loc240 = loc(callsite(#loc161 at #loc221)) +#loc241 = loc(callsite(#loc162 at #loc221)) +#loc242 = loc(callsite(#loc163 at #loc221)) +#loc243 = loc(callsite(#loc164 at #loc221)) +#loc244 = loc(callsite(#loc165 at #loc221)) +#loc245 = loc(callsite(#loc166 at #loc221)) +#loc246 = loc(callsite(#loc167 at #loc221)) +#loc248 = loc(callsite(#loc170 at #loc221)) +#loc249 = loc(callsite(#loc171 at #loc221)) +#loc250 = loc(callsite(#loc172 at #loc221)) +#loc251 = loc(callsite(#loc173 at #loc221)) +#loc253 = loc(callsite(#loc175 at #loc221)) +#loc254 = loc(callsite(#loc176 at #loc221)) +#loc255 = loc(callsite(#loc177 at #loc221)) +#loc256 = loc(callsite(#loc178 at #loc221)) +#loc257 = loc(callsite(#loc222 at #loc221)) +#loc258 = loc(callsite(#loc223 at #loc221)) +#loc259 = loc(callsite(#loc181 at #loc221)) +#loc260 = loc(callsite(#loc182 at #loc221)) +#loc261 = loc(callsite(#loc183 at #loc221)) +#loc262 = loc(callsite(#loc184 at #loc221)) +#loc263 = loc(callsite(#loc185 at #loc221)) +#loc264 = loc(callsite(#loc186 at #loc221)) +#loc265 = loc(callsite(#loc187 at #loc221)) +#loc266 = loc(callsite(#loc188 at #loc221)) +#loc267 = loc(callsite(#loc189 at #loc221)) +#loc268 = loc(callsite(#loc190 at #loc221)) +#loc269 = loc(callsite(#loc191 at #loc221)) +#loc270 = loc(callsite(#loc153 at #loc225)) +#loc271 = loc(callsite(#loc155 at #loc225)) +#loc273 = loc(callsite(#loc157 at #loc225)) +#loc274 = loc(callsite(#loc158 at #loc225)) +#loc275 = loc(callsite(#loc159 at #loc225)) +#loc277 = loc(callsite(#loc161 at #loc225)) +#loc278 = loc(callsite(#loc162 at #loc225)) +#loc279 = loc(callsite(#loc163 at #loc225)) +#loc280 = loc(callsite(#loc164 at #loc225)) +#loc281 = loc(callsite(#loc222 at #loc225)) +#loc282 = loc(callsite(#loc223 at #loc225)) +#loc283 = loc(callsite(#loc182 at #loc225)) +#loc284 = loc(callsite(#loc183 at #loc225)) +#loc285 = loc(callsite(#loc184 at #loc225)) +#loc286 = loc(callsite(#loc185 at #loc225)) +#loc287 = loc(callsite(#loc186 at #loc225)) +#loc288 = loc(callsite(#loc187 at #loc225)) +#loc289 = loc(callsite(#loc188 at #loc225)) +#loc290 = loc(callsite(#loc190 at #loc225)) +#loc291 = loc(callsite(#loc191 at #loc225)) +#loc292 = loc(callsite(#loc165 at #loc225)) +#loc293 = loc(callsite(#loc167 at #loc225)) +#loc295 = loc(callsite(#loc170 at #loc225)) +#loc296 = loc(callsite(#loc171 at #loc225)) +#loc297 = loc(callsite(#loc173 at #loc225)) +#loc299 = loc(callsite(#loc175 at #loc225)) +#loc300 = loc(callsite(#loc176 at #loc225)) +#loc301 = loc(callsite(#loc177 at #loc225)) +#loc302 = loc(callsite(#loc178 at #loc225)) +#loc303 = loc(callsite(#loc181 at #loc225)) +#loc304 = loc(callsite(#loc189 at #loc225)) +#loc305 = loc(callsite(#loc33 at #loc227)) +#loc306 = loc(callsite(#loc33 at #loc230)) +#loc307 = loc(callsite(#loc31 at #loc235)) +#loc309 = loc(callsite(#loc31 at #loc239)) +#loc311 = loc(callsite(#loc168 at #loc247)) +#loc312 = loc(callsite(#loc31 at #loc247)) +#loc314 = loc(callsite(#loc168 at #loc252)) +#loc315 = loc(callsite(#loc31 at #loc252)) +#loc317 = loc(callsite(#loc31 at #loc272)) +#loc319 = loc(callsite(#loc31 at #loc276)) +#loc321 = loc(callsite(#loc31 at #loc294)) +#loc323 = loc(callsite(#loc31 at #loc298)) +#loc325 = loc(callsite(#loc33 at #loc307)) +#loc326 = loc(callsite(#loc33 at #loc309)) +#loc327 = loc(callsite(#loc33 at #loc312)) +#loc328 = loc(callsite(#loc33 at #loc315)) +#loc329 = loc(callsite(#loc33 at #loc317)) +#loc330 = loc(callsite(#loc33 at #loc319)) +#loc331 = loc(callsite(#loc33 at #loc321)) +#loc332 = loc(callsite(#loc33 at #loc323)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/FM6CWX5NDNUSSOQWH3EGE2NL4IM53NMNXM4NI2FUVI2663Y6EMSQ/__grp__triton_red_fused_zeros_0.json b/SpecForge-ext/cache/compiled_kernels/triton/0/FM6CWX5NDNUSSOQWH3EGE2NL4IM53NMNXM4NI2FUVI2663Y6EMSQ/__grp__triton_red_fused_zeros_0.json new file mode 100644 index 0000000000000000000000000000000000000000..ae56b94732e1533b6fe1c2e7b78837b78e003625 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/FM6CWX5NDNUSSOQWH3EGE2NL4IM53NMNXM4NI2FUVI2663Y6EMSQ/__grp__triton_red_fused_zeros_0.json @@ -0,0 +1 @@ +{"child_paths": {"triton_red_fused_zeros_0.source": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/FM6CWX5NDNUSSOQWH3EGE2NL4IM53NMNXM4NI2FUVI2663Y6EMSQ/triton_red_fused_zeros_0.source", "triton_red_fused_zeros_0.ttir": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/FM6CWX5NDNUSSOQWH3EGE2NL4IM53NMNXM4NI2FUVI2663Y6EMSQ/triton_red_fused_zeros_0.ttir", "triton_red_fused_zeros_0.ttgir": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/FM6CWX5NDNUSSOQWH3EGE2NL4IM53NMNXM4NI2FUVI2663Y6EMSQ/triton_red_fused_zeros_0.ttgir", "triton_red_fused_zeros_0.llir": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/FM6CWX5NDNUSSOQWH3EGE2NL4IM53NMNXM4NI2FUVI2663Y6EMSQ/triton_red_fused_zeros_0.llir", "triton_red_fused_zeros_0.ptx": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/FM6CWX5NDNUSSOQWH3EGE2NL4IM53NMNXM4NI2FUVI2663Y6EMSQ/triton_red_fused_zeros_0.ptx", "triton_red_fused_zeros_0.cubin": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/FM6CWX5NDNUSSOQWH3EGE2NL4IM53NMNXM4NI2FUVI2663Y6EMSQ/triton_red_fused_zeros_0.cubin", "triton_red_fused_zeros_0.json": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/FM6CWX5NDNUSSOQWH3EGE2NL4IM53NMNXM4NI2FUVI2663Y6EMSQ/triton_red_fused_zeros_0.json"}} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/FM6CWX5NDNUSSOQWH3EGE2NL4IM53NMNXM4NI2FUVI2663Y6EMSQ/triton_red_fused_zeros_0.cubin b/SpecForge-ext/cache/compiled_kernels/triton/0/FM6CWX5NDNUSSOQWH3EGE2NL4IM53NMNXM4NI2FUVI2663Y6EMSQ/triton_red_fused_zeros_0.cubin new file mode 100644 index 0000000000000000000000000000000000000000..c998bc03c657bad8830daca5d607841ea2005312 Binary files /dev/null and b/SpecForge-ext/cache/compiled_kernels/triton/0/FM6CWX5NDNUSSOQWH3EGE2NL4IM53NMNXM4NI2FUVI2663Y6EMSQ/triton_red_fused_zeros_0.cubin differ diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/FM6CWX5NDNUSSOQWH3EGE2NL4IM53NMNXM4NI2FUVI2663Y6EMSQ/triton_red_fused_zeros_0.json b/SpecForge-ext/cache/compiled_kernels/triton/0/FM6CWX5NDNUSSOQWH3EGE2NL4IM53NMNXM4NI2FUVI2663Y6EMSQ/triton_red_fused_zeros_0.json new file mode 100644 index 0000000000000000000000000000000000000000..e090e69169bebe56ef7149974c10d057f011f555 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/FM6CWX5NDNUSSOQWH3EGE2NL4IM53NMNXM4NI2FUVI2663Y6EMSQ/triton_red_fused_zeros_0.json @@ -0,0 +1 @@ +{"hash": "2b3c2b5fad1b69293a163ec86269abe219ddb58dbb38d468b4aa35ef6f1e2325", "target": {"backend": "cuda", "arch": 90, "warp_size": 32}, "num_warps": 8, "num_ctas": 1, "num_stages": 1, "warp_size": 32, "maxnreg": null, "cluster_dims": [1, 1, 1], "ptx_version": null, "ptx_options": null, "ir_override": null, "enable_fp_fusion": true, "launch_cooperative_grid": false, "launch_pdl": false, "supported_fp8_dtypes": ["fp8e4b15", "fp8e4nv", "fp8e5"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b15"], "default_dot_input_precision": "tf32", "allowed_dot_input_precisions": ["tf32", "tf32x3", "ieee"], "max_num_imprecise_acc_default": 1073741824, "extern_libs": [["libdevice", "/workspace/specforge/lib/python3.11/site-packages/triton/backends/nvidia/lib/libdevice.10.bc"]], "debug": true, "backend_name": "cuda", "sanitize_overflow": false, "arch": "sm90", "instrumentation_mode": "", "triton_version": "3.5.1", "tensordesc_meta": [], "shared": 256, "tmem_size": 0, "global_scratch_size": 0, "global_scratch_align": 1, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "triton_red_fused_zeros_0"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/FM6CWX5NDNUSSOQWH3EGE2NL4IM53NMNXM4NI2FUVI2663Y6EMSQ/triton_red_fused_zeros_0.llir b/SpecForge-ext/cache/compiled_kernels/triton/0/FM6CWX5NDNUSSOQWH3EGE2NL4IM53NMNXM4NI2FUVI2663Y6EMSQ/triton_red_fused_zeros_0.llir new file mode 100644 index 0000000000000000000000000000000000000000..1741475d410ba33360c6962ddfa59127e4bfed44 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/FM6CWX5NDNUSSOQWH3EGE2NL4IM53NMNXM4NI2FUVI2663Y6EMSQ/triton_red_fused_zeros_0.llir @@ -0,0 +1,186 @@ +; ModuleID = 'LLVMDialectModule' +source_filename = "LLVMDialectModule" +target datalayout = "e-p3:32:32-p4:32:32-p5:32:32-p6:32:32-p7:32:32-i64:64-i128:128-v16:16-v32:32-n16:32:64" + +@global_smem = external local_unnamed_addr addrspace(3) global [0 x i8], align 16 + +; Function Attrs: nounwind +define ptx_kernel void @triton_red_fused_zeros_0(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, i64 %3, i64 %4, i32 %5, i32 %6, ptr addrspace(1) readnone captures(none) %7, ptr addrspace(1) readnone captures(none) %8) local_unnamed_addr #0 !dbg !4 { + %10 = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !dbg !7 + %11 = shl i32 %10, 6, !dbg !8 + %12 = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !dbg !9 + %13 = and i32 %12, 252, !dbg !9 + %14 = lshr exact i32 %13, 2, !dbg !9 + %15 = or disjoint i32 %14, %11, !dbg !10 + %16 = icmp slt i32 %15, %5, !dbg !11 + %17 = and i32 %12, 3, !dbg !12 + %18 = sext i32 %15 to i64, !dbg !13 + %.frozen = freeze i64 %3, !dbg !14 + %19 = sdiv i64 %18, %.frozen, !dbg !14 + %20 = mul i64 %19, %.frozen, !dbg !13 + %.decomposed = sub i64 %18, %20, !dbg !13 + %21 = srem i64 %19, 32, !dbg !15 + %22 = sdiv i64 %18, %4, !dbg !16 + %.not = icmp ne i64 %.decomposed, 0, !dbg !17 + %23 = icmp slt i32 %11, 0, !dbg !21 + %24 = icmp slt i64 %3, 0, !dbg !22 + %25 = xor i1 %23, %24, !dbg !23 + %narrow = select i1 %25, i1 %.not, i1 false, !dbg !24 + %26 = sext i1 %narrow to i64, !dbg !24 + %27 = add nsw i64 %19, %26, !dbg !24 + %28 = shl i64 %3, 12, !dbg !25 + %29 = mul i64 %28, %22, !dbg !26 + %30 = icmp slt i64 %3, 2, !dbg !27 + %31 = icmp sgt i64 %3, 1, !dbg !28 + %32 = select i1 %31, i64 %3, i64 0, !dbg !29 + %33 = zext i1 %30 to i64, !dbg !30 + %34 = add i64 %32, %33, !dbg !31 + %35 = shl i64 %34, 7, !dbg !32 + %36 = mul i64 %35, %27, !dbg !33 + %.idx = shl nsw i64 %21, 8 + %37 = getelementptr i8, ptr addrspace(1) %0, i64 %.idx + %.idx1 = shl nsw i64 %.decomposed, 13 + %invariant.gep = getelementptr i8, ptr addrspace(1) %37, i64 %.idx1, !dbg !34 + %invariant.gep3 = getelementptr bfloat, ptr addrspace(1) %invariant.gep, i64 %29, !dbg !34 + %.idx2 = shl nsw i64 %.decomposed, 8 + %38 = getelementptr i8, ptr addrspace(1) %1, i64 %.idx2 + %invariant.gep5 = getelementptr bfloat, ptr addrspace(1) %38, i64 %36, !dbg !34 + %.fr = freeze i1 %16 + %39 = zext nneg i32 %17 to i64, !dbg !34 + br i1 %.fr, label %.split.us, label %.split + +.split.us: ; preds = %9, %.split.us + %indvars.iv9 = phi i64 [ %indvars.iv.next10, %.split.us ], [ 0, %9 ] + %40 = phi float [ %51, %.split.us ], [ 0.000000e+00, %9 ] + %41 = or disjoint i64 %indvars.iv9, %39, !dbg !35 + %gep4.us = getelementptr bfloat, ptr addrspace(1) %invariant.gep3, i64 %41, !dbg !36 + %42 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_first.b64 $0, 1.0;", "=l"() #4, !dbg !37 + %43 = tail call i16 asm sideeffect "mov.u16 $0, $1;\0A\09@$4 ld.global.L1::evict_first.L2::cache_hint.b16 { $0 }, [ $2 + 0 ], $3;", "=c,c,l,l,b"(i16 0, ptr addrspace(1) %gep4.us, i64 %42, i1 true) #4, !dbg !37 + %44 = bitcast i16 %43 to bfloat, !dbg !37 + %45 = fpext bfloat %44 to float, !dbg !38 + %gep.us = getelementptr bfloat, ptr addrspace(1) %invariant.gep5, i64 %41, !dbg !39 + %46 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_first.b64 $0, 1.0;", "=l"() #4, !dbg !40 + %47 = tail call i16 asm sideeffect "mov.u16 $0, $1;\0A\09@$4 ld.global.L1::evict_first.L2::cache_hint.b16 { $0 }, [ $2 + 0 ], $3;", "=c,c,l,l,b"(i16 0, ptr addrspace(1) %gep.us, i64 %46, i1 true) #4, !dbg !40 + %48 = bitcast i16 %47 to bfloat, !dbg !40 + %49 = fpext bfloat %48 to float, !dbg !41 + %50 = fmul float %45, %49, !dbg !42 + %51 = fadd float %40, %50, !dbg !43 + %indvars.iv.next10 = add nuw nsw i64 %indvars.iv9, 4, !dbg !34 + %52 = icmp samesign ult i64 %indvars.iv9, 124, !dbg !34 + br i1 %52, label %.split.us, label %.split7.us, !dbg !34 + +.split: ; preds = %9, %.split + %indvars.iv = phi i64 [ %indvars.iv.next, %.split ], [ 0, %9 ] + %53 = or disjoint i64 %indvars.iv, %39, !dbg !35 + %gep4 = getelementptr bfloat, ptr addrspace(1) %invariant.gep3, i64 %53, !dbg !36 + %54 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_first.b64 $0, 1.0;", "=l"() #4, !dbg !37 + %55 = tail call i16 asm sideeffect "mov.u16 $0, $1;\0A\09@$4 ld.global.L1::evict_first.L2::cache_hint.b16 { $0 }, [ $2 + 0 ], $3;", "=c,c,l,l,b"(i16 0, ptr addrspace(1) %gep4, i64 %54, i1 false) #4, !dbg !37 + %gep = getelementptr bfloat, ptr addrspace(1) %invariant.gep5, i64 %53, !dbg !39 + %56 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_first.b64 $0, 1.0;", "=l"() #4, !dbg !40 + %57 = tail call i16 asm sideeffect "mov.u16 $0, $1;\0A\09@$4 ld.global.L1::evict_first.L2::cache_hint.b16 { $0 }, [ $2 + 0 ], $3;", "=c,c,l,l,b"(i16 0, ptr addrspace(1) %gep, i64 %56, i1 false) #4, !dbg !40 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 4, !dbg !34 + %58 = icmp samesign ult i64 %indvars.iv, 124, !dbg !34 + br i1 %58, label %.split, label %.split7.us, !dbg !34 + +.split7.us: ; preds = %.split, %.split.us + %.us-phi = phi float [ %51, %.split.us ], [ 0.000000e+00, %.split ], !dbg !9 + %59 = and i32 %12, 63, !dbg !9 + %60 = or disjoint i32 %11, %59, !dbg !10 + %61 = icmp slt i32 %60, %5, !dbg !11 + %62 = bitcast float %.us-phi to i32, !dbg !44 + %63 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %62, i32 2, i32 31), !dbg !44 + %64 = bitcast i32 %63 to float, !dbg !44 + %65 = fadd float %.us-phi, %64, !dbg !48 + %66 = bitcast float %65 to i32, !dbg !44 + %67 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %66, i32 1, i32 31), !dbg !44 + %68 = bitcast i32 %67 to float, !dbg !44 + %69 = fadd float %65, %68, !dbg !48 + %70 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %13, !dbg !49 + store float %69, ptr addrspace(3) %70, align 4, !dbg !49 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !49 + %71 = shl nuw nsw i32 %59, 2, !dbg !49 + %72 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %71, !dbg !49 + %73 = load i32, ptr addrspace(3) %72, align 4, !dbg !49 + %74 = sext i32 %60 to i64, !dbg !50 + %75 = getelementptr float, ptr addrspace(1) %2, i64 %74, !dbg !50 + %76 = and i32 %12, 192, !dbg !51 + %77 = icmp eq i32 %76, 0, !dbg !51 + %78 = and i1 %77, %61, !dbg !51 + tail call void asm sideeffect "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b"(i32 %73, ptr addrspace(1) %75, i1 %78) #4, !dbg !51 + ret void, !dbg !52 +} + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare noundef range(i32 0, 2147483647) i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #1 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare noundef range(i32 0, 1024) i32 @llvm.nvvm.read.ptx.sreg.tid.x() #1 + +; Function Attrs: convergent nocallback nounwind memory(inaccessiblemem: readwrite) +declare i32 @llvm.nvvm.shfl.sync.bfly.i32(i32, i32, i32, i32) #2 + +; Function Attrs: convergent nocallback nounwind +declare void @llvm.nvvm.barrier.cta.sync.aligned.all(i32) #3 + +attributes #0 = { nounwind "nvvm.reqntid"="256" } +attributes #1 = { mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) } +attributes #2 = { convergent nocallback nounwind memory(inaccessiblemem: readwrite) } +attributes #3 = { convergent nocallback nounwind } +attributes #4 = { nounwind } + +!llvm.dbg.cu = !{!0} +!llvm.module.flags = !{!2, !3} + +!0 = distinct !DICompileUnit(language: DW_LANG_C, file: !1, producer: "triton", isOptimized: true, runtimeVersion: 0, emissionKind: LineTablesOnly) +!1 = !DIFile(filename: "cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py", directory: "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc") +!2 = !{i32 2, !"Debug Info Version", i32 3} +!3 = !{i32 4, !"nvvm-reflect-ftz", i32 1} +!4 = distinct !DISubprogram(name: "triton_red_fused_zeros_0", linkageName: "triton_red_fused_zeros_0", scope: !1, file: !1, line: 18, type: !5, scopeLine: 18, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !0) +!5 = !DISubroutineType(cc: DW_CC_normal, types: !6) +!6 = !{} +!7 = !DILocation(line: 22, column: 28, scope: !4) +!8 = !DILocation(line: 22, column: 33, scope: !4) +!9 = !DILocation(line: 23, column: 44, scope: !4) +!10 = !DILocation(line: 23, column: 23, scope: !4) +!11 = !DILocation(line: 24, column: 21, scope: !4) +!12 = !DILocation(line: 25, column: 37, scope: !4) +!13 = !DILocation(line: 27, column: 19, scope: !4) +!14 = !DILocation(line: 28, column: 21, scope: !4) +!15 = !DILocation(line: 28, column: 28, scope: !4) +!16 = !DILocation(line: 29, column: 19, scope: !4) +!17 = !DILocation(line: 74, column: 34, scope: !18, inlinedAt: !20) +!18 = distinct !DILexicalBlockFile(scope: !4, file: !19, discriminator: 0) +!19 = !DIFile(filename: "triton_helpers.py", directory: "/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime") +!20 = !DILocation(line: 30, column: 51, scope: !4) +!21 = !DILocation(line: 75, column: 25, scope: !18, inlinedAt: !20) +!22 = !DILocation(line: 75, column: 36, scope: !18, inlinedAt: !20) +!23 = !DILocation(line: 75, column: 32, scope: !18, inlinedAt: !20) +!24 = !DILocation(line: 75, column: 47, scope: !18, inlinedAt: !20) +!25 = !DILocation(line: 39, column: 65, scope: !4) +!26 = !DILocation(line: 39, column: 69, scope: !4) +!27 = !DILocation(line: 40, column: 73, scope: !4) +!28 = !DILocation(line: 40, column: 99, scope: !4) +!29 = !DILocation(line: 40, column: 90, scope: !4) +!30 = !DILocation(line: 40, scope: !4) +!31 = !DILocation(line: 40, column: 81, scope: !4) +!32 = !DILocation(line: 40, column: 54, scope: !4) +!33 = !DILocation(line: 40, column: 58, scope: !4) +!34 = !DILocation(line: 33, column: 40, scope: !4) +!35 = !DILocation(line: 34, column: 31, scope: !4) +!36 = !DILocation(line: 39, column: 34, scope: !4) +!37 = !DILocation(line: 39, column: 74, scope: !4) +!38 = !DILocation(line: 39, column: 136, scope: !4) +!39 = !DILocation(line: 40, column: 34, scope: !4) +!40 = !DILocation(line: 40, column: 106, scope: !4) +!41 = !DILocation(line: 40, column: 168, scope: !4) +!42 = !DILocation(line: 41, column: 22, scope: !4) +!43 = !DILocation(line: 43, column: 23, scope: !4) +!44 = !DILocation(line: 291, column: 36, scope: !45, inlinedAt: !47) +!45 = distinct !DILexicalBlockFile(scope: !4, file: !46, discriminator: 0) +!46 = !DIFile(filename: "standard.py", directory: "/workspace/specforge/lib/python3.11/site-packages/triton/language") +!47 = !DILocation(line: 45, column: 25, scope: !4) +!48 = !DILocation(line: 261, column: 15, scope: !45, inlinedAt: !47) +!49 = !DILocation(line: 45, column: 28, scope: !4) +!50 = !DILocation(line: 49, column: 25, scope: !4) +!51 = !DILocation(line: 49, column: 36, scope: !4) +!52 = !DILocation(line: 49, column: 4, scope: !4) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/FM6CWX5NDNUSSOQWH3EGE2NL4IM53NMNXM4NI2FUVI2663Y6EMSQ/triton_red_fused_zeros_0.ptx b/SpecForge-ext/cache/compiled_kernels/triton/0/FM6CWX5NDNUSSOQWH3EGE2NL4IM53NMNXM4NI2FUVI2663Y6EMSQ/triton_red_fused_zeros_0.ptx new file mode 100644 index 0000000000000000000000000000000000000000..b737ed801dfd2df1b7776a051b16a34772334988 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/FM6CWX5NDNUSSOQWH3EGE2NL4IM53NMNXM4NI2FUVI2663Y6EMSQ/triton_red_fused_zeros_0.ptx @@ -0,0 +1,517 @@ +// +// Generated by LLVM NVPTX Back-End +// + +.version 8.7 +.target sm_90a +.address_size 64 + + // .globl triton_red_fused_zeros_0 // -- Begin function triton_red_fused_zeros_0 +.extern .shared .align 16 .b8 global_smem[]; + // @triton_red_fused_zeros_0 +.visible .entry triton_red_fused_zeros_0( + .param .u64 .ptr .global .align 1 triton_red_fused_zeros_0_param_0, + .param .u64 .ptr .global .align 1 triton_red_fused_zeros_0_param_1, + .param .u64 .ptr .global .align 1 triton_red_fused_zeros_0_param_2, + .param .u64 triton_red_fused_zeros_0_param_3, + .param .u64 triton_red_fused_zeros_0_param_4, + .param .u32 triton_red_fused_zeros_0_param_5, + .param .u32 triton_red_fused_zeros_0_param_6, + .param .u64 .ptr .global .align 1 triton_red_fused_zeros_0_param_7, + .param .u64 .ptr .global .align 1 triton_red_fused_zeros_0_param_8 +) +.reqntid 256 +{ + .reg .pred %p<22>; + .reg .b16 %rs<9>; + .reg .b32 %r<41>; + .reg .b64 %rd<102>; + .loc 1 18 0 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:18:0 +$L__func_begin0: + .loc 1 18 0 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:18:0 + +// %bb.0: + ld.param.b64 %rd35, [triton_red_fused_zeros_0_param_4]; + ld.param.b64 %rd36, [triton_red_fused_zeros_0_param_3]; +$L__tmp0: + .loc 1 22 28 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:22:28 + mov.u32 %r10, %ctaid.x; + .loc 1 22 33 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:22:33 + shl.b32 %r1, %r10, 6; + .loc 1 23 44 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:23:44 + mov.u32 %r2, %tid.x; + bfe.u32 %r4, %r2, 2, 6; + .loc 1 23 23 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:23:23 + or.b32 %r11, %r4, %r1; + .loc 1 27 19 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:27:19 + cvt.s64.s32 %rd1, %r11; + .loc 1 28 21 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:28:21 + or.b64 %rd37, %rd1, %rd36; + and.b64 %rd38, %rd37, -4294967296; + setp.ne.b64 %p2, %rd38, 0; + cvt.u32.u64 %r38, %rd1; + @%p2 bra $L__BB0_2; + bra.uni $L__BB0_1; +$L__BB0_2: + div.s64 %rd94, %rd1, %rd36; + bra.uni $L__BB0_3; +$L__BB0_1: + cvt.u32.u64 %r12, %rd36; + div.u32 %r14, %r38, %r12; + cvt.u64.u32 %rd94, %r14; +$L__BB0_3: + .loc 1 0 21 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:0:21 + ld.param.b32 %r9, [triton_red_fused_zeros_0_param_5]; + and.b32 %r5, %r2, 3; + .loc 1 27 19 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:27:19 + mul.lo.s64 %rd6, %rd94, %rd36; + .loc 1 28 28 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:28:28 + shr.s64 %rd39, %rd94, 63; + shr.u64 %rd40, %rd39, 59; + add.s64 %rd41, %rd94, %rd40; + and.b64 %rd42, %rd41, -32; + sub.s64 %rd8, %rd94, %rd42; + .loc 1 29 19 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:29:19 + or.b64 %rd43, %rd1, %rd35; + and.b64 %rd44, %rd43, -4294967296; + setp.ne.b64 %p3, %rd44, 0; + @%p3 bra $L__BB0_5; + bra.uni $L__BB0_4; +$L__BB0_5: + div.s64 %rd95, %rd1, %rd35; + bra.uni $L__BB0_6; +$L__BB0_4: + cvt.u32.u64 %r15, %rd35; + div.u32 %r17, %r38, %r15; + cvt.u64.u32 %rd95, %r17; +$L__BB0_6: + .loc 1 0 19 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:0:19 + ld.param.b64 %rd33, [triton_red_fused_zeros_0_param_2]; + ld.param.b64 %rd32, [triton_red_fused_zeros_0_param_1]; + ld.param.b64 %rd31, [triton_red_fused_zeros_0_param_0]; + and.b32 %r3, %r2, 252; + sub.s64 %rd7, %rd1, %rd6; + .loc 1 24 21 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:24:21 + setp.lt.s32 %p4, %r38, %r9; +$L__tmp1: + .loc 2 75 25 // triton_helpers.py:75:25 @[ cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:30:51 ] + setp.lt.s32 %p5, %r1, 0; + .loc 2 75 36 // triton_helpers.py:75:36 @[ cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:30:51 ] + setp.lt.s64 %p6, %rd36, 0; + .loc 2 75 32 // triton_helpers.py:75:32 @[ cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:30:51 ] + xor.pred %p1, %p5, %p6; +$L__tmp2: + .loc 1 40 73 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:40:73 + setp.lt.s64 %p7, %rd36, 2; + .loc 1 40 99 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:40:99 + setp.gt.s64 %p8, %rd36, 1; + .loc 1 40 90 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:40:90 + selp.b64 %rd45, %rd36, 0, %p8; + .loc 1 40 0 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:40 + selp.b64 %rd46, 1, 0, %p7; + .loc 1 40 81 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:40:81 + add.s64 %rd12, %rd45, %rd46; + shl.b64 %rd13, %rd8, 8; + .loc 1 33 40 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:33:40 + cvt.u64.u32 %rd14, %r5; + @%p4 bra $L__BB0_9; + bra.uni $L__BB0_7; +$L__BB0_9: // %.split.us.preheader +$L__tmp3: + .loc 2 74 34 // triton_helpers.py:74:34 @[ cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:30:51 ] + setp.ne.b64 %p14, %rd7, 0; +$L__tmp4: + .loc 1 33 40 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:33:40 + and.pred %p15, %p14, %p1; + selp.b64 %rd71, -1, 0, %p15; + add.s64 %rd72, %rd94, %rd71; + mul.lo.s64 %rd73, %rd72, %rd12; + shl.b64 %rd74, %rd73, 8; + add.s32 %r22, %r1, %r4; + mad.wide.s32 %rd75, %r22, 256, %rd74; + shl.b64 %rd76, %rd14, 1; + or.b64 %rd77, %rd75, %rd76; + shl.b64 %rd78, %rd6, 8; + sub.s64 %rd79, %rd77, %rd78; + add.s64 %rd97, %rd32, %rd79; + mul.lo.s64 %rd80, %rd95, %rd36; + shl.b64 %rd81, %rd80, 13; + mad.wide.s32 %rd82, %r22, 8192, %rd81; + add.s64 %rd83, %rd82, %rd13; + add.s64 %rd84, %rd83, %rd76; + shl.b64 %rd85, %rd6, 13; + sub.s64 %rd86, %rd84, %rd85; + add.s64 %rd96, %rd31, %rd86; + mov.b32 %r40, 0f00000000; + mov.b64 %rd98, -4; +$L__BB0_10: // %.split.us + // =>This Inner Loop Header: Depth=1 + .loc 1 39 74 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:39:74 + // begin inline asm + mov.u64 %rd87, 0x0; + createpolicy.fractional.L2::evict_first.b64 %rd87, 1.0; + // end inline asm + mov.b16 %rs6, 0; + mov.pred %p16, -1; + // begin inline asm + mov.u16 %rs5, %rs6; + @%p16 ld.global.L1::evict_first.L2::cache_hint.b16 { %rs5 }, [ %rd96 + 0 ], %rd87; + // end inline asm + .loc 1 39 136 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:39:136 + cvt.f32.bf16 %r23, %rs5; + .loc 1 40 106 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:40:106 + // begin inline asm + mov.u64 %rd90, 0x0; + createpolicy.fractional.L2::evict_first.b64 %rd90, 1.0; + // end inline asm + // begin inline asm + mov.u16 %rs7, %rs6; + @%p16 ld.global.L1::evict_first.L2::cache_hint.b16 { %rs7 }, [ %rd97 + 0 ], %rd90; + // end inline asm + .loc 1 40 168 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:40:168 + cvt.f32.bf16 %r24, %rs7; + .loc 1 43 23 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:43:23 + fma.rn.f32 %r40, %r23, %r24, %r40; + .loc 1 33 40 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:33:40 + add.s64 %rd98, %rd98, 4; + add.s64 %rd97, %rd97, 8; + add.s64 %rd96, %rd96, 8; + setp.lt.u64 %p18, %rd98, 124; + @%p18 bra $L__BB0_10; + bra.uni $L__BB0_11; +$L__BB0_7: // %.split.preheader +$L__tmp5: + .loc 2 74 34 // triton_helpers.py:74:34 @[ cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:30:51 ] + setp.ne.b64 %p9, %rd7, 0; +$L__tmp6: + .loc 1 33 40 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:33:40 + and.pred %p10, %p9, %p1; + selp.b64 %rd48, -1, 0, %p10; + add.s64 %rd49, %rd94, %rd48; + mul.lo.s64 %rd50, %rd49, %rd12; + shl.b64 %rd51, %rd50, 8; + add.s32 %r19, %r1, %r4; + mad.wide.s32 %rd52, %r19, 256, %rd51; + shl.b64 %rd53, %rd14, 1; + or.b64 %rd54, %rd52, %rd53; + shl.b64 %rd55, %rd6, 8; + sub.s64 %rd56, %rd54, %rd55; + add.s64 %rd100, %rd32, %rd56; + mul.lo.s64 %rd57, %rd95, %rd36; + shl.b64 %rd58, %rd57, 13; + mad.wide.s32 %rd59, %r19, 8192, %rd58; + add.s64 %rd60, %rd59, %rd13; + add.s64 %rd61, %rd60, %rd53; + shl.b64 %rd62, %rd6, 13; + sub.s64 %rd63, %rd61, %rd62; + add.s64 %rd99, %rd31, %rd63; + mov.b64 %rd101, -4; +$L__BB0_8: // %.split + // =>This Inner Loop Header: Depth=1 + .loc 1 39 74 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:39:74 + // begin inline asm + mov.u64 %rd64, 0x0; + createpolicy.fractional.L2::evict_first.b64 %rd64, 1.0; + // end inline asm + mov.b16 %rs2, 0; + mov.pred %p11, 0; + // begin inline asm + mov.u16 %rs1, %rs2; + @%p11 ld.global.L1::evict_first.L2::cache_hint.b16 { %rs1 }, [ %rd99 + 0 ], %rd64; + // end inline asm + .loc 1 40 106 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:40:106 + // begin inline asm + mov.u64 %rd67, 0x0; + createpolicy.fractional.L2::evict_first.b64 %rd67, 1.0; + // end inline asm + // begin inline asm + mov.u16 %rs3, %rs2; + @%p11 ld.global.L1::evict_first.L2::cache_hint.b16 { %rs3 }, [ %rd100 + 0 ], %rd67; + // end inline asm + .loc 1 33 40 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:33:40 + add.s64 %rd101, %rd101, 4; + add.s64 %rd100, %rd100, 8; + add.s64 %rd99, %rd99, 8; + setp.lt.u64 %p13, %rd101, 124; + mov.b32 %r40, 0f00000000; + @%p13 bra $L__BB0_8; +$L__BB0_11: // %.split7.us + .loc 1 23 44 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:23:44 + and.b32 %r26, %r2, 63; + .loc 1 23 23 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:23:23 + or.b32 %r27, %r1, %r26; + .loc 1 24 21 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:24:21 + setp.lt.s32 %p20, %r27, %r9; +$L__tmp7: + .loc 3 291 36 // standard.py:291:36 @[ cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:45:25 ] + shfl.sync.bfly.b32 %r28, %r40, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:45:25 ] + add.f32 %r29, %r40, %r28; + .loc 3 291 36 // standard.py:291:36 @[ cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:45:25 ] + shfl.sync.bfly.b32 %r30, %r29, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:45:25 ] + add.f32 %r31, %r29, %r30; +$L__tmp8: + .loc 1 45 28 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:45:28 + mov.b32 %r32, global_smem; + add.s32 %r33, %r32, %r3; + st.shared.b32 [%r33], %r31; + bar.sync 0; + shl.b32 %r34, %r26, 2; + add.s32 %r35, %r32, %r34; + ld.shared.b32 %r25, [%r35]; + .loc 1 49 25 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:49:25 + mad.wide.s32 %rd93, %r27, 4, %rd33; + .loc 1 49 36 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:49:36 + and.b32 %r36, %r2, 192; + setp.eq.b32 %p21, %r36, 0; + and.pred %p19, %p21, %p20; + // begin inline asm + @%p19 st.global.b32 [ %rd93 + 0 ], { %r25 }; + // end inline asm + .loc 1 49 4 // cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py:49:4 + ret; +$L__tmp9: +$L__func_end0: + // -- End function +} + .file 1 "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py" + .file 2 "/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py" + .file 3 "/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py" + .section .debug_abbrev + { +.b8 1 // Abbreviation Code +.b8 17 // DW_TAG_compile_unit +.b8 1 // DW_CHILDREN_yes +.b8 37 // DW_AT_producer +.b8 8 // DW_FORM_string +.b8 19 // DW_AT_language +.b8 5 // DW_FORM_data2 +.b8 3 // DW_AT_name +.b8 8 // DW_FORM_string +.b8 16 // DW_AT_stmt_list +.b8 6 // DW_FORM_data4 +.b8 27 // DW_AT_comp_dir +.b8 8 // DW_FORM_string +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 2 // Abbreviation Code +.b8 46 // DW_TAG_subprogram +.b8 0 // DW_CHILDREN_no +.b8 3 // DW_AT_name +.b8 8 // DW_FORM_string +.b8 32 // DW_AT_inline +.b8 11 // DW_FORM_data1 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 3 // Abbreviation Code +.b8 46 // DW_TAG_subprogram +.b8 1 // DW_CHILDREN_yes +.b8 17 // DW_AT_low_pc +.b8 1 // DW_FORM_addr +.b8 18 // DW_AT_high_pc +.b8 1 // DW_FORM_addr +.b8 49 // DW_AT_abstract_origin +.b8 19 // DW_FORM_ref4 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 4 // Abbreviation Code +.b8 29 // DW_TAG_inlined_subroutine +.b8 0 // DW_CHILDREN_no +.b8 49 // DW_AT_abstract_origin +.b8 19 // DW_FORM_ref4 +.b8 17 // DW_AT_low_pc +.b8 1 // DW_FORM_addr +.b8 18 // DW_AT_high_pc +.b8 1 // DW_FORM_addr +.b8 88 // DW_AT_call_file +.b8 11 // DW_FORM_data1 +.b8 89 // DW_AT_call_line +.b8 11 // DW_FORM_data1 +.b8 87 // DW_AT_call_column +.b8 11 // DW_FORM_data1 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 0 // EOM(3) + } + .section .debug_info + { +.b32 233 // Length of Unit +.b8 2 // DWARF version number +.b8 0 +.b32 .debug_abbrev // Offset Into Abbrev. Section +.b8 8 // Address Size (in bytes) +.b8 1 // Abbrev [1] 0xb:0xe2 DW_TAG_compile_unit +.b8 116 // DW_AT_producer +.b8 114 +.b8 105 +.b8 116 +.b8 111 +.b8 110 +.b8 0 +.b8 2 // DW_AT_language +.b8 0 +.b8 99 // DW_AT_name +.b8 120 +.b8 99 +.b8 52 +.b8 105 +.b8 110 +.b8 99 +.b8 119 +.b8 109 +.b8 107 +.b8 108 +.b8 50 +.b8 120 +.b8 111 +.b8 104 +.b8 51 +.b8 116 +.b8 54 +.b8 113 +.b8 104 +.b8 114 +.b8 105 +.b8 114 +.b8 104 +.b8 52 +.b8 104 +.b8 117 +.b8 98 +.b8 105 +.b8 118 +.b8 51 +.b8 113 +.b8 54 +.b8 52 +.b8 53 +.b8 118 +.b8 100 +.b8 113 +.b8 53 +.b8 115 +.b8 119 +.b8 98 +.b8 53 +.b8 103 +.b8 50 +.b8 100 +.b8 120 +.b8 101 +.b8 103 +.b8 122 +.b8 55 +.b8 98 +.b8 46 +.b8 112 +.b8 121 +.b8 0 +.b32 .debug_line // DW_AT_stmt_list +.b8 47 // DW_AT_comp_dir +.b8 119 +.b8 111 +.b8 114 +.b8 107 +.b8 115 +.b8 112 +.b8 97 +.b8 99 +.b8 101 +.b8 47 +.b8 104 +.b8 97 +.b8 110 +.b8 114 +.b8 117 +.b8 105 +.b8 47 +.b8 83 +.b8 112 +.b8 101 +.b8 99 +.b8 70 +.b8 111 +.b8 114 +.b8 103 +.b8 101 +.b8 45 +.b8 101 +.b8 120 +.b8 116 +.b8 47 +.b8 99 +.b8 97 +.b8 99 +.b8 104 +.b8 101 +.b8 47 +.b8 99 +.b8 111 +.b8 109 +.b8 112 +.b8 105 +.b8 108 +.b8 101 +.b8 100 +.b8 95 +.b8 107 +.b8 101 +.b8 114 +.b8 110 +.b8 101 +.b8 108 +.b8 115 +.b8 47 +.b8 120 +.b8 99 +.b8 0 +.b8 2 // Abbrev [2] 0x8b:0x1b DW_TAG_subprogram +.b8 116 // DW_AT_name +.b8 114 +.b8 105 +.b8 116 +.b8 111 +.b8 110 +.b8 95 +.b8 114 +.b8 101 +.b8 100 +.b8 95 +.b8 102 +.b8 117 +.b8 115 +.b8 101 +.b8 100 +.b8 95 +.b8 122 +.b8 101 +.b8 114 +.b8 111 +.b8 115 +.b8 95 +.b8 48 +.b8 0 +.b8 1 // DW_AT_inline +.b8 3 // Abbrev [3] 0xa6:0x46 DW_TAG_subprogram +.b64 $L__func_begin0 // DW_AT_low_pc +.b64 $L__func_end0 // DW_AT_high_pc +.b32 139 // DW_AT_abstract_origin +.b8 4 // Abbrev [4] 0xbb:0x18 DW_TAG_inlined_subroutine +.b32 139 // DW_AT_abstract_origin +.b64 $L__tmp1 // DW_AT_low_pc +.b64 $L__tmp6 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 30 // DW_AT_call_line +.b8 51 // DW_AT_call_column +.b8 4 // Abbrev [4] 0xd3:0x18 DW_TAG_inlined_subroutine +.b32 139 // DW_AT_abstract_origin +.b64 $L__tmp7 // DW_AT_low_pc +.b64 $L__tmp8 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 45 // DW_AT_call_line +.b8 25 // DW_AT_call_column +.b8 0 // End Of Children Mark +.b8 0 // End Of Children Mark + } + .section .debug_macinfo { } diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/FM6CWX5NDNUSSOQWH3EGE2NL4IM53NMNXM4NI2FUVI2663Y6EMSQ/triton_red_fused_zeros_0.source b/SpecForge-ext/cache/compiled_kernels/triton/0/FM6CWX5NDNUSSOQWH3EGE2NL4IM53NMNXM4NI2FUVI2663Y6EMSQ/triton_red_fused_zeros_0.source new file mode 100644 index 0000000000000000000000000000000000000000..4f580c765aab3da16073237cfc6ffc84873b6e59 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/FM6CWX5NDNUSSOQWH3EGE2NL4IM53NMNXM4NI2FUVI2663Y6EMSQ/triton_red_fused_zeros_0.source @@ -0,0 +1,325 @@ +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":18:0) +#loc56 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":69:0) +#loc68 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":285:0) +#loc70 = loc(unknown) +#loc73 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":260:0) +#loc77 = loc("in_ptr0"(#loc)) +#loc78 = loc("in_ptr1"(#loc)) +#loc79 = loc("out_ptr1"(#loc)) +#loc80 = loc("ks0"(#loc)) +#loc81 = loc("ks1"(#loc)) +#loc82 = loc("xnumel"(#loc)) +#loc83 = loc("r0_numel"(#loc)) +#loc135 = loc("a"(#loc56)) +#loc136 = loc("b"(#loc56)) +#loc142 = loc("input"(#loc68)) +#loc143 = loc("a"(#loc73)) +#loc144 = loc("b"(#loc73)) +module { + tt.func public @triton_red_fused_zeros_0(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %in_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr1"(#loc)), %out_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr1"(#loc)), %ks0: i64 loc("ks0"(#loc)), %ks1: i64 {tt.divisibility = 16 : i32} loc("ks1"(#loc)), %xnumel: i32 {tt.divisibility = 16 : i32} loc("xnumel"(#loc)), %r0_numel: i32 {tt.divisibility = 16 : i32} loc("r0_numel"(#loc))) attributes {noinline = false} { + %r0_numel_0 = arith.constant 128 : i32 loc(#loc84) + %xoffset = tt.get_program_id x : i32 loc(#loc85) + %xoffset_1 = arith.constant 64 : i32 loc(#loc86) + %xoffset_2 = arith.constant 64 : i32 loc(#loc86) + %xoffset_3 = arith.muli %xoffset, %xoffset_2 : i32 loc(#loc86) + %xindex = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc87) + %xindex_4 = tt.expand_dims %xindex {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc88) + %xindex_5 = tt.splat %xoffset_3 : i32 -> tensor<64x1xi32> loc(#loc89) + %xindex_6 = arith.addi %xindex_5, %xindex_4 : tensor<64x1xi32> loc(#loc89) + %xmask = tt.splat %xnumel : i32 -> tensor<64x1xi32> loc(#loc90) + %xmask_7 = arith.cmpi slt, %xindex_6, %xmask : tensor<64x1xi32> loc(#loc90) + %r0_base = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> loc(#loc91) + %r0_base_8 = tt.expand_dims %r0_base {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> loc(#loc92) + %x0 = arith.extsi %xindex_6 : tensor<64x1xi32> to tensor<64x1xi64> loc(#loc93) + %x0_9 = tt.splat %ks0 : i64 -> tensor<64x1xi64> loc(#loc93) + %x0_10 = arith.remsi %x0, %x0_9 : tensor<64x1xi64> loc(#loc93) + %x1 = arith.extsi %xindex_6 : tensor<64x1xi32> to tensor<64x1xi64> loc(#loc94) + %x1_11 = tt.splat %ks0 : i64 -> tensor<64x1xi64> loc(#loc94) + %x1_12 = arith.divsi %x1, %x1_11 : tensor<64x1xi64> loc(#loc94) + %x1_13 = arith.constant 32 : i32 loc(#loc95) + %x1_14 = arith.constant 32 : i64 loc(#loc95) + %x1_15 = arith.constant dense<32> : tensor<64x1xi64> loc(#loc95) + %x1_16 = arith.remsi %x1_12, %x1_15 : tensor<64x1xi64> loc(#loc95) + %x2 = arith.extsi %xindex_6 : tensor<64x1xi32> to tensor<64x1xi64> loc(#loc96) + %x2_17 = tt.splat %ks1 : i64 -> tensor<64x1xi64> loc(#loc96) + %x2_18 = arith.divsi %x2, %x2_17 : tensor<64x1xi64> loc(#loc96) + %x5 = tt.call @torch._inductor.runtime.triton_helpers.div_floor_integer__i32S64_1S_i64__(%xindex_6, %ks0) : (tensor<64x1xi32>, i64) -> tensor<64x1xi64> loc(#loc97) + %_tmp4 = arith.constant 0.000000e+00 : f32 loc(#loc98) + %_tmp4_19 = arith.constant dense<0.000000e+00> : tensor<64x4xf32> loc(#loc98) + %c0_i32 = arith.constant 0 : i32 loc(#loc16) + %c4_i32 = arith.constant 4 : i32 loc(#loc16) + %0 = arith.bitcast %c0_i32 : i32 to i32 loc(#loc16) + %1 = arith.bitcast %r0_numel_0 : i32 to i32 loc(#loc16) + %2 = arith.bitcast %c4_i32 : i32 to i32 loc(#loc16) + %3 = ub.poison : i32 loc(#loc16) + %_tmp4_20 = scf.for %r0_offset = %0 to %1 step %2 iter_args(%_tmp4_23 = %_tmp4_19) -> (tensor<64x4xf32>) : i32 { + %r0_index = tt.splat %r0_offset : i32 -> tensor<1x4xi32> loc(#loc100) + %r0_index_24 = arith.addi %r0_index, %r0_base_8 : tensor<1x4xi32> loc(#loc100) + %r0_mask = arith.constant dense<128> : tensor<1x4xi32> loc(#loc101) + %r0_mask_25 = arith.cmpi slt, %r0_index_24, %r0_mask : tensor<1x4xi32> loc(#loc101) + %tmp0 = arith.constant 128 : i32 loc(#loc102) + %tmp0_26 = arith.constant 128 : i64 loc(#loc102) + %tmp0_27 = arith.constant dense<128> : tensor<64x1xi64> loc(#loc102) + %tmp0_28 = arith.muli %tmp0_27, %x1_16 : tensor<64x1xi64> loc(#loc102) + %tmp0_29 = arith.extsi %r0_index_24 : tensor<1x4xi32> to tensor<1x4xi64> loc(#loc103) + %tmp0_30 = tt.broadcast %tmp0_29 : tensor<1x4xi64> -> tensor<64x4xi64> loc(#loc103) + %tmp0_31 = tt.broadcast %tmp0_28 : tensor<64x1xi64> -> tensor<64x4xi64> loc(#loc103) + %tmp0_32 = arith.addi %tmp0_30, %tmp0_31 : tensor<64x4xi64> loc(#loc103) + %tmp0_33 = arith.constant 4096 : i32 loc(#loc104) + %tmp0_34 = arith.constant 4096 : i64 loc(#loc104) + %tmp0_35 = arith.constant dense<4096> : tensor<64x1xi64> loc(#loc104) + %tmp0_36 = arith.muli %tmp0_35, %x0_10 : tensor<64x1xi64> loc(#loc104) + %tmp0_37 = tt.broadcast %tmp0_36 : tensor<64x1xi64> -> tensor<64x4xi64> loc(#loc105) + %tmp0_38 = arith.addi %tmp0_32, %tmp0_37 : tensor<64x4xi64> loc(#loc105) + %tmp0_39 = arith.constant 4096 : i32 loc(#loc106) + %tmp0_40 = arith.constant 4096 : i64 loc(#loc106) + %tmp0_41 = arith.muli %tmp0_40, %ks0 : i64 loc(#loc106) + %tmp0_42 = tt.splat %tmp0_41 : i64 -> tensor<64x1xi64> loc(#loc107) + %tmp0_43 = arith.muli %tmp0_42, %x2_18 : tensor<64x1xi64> loc(#loc107) + %tmp0_44 = tt.broadcast %tmp0_43 : tensor<64x1xi64> -> tensor<64x4xi64> loc(#loc108) + %tmp0_45 = arith.addi %tmp0_38, %tmp0_44 : tensor<64x4xi64> loc(#loc108) + %tmp0_46 = tt.splat %in_ptr0 : !tt.ptr -> tensor<64x4x!tt.ptr> loc(#loc109) + %tmp0_47 = tt.addptr %tmp0_46, %tmp0_45 : tensor<64x4x!tt.ptr>, tensor<64x4xi64> loc(#loc109) + %tmp0_48 = tt.broadcast %r0_mask_25 : tensor<1x4xi1> -> tensor<64x4xi1> loc(#loc110) + %tmp0_49 = tt.broadcast %xmask_7 : tensor<64x1xi1> -> tensor<64x4xi1> loc(#loc110) + %tmp0_50 = arith.andi %tmp0_48, %tmp0_49 : tensor<64x4xi1> loc(#loc110) + %tmp0_51 = arith.constant 0.000000e+00 : f32 loc(#loc111) + %tmp0_52 = arith.constant dense<0.000000e+00> : tensor<64x4xf32> loc(#loc111) + %tmp0_53 = arith.truncf %tmp0_52 : tensor<64x4xf32> to tensor<64x4xbf16> loc(#loc111) + %tmp0_54 = tt.load %tmp0_47, %tmp0_50, %tmp0_53 evictionPolicy = evict_first : tensor<64x4x!tt.ptr> loc(#loc111) + %tmp0_55 = arith.extf %tmp0_54 : tensor<64x4xbf16> to tensor<64x4xf32> loc(#loc112) + %tmp1 = arith.constant 128 : i32 loc(#loc113) + %tmp1_56 = arith.constant 128 : i64 loc(#loc113) + %tmp1_57 = arith.constant dense<128> : tensor<64x1xi64> loc(#loc113) + %tmp1_58 = arith.muli %tmp1_57, %x0_10 : tensor<64x1xi64> loc(#loc113) + %tmp1_59 = arith.extsi %r0_index_24 : tensor<1x4xi32> to tensor<1x4xi64> loc(#loc114) + %tmp1_60 = tt.broadcast %tmp1_59 : tensor<1x4xi64> -> tensor<64x4xi64> loc(#loc114) + %tmp1_61 = tt.broadcast %tmp1_58 : tensor<64x1xi64> -> tensor<64x4xi64> loc(#loc114) + %tmp1_62 = arith.addi %tmp1_60, %tmp1_61 : tensor<64x4xi64> loc(#loc114) + %tmp1_63 = arith.constant 128 : i32 loc(#loc115) + %tmp1_64 = arith.constant 128 : i64 loc(#loc115) + %tmp1_65 = arith.constant dense<128> : tensor<64x1xi64> loc(#loc115) + %tmp1_66 = arith.muli %tmp1_65, %x5 : tensor<64x1xi64> loc(#loc115) + %tmp1_67 = arith.constant 1 : i32 loc(#loc116) + %tmp1_68 = arith.extsi %tmp1_67 : i32 to i64 loc(#loc116) + %tmp1_69 = arith.cmpi sge, %tmp1_68, %ks0 : i64 loc(#loc116) + %tmp1_70 = arith.constant 1 : i32 loc(#loc117) + %tmp1_71 = arith.constant 1 : i32 loc(#loc117) + %tmp1_72 = arith.extui %tmp1_69 : i1 to i32 loc(#loc117) + %tmp1_73 = arith.muli %tmp1_71, %tmp1_72 : i32 loc(#loc117) + %tmp1_74 = arith.constant 1 : i32 loc(#loc118) + %tmp1_75 = arith.extsi %tmp1_74 : i32 to i64 loc(#loc118) + %tmp1_76 = arith.cmpi sgt, %ks0, %tmp1_75 : i64 loc(#loc118) + %tmp1_77 = arith.extui %tmp1_76 : i1 to i64 loc(#loc119) + %tmp1_78 = arith.muli %ks0, %tmp1_77 : i64 loc(#loc119) + %tmp1_79 = arith.extsi %tmp1_73 : i32 to i64 loc(#loc120) + %tmp1_80 = arith.addi %tmp1_79, %tmp1_78 : i64 loc(#loc120) + %tmp1_81 = tt.splat %tmp1_80 : i64 -> tensor<64x1xi64> loc(#loc121) + %tmp1_82 = arith.muli %tmp1_66, %tmp1_81 : tensor<64x1xi64> loc(#loc121) + %tmp1_83 = tt.broadcast %tmp1_82 : tensor<64x1xi64> -> tensor<64x4xi64> loc(#loc122) + %tmp1_84 = arith.addi %tmp1_62, %tmp1_83 : tensor<64x4xi64> loc(#loc122) + %tmp1_85 = tt.splat %in_ptr1 : !tt.ptr -> tensor<64x4x!tt.ptr> loc(#loc123) + %tmp1_86 = tt.addptr %tmp1_85, %tmp1_84 : tensor<64x4x!tt.ptr>, tensor<64x4xi64> loc(#loc123) + %tmp1_87 = tt.broadcast %r0_mask_25 : tensor<1x4xi1> -> tensor<64x4xi1> loc(#loc124) + %tmp1_88 = tt.broadcast %xmask_7 : tensor<64x1xi1> -> tensor<64x4xi1> loc(#loc124) + %tmp1_89 = arith.andi %tmp1_87, %tmp1_88 : tensor<64x4xi1> loc(#loc124) + %tmp1_90 = arith.constant 0.000000e+00 : f32 loc(#loc125) + %tmp1_91 = arith.constant dense<0.000000e+00> : tensor<64x4xf32> loc(#loc125) + %tmp1_92 = arith.truncf %tmp1_91 : tensor<64x4xf32> to tensor<64x4xbf16> loc(#loc125) + %tmp1_93 = tt.load %tmp1_86, %tmp1_89, %tmp1_92 evictionPolicy = evict_first : tensor<64x4x!tt.ptr> loc(#loc125) + %tmp1_94 = arith.extf %tmp1_93 : tensor<64x4xbf16> to tensor<64x4xf32> loc(#loc126) + %tmp2 = arith.mulf %tmp0_55, %tmp1_94 : tensor<64x4xf32> loc(#loc127) + %tmp5 = arith.addf %_tmp4_23, %tmp2 : tensor<64x4xf32> loc(#loc128) + %_tmp4_95 = tt.broadcast %r0_mask_25 : tensor<1x4xi1> -> tensor<64x4xi1> loc(#loc129) + %_tmp4_96 = tt.broadcast %xmask_7 : tensor<64x1xi1> -> tensor<64x4xi1> loc(#loc129) + %_tmp4_97 = arith.andi %_tmp4_95, %_tmp4_96 : tensor<64x4xi1> loc(#loc129) + %_tmp4_98 = arith.select %_tmp4_97, %tmp5, %_tmp4_23 : tensor<64x4xi1>, tensor<64x4xf32> loc(#loc130) + scf.yield %_tmp4_98 : tensor<64x4xf32> loc(#loc48) + } loc(#loc99) + %tmp4 = tt.call @"triton.language.standard.sum__fp32S64_4S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%_tmp4_20) : (tensor<64x4xf32>) -> tensor<64xf32> loc(#loc131) + %tmp4_21 = tt.expand_dims %tmp4 {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32> loc(#loc132) + %tmp7 = arith.constant 0.000000e+00 : f32 loc(#loc133) + %tmp8 = arith.constant dense<0.000000e+00> : tensor<64x1xf32> loc(#loc134) + %tmp8_22 = arith.subf %tmp4_21, %tmp8 : tensor<64x1xf32> loc(#loc134) + %4 = tt.splat %out_ptr1 : !tt.ptr -> tensor<64x1x!tt.ptr> loc(#loc53) + %5 = tt.addptr %4, %xindex_6 : tensor<64x1x!tt.ptr>, tensor<64x1xi32> loc(#loc53) + tt.store %5, %tmp8_22, %xmask_7 : tensor<64x1x!tt.ptr> loc(#loc54) + tt.return loc(#loc55) + } loc(#loc) + tt.func private @torch._inductor.runtime.triton_helpers.div_floor_integer__i32S64_1S_i64__(%a: tensor<64x1xi32> loc("a"(#loc56)), %b: i64 loc("b"(#loc56))) -> tensor<64x1xi64> attributes {noinline = false} { + %quot = arith.extsi %a : tensor<64x1xi32> to tensor<64x1xi64> loc(#loc137) + %quot_0 = tt.splat %b : i64 -> tensor<64x1xi64> loc(#loc137) + %quot_1 = arith.divsi %quot, %quot_0 : tensor<64x1xi64> loc(#loc137) + %remainder = arith.extsi %a : tensor<64x1xi32> to tensor<64x1xi64> loc(#loc138) + %remainder_2 = tt.splat %b : i64 -> tensor<64x1xi64> loc(#loc138) + %remainder_3 = arith.remsi %remainder, %remainder_2 : tensor<64x1xi64> loc(#loc138) + %fixed = arith.constant 0 : i32 loc(#loc139) + %fixed_4 = arith.extsi %fixed : i32 to i64 loc(#loc139) + %fixed_5 = tt.splat %fixed_4 : i64 -> tensor<64x1xi64> loc(#loc139) + %fixed_6 = arith.cmpi ne, %remainder_3, %fixed_5 : tensor<64x1xi64> loc(#loc139) + %fixed_7 = arith.constant 1 : i32 loc(#loc140) + %fixed_8 = arith.constant 1 : i64 loc(#loc140) + %fixed_9 = arith.constant dense<1> : tensor<64x1xi64> loc(#loc140) + %fixed_10 = arith.subi %quot_1, %fixed_9 : tensor<64x1xi64> loc(#loc140) + %fixed_11 = arith.select %fixed_6, %fixed_10, %quot_1 : tensor<64x1xi1>, tensor<64x1xi64> loc(#loc141) + %c0_i32 = arith.constant 0 : i32 loc(#loc62) + %cst = arith.constant dense<0> : tensor<64x1xi32> loc(#loc62) + %0 = arith.cmpi slt, %a, %cst : tensor<64x1xi32> loc(#loc62) + %c0_i32_12 = arith.constant 0 : i32 loc(#loc63) + %1 = arith.extsi %c0_i32_12 : i32 to i64 loc(#loc63) + %2 = arith.cmpi slt, %b, %1 : i64 loc(#loc63) + %3 = tt.splat %2 : i1 -> tensor<64x1xi1> loc(#loc64) + %4 = arith.cmpi ne, %0, %3 : tensor<64x1xi1> loc(#loc64) + %5 = arith.select %4, %fixed_11, %quot_1 : tensor<64x1xi1>, tensor<64x1xi64> loc(#loc65) + tt.return %5 : tensor<64x1xi64> loc(#loc66) + ^bb1: // no predecessors + %6 = ub.poison : tensor<64x1xi64> loc(#loc67) + tt.return %6 : tensor<64x1xi64> loc(#loc67) + } loc(#loc56) + tt.func private @"triton.language.standard.sum__fp32S64_4S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%input: tensor<64x4xf32> loc("input"(#loc68))) -> tensor<64xf32> attributes {noinline = false} { + %0 = "tt.reduce"(%input) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32 loc(unknown), %arg2: f32 loc(unknown)): + %2 = tt.call @triton.language.standard._sum_combine__fp32_fp32__(%arg1, %arg2) : (f32, f32) -> f32 loc(#loc69) + tt.reduce.return %2 : f32 loc(#loc69) + }) : (tensor<64x4xf32>) -> tensor<64xf32> loc(#loc69) + tt.return %0 : tensor<64xf32> loc(#loc71) + ^bb1: // no predecessors + %1 = ub.poison : tensor<64xf32> loc(#loc72) + tt.return %1 : tensor<64xf32> loc(#loc72) + } loc(#loc68) + tt.func private @triton.language.standard._sum_combine__fp32_fp32__(%a: f32 loc("a"(#loc73)), %b: f32 loc("b"(#loc73))) -> f32 attributes {noinline = false} { + %0 = arith.addf %a, %b : f32 loc(#loc74) + tt.return %0 : f32 loc(#loc75) + ^bb1: // no predecessors + %1 = ub.poison : f32 loc(#loc76) + tt.return %1 : f32 loc(#loc76) + } loc(#loc73) +} loc(#loc) +#loc1 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":19:15) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":22:28) +#loc3 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":22:33) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":23:36) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":23:44) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":23:23) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":24:21) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":25:27) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":25:37) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":27:19) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":28:21) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":28:28) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":29:19) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":30:51) +#loc15 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":31:43) +#loc16 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":33:40) +#loc17 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":34:31) +#loc18 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":35:29) +#loc19 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:45) +#loc20 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:41) +#loc21 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:55) +#loc22 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:50) +#loc23 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:65) +#loc24 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:69) +#loc25 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:60) +#loc26 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:34) +#loc27 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:84) +#loc28 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:74) +#loc29 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:136) +#loc30 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:45) +#loc31 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:41) +#loc32 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:54) +#loc33 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:73) +#loc34 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:65) +#loc35 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:99) +#loc36 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:90) +#loc37 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:81) +#loc38 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:58) +#loc39 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:50) +#loc40 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:34) +#loc41 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:116) +#loc42 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:106) +#loc43 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:168) +#loc44 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":41:22) +#loc45 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":43:23) +#loc46 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":44:35) +#loc47 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":44:48) +#loc48 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":44:8) +#loc49 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":45:25) +#loc50 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":45:28) +#loc51 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":47:11) +#loc52 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":48:18) +#loc53 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":49:25) +#loc54 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":49:36) +#loc55 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":49:4) +#loc57 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":72:16) +#loc58 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":73:20) +#loc59 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":74:34) +#loc60 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":74:44) +#loc61 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":74:47) +#loc62 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":75:25) +#loc63 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":75:36) +#loc64 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":75:32) +#loc65 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":75:47) +#loc66 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":75:11) +#loc67 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":75:4) +#loc69 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:36) +#loc71 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:11) +#loc72 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:4) +#loc74 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:15) +#loc75 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:11) +#loc76 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:4) +#loc84 = loc("r0_numel"(#loc1)) +#loc85 = loc("xoffset"(#loc2)) +#loc86 = loc("xoffset"(#loc3)) +#loc87 = loc("xindex"(#loc4)) +#loc88 = loc("xindex"(#loc5)) +#loc89 = loc("xindex"(#loc6)) +#loc90 = loc("xmask"(#loc7)) +#loc91 = loc("r0_base"(#loc8)) +#loc92 = loc("r0_base"(#loc9)) +#loc93 = loc("x0"(#loc10)) +#loc94 = loc("x1"(#loc11)) +#loc95 = loc("x1"(#loc12)) +#loc96 = loc("x2"(#loc13)) +#loc97 = loc("x5"(#loc14)) +#loc98 = loc("_tmp4"(#loc15)) +#loc99 = loc("_tmp4"(#loc16)) +#loc100 = loc("r0_index"(#loc17)) +#loc101 = loc("r0_mask"(#loc18)) +#loc102 = loc("tmp0"(#loc19)) +#loc103 = loc("tmp0"(#loc20)) +#loc104 = loc("tmp0"(#loc21)) +#loc105 = loc("tmp0"(#loc22)) +#loc106 = loc("tmp0"(#loc23)) +#loc107 = loc("tmp0"(#loc24)) +#loc108 = loc("tmp0"(#loc25)) +#loc109 = loc("tmp0"(#loc26)) +#loc110 = loc("tmp0"(#loc27)) +#loc111 = loc("tmp0"(#loc28)) +#loc112 = loc("tmp0"(#loc29)) +#loc113 = loc("tmp1"(#loc30)) +#loc114 = loc("tmp1"(#loc31)) +#loc115 = loc("tmp1"(#loc32)) +#loc116 = loc("tmp1"(#loc33)) +#loc117 = loc("tmp1"(#loc34)) +#loc118 = loc("tmp1"(#loc35)) +#loc119 = loc("tmp1"(#loc36)) +#loc120 = loc("tmp1"(#loc37)) +#loc121 = loc("tmp1"(#loc38)) +#loc122 = loc("tmp1"(#loc39)) +#loc123 = loc("tmp1"(#loc40)) +#loc124 = loc("tmp1"(#loc41)) +#loc125 = loc("tmp1"(#loc42)) +#loc126 = loc("tmp1"(#loc43)) +#loc127 = loc("tmp2"(#loc44)) +#loc128 = loc("tmp5"(#loc45)) +#loc129 = loc("_tmp4"(#loc46)) +#loc130 = loc("_tmp4"(#loc47)) +#loc131 = loc("tmp4"(#loc49)) +#loc132 = loc("tmp4"(#loc50)) +#loc133 = loc("tmp7"(#loc51)) +#loc134 = loc("tmp8"(#loc52)) +#loc137 = loc("quot"(#loc57)) +#loc138 = loc("remainder"(#loc58)) +#loc139 = loc("fixed"(#loc59)) +#loc140 = loc("fixed"(#loc60)) +#loc141 = loc("fixed"(#loc61)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/FM6CWX5NDNUSSOQWH3EGE2NL4IM53NMNXM4NI2FUVI2663Y6EMSQ/triton_red_fused_zeros_0.ttgir b/SpecForge-ext/cache/compiled_kernels/triton/0/FM6CWX5NDNUSSOQWH3EGE2NL4IM53NMNXM4NI2FUVI2663Y6EMSQ/triton_red_fused_zeros_0.ttgir new file mode 100644 index 0000000000000000000000000000000000000000..b508f7b67ae7f71ec23213897f80f7794a6c8a20 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/FM6CWX5NDNUSSOQWH3EGE2NL4IM53NMNXM4NI2FUVI2663Y6EMSQ/triton_red_fused_zeros_0.ttgir @@ -0,0 +1,233 @@ +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 4], order = [0, 1]}> +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":18:0) +#loc1 = loc(unknown) +#loc52 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":45:25) +#loc58 = loc("in_ptr0"(#loc)) +#loc59 = loc("in_ptr1"(#loc)) +#loc60 = loc("out_ptr1"(#loc)) +#loc61 = loc("ks0"(#loc)) +#loc62 = loc("ks1"(#loc)) +#loc63 = loc("xnumel"(#loc)) +#loc64 = loc("r0_numel"(#loc)) +#loc109 = loc("tmp4"(#loc52)) +#loc120 = loc(callsite(#loc1 at #loc109)) +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @triton_red_fused_zeros_0(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %in_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr1"(#loc)), %out_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr1"(#loc)), %ks0: i64 loc("ks0"(#loc)), %ks1: i64 {tt.divisibility = 16 : i32} loc("ks1"(#loc)), %xnumel: i32 {tt.divisibility = 16 : i32} loc("xnumel"(#loc)), %r0_numel: i32 {tt.divisibility = 16 : i32} loc("r0_numel"(#loc))) attributes {noinline = false} { + %cst = arith.constant dense<32> : tensor<64x1xi64, #blocked> loc(#loc1) + %cst_0 = arith.constant dense<128> : tensor<64x1xi64, #blocked> loc(#loc1) + %cst_1 = arith.constant dense<4096> : tensor<64x1xi64, #blocked> loc(#loc1) + %cst_2 = arith.constant dense<0> : tensor<64x1xi64, #blocked> loc(#loc1) + %cst_3 = arith.constant dense<0> : tensor<64x1xi32, #blocked> loc(#loc1) + %cst_4 = arith.constant dense<1> : tensor<64x1xi64, #blocked> loc(#loc1) + %c0_i64 = arith.constant 0 : i64 loc(#loc1) + %c1_i64 = arith.constant 1 : i64 loc(#loc1) + %cst_5 = arith.constant dense<0.000000e+00> : tensor<64x4xbf16, #blocked> loc(#loc1) + %c4_i32 = arith.constant 4 : i32 loc(#loc1) + %c128_i32 = arith.constant 128 : i32 loc(#loc1) + %c0_i32 = arith.constant 0 : i32 loc(#loc1) + %c4096_i64 = arith.constant 4096 : i64 loc(#loc1) + %cst_6 = arith.constant dense<128> : tensor<1x4xi32, #blocked> loc(#loc1) + %cst_7 = arith.constant dense<0.000000e+00> : tensor<64x4xf32, #blocked> loc(#loc1) + %c64_i32 = arith.constant 64 : i32 loc(#loc1) + %xoffset = tt.get_program_id x : i32 loc(#loc65) + %xoffset_8 = arith.muli %xoffset, %c64_i32 : i32 loc(#loc66) + %xindex = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc67) + %xindex_9 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc67) + %xindex_10 = tt.expand_dims %xindex {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc67) + %xindex_11 = tt.expand_dims %xindex_9 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> loc(#loc67) + %xindex_12 = tt.splat %xoffset_8 : i32 -> tensor<64x1xi32, #blocked> loc(#loc68) + %xindex_13 = tt.splat %xoffset_8 : i32 -> tensor<64x1xi32, #blocked1> loc(#loc68) + %xindex_14 = arith.addi %xindex_12, %xindex_10 : tensor<64x1xi32, #blocked> loc(#loc68) + %xindex_15 = arith.addi %xindex_13, %xindex_11 : tensor<64x1xi32, #blocked1> loc(#loc68) + %xmask = tt.splat %xnumel : i32 -> tensor<64x1xi32, #blocked> loc(#loc69) + %xmask_16 = tt.splat %xnumel : i32 -> tensor<64x1xi32, #blocked1> loc(#loc69) + %xmask_17 = arith.cmpi slt, %xindex_14, %xmask : tensor<64x1xi32, #blocked> loc(#loc69) + %xmask_18 = arith.cmpi slt, %xindex_15, %xmask_16 : tensor<64x1xi32, #blocked1> loc(#loc69) + %r0_base = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc70) + %r0_base_19 = tt.expand_dims %r0_base {axis = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x4xi32, #blocked> loc(#loc70) + %x0 = arith.extsi %xindex_14 : tensor<64x1xi32, #blocked> to tensor<64x1xi64, #blocked> loc(#loc71) + %x0_20 = tt.splat %ks0 : i64 -> tensor<64x1xi64, #blocked> loc(#loc71) + %x0_21 = arith.remsi %x0, %x0_20 : tensor<64x1xi64, #blocked> loc(#loc71) + %x1 = arith.divsi %x0, %x0_20 : tensor<64x1xi64, #blocked> loc(#loc72) + %x1_22 = arith.remsi %x1, %cst : tensor<64x1xi64, #blocked> loc(#loc73) + %x2 = tt.splat %ks1 : i64 -> tensor<64x1xi64, #blocked> loc(#loc74) + %x2_23 = arith.divsi %x0, %x2 : tensor<64x1xi64, #blocked> loc(#loc74) + %fixed = arith.cmpi ne, %x0_21, %cst_2 : tensor<64x1xi64, #blocked> loc(#loc111) + %fixed_24 = arith.subi %x1, %cst_4 : tensor<64x1xi64, #blocked> loc(#loc112) + %fixed_25 = arith.select %fixed, %fixed_24, %x1 : tensor<64x1xi1, #blocked>, tensor<64x1xi64, #blocked> loc(#loc113) + %x5 = arith.cmpi slt, %xindex_14, %cst_3 : tensor<64x1xi32, #blocked> loc(#loc114) + %x5_26 = arith.cmpi slt, %ks0, %c0_i64 : i64 loc(#loc115) + %x5_27 = tt.splat %x5_26 : i1 -> tensor<64x1xi1, #blocked> loc(#loc116) + %x5_28 = arith.cmpi ne, %x5, %x5_27 : tensor<64x1xi1, #blocked> loc(#loc116) + %x5_29 = arith.select %x5_28, %fixed_25, %x1 : tensor<64x1xi1, #blocked>, tensor<64x1xi64, #blocked> loc(#loc117) + %tmp0 = arith.muli %x1_22, %cst_0 : tensor<64x1xi64, #blocked> loc(#loc79) + %tmp0_30 = tt.broadcast %tmp0 : tensor<64x1xi64, #blocked> -> tensor<64x4xi64, #blocked> loc(#loc80) + %tmp0_31 = arith.muli %x0_21, %cst_1 : tensor<64x1xi64, #blocked> loc(#loc81) + %tmp0_32 = tt.broadcast %tmp0_31 : tensor<64x1xi64, #blocked> -> tensor<64x4xi64, #blocked> loc(#loc82) + %tmp0_33 = arith.muli %ks0, %c4096_i64 : i64 loc(#loc83) + %tmp0_34 = tt.splat %tmp0_33 : i64 -> tensor<64x1xi64, #blocked> loc(#loc84) + %tmp0_35 = arith.muli %tmp0_34, %x2_23 : tensor<64x1xi64, #blocked> loc(#loc84) + %tmp0_36 = tt.broadcast %tmp0_35 : tensor<64x1xi64, #blocked> -> tensor<64x4xi64, #blocked> loc(#loc85) + %tmp0_37 = tt.splat %in_ptr0 : !tt.ptr -> tensor<64x4x!tt.ptr, #blocked> loc(#loc86) + %tmp0_38 = tt.broadcast %xmask_17 : tensor<64x1xi1, #blocked> -> tensor<64x4xi1, #blocked> loc(#loc87) + %tmp1 = arith.muli %x0_21, %cst_0 : tensor<64x1xi64, #blocked> loc(#loc88) + %tmp1_39 = tt.broadcast %tmp1 : tensor<64x1xi64, #blocked> -> tensor<64x4xi64, #blocked> loc(#loc89) + %tmp1_40 = arith.muli %x5_29, %cst_0 : tensor<64x1xi64, #blocked> loc(#loc90) + %tmp1_41 = arith.cmpi sle, %ks0, %c1_i64 : i64 loc(#loc91) + %tmp1_42 = arith.cmpi sgt, %ks0, %c1_i64 : i64 loc(#loc92) + %tmp1_43 = arith.extui %tmp1_42 : i1 to i64 loc(#loc93) + %tmp1_44 = arith.muli %ks0, %tmp1_43 : i64 loc(#loc93) + %tmp1_45 = arith.extui %tmp1_41 : i1 to i64 loc(#loc118) + %tmp1_46 = arith.addi %tmp1_45, %tmp1_44 : i64 loc(#loc94) + %tmp1_47 = tt.splat %tmp1_46 : i64 -> tensor<64x1xi64, #blocked> loc(#loc96) + %tmp1_48 = arith.muli %tmp1_40, %tmp1_47 : tensor<64x1xi64, #blocked> loc(#loc96) + %tmp1_49 = tt.broadcast %tmp1_48 : tensor<64x1xi64, #blocked> -> tensor<64x4xi64, #blocked> loc(#loc97) + %tmp1_50 = tt.splat %in_ptr1 : !tt.ptr -> tensor<64x4x!tt.ptr, #blocked> loc(#loc98) + %_tmp4 = scf.for %r0_offset = %c0_i32 to %c128_i32 step %c4_i32 iter_args(%_tmp4_53 = %cst_7) -> (tensor<64x4xf32, #blocked>) : i32 { + %r0_index = tt.splat %r0_offset : i32 -> tensor<1x4xi32, #blocked> loc(#loc100) + %r0_index_54 = arith.addi %r0_index, %r0_base_19 : tensor<1x4xi32, #blocked> loc(#loc100) + %r0_mask = arith.cmpi slt, %r0_index_54, %cst_6 : tensor<1x4xi32, #blocked> loc(#loc101) + %tmp0_55 = arith.extsi %r0_index_54 : tensor<1x4xi32, #blocked> to tensor<1x4xi64, #blocked> loc(#loc80) + %tmp0_56 = tt.broadcast %tmp0_55 : tensor<1x4xi64, #blocked> -> tensor<64x4xi64, #blocked> loc(#loc80) + %tmp0_57 = arith.addi %tmp0_56, %tmp0_30 : tensor<64x4xi64, #blocked> loc(#loc80) + %tmp0_58 = arith.addi %tmp0_57, %tmp0_32 : tensor<64x4xi64, #blocked> loc(#loc82) + %tmp0_59 = arith.addi %tmp0_58, %tmp0_36 : tensor<64x4xi64, #blocked> loc(#loc85) + %tmp0_60 = tt.addptr %tmp0_37, %tmp0_59 : tensor<64x4x!tt.ptr, #blocked>, tensor<64x4xi64, #blocked> loc(#loc86) + %tmp0_61 = tt.broadcast %r0_mask : tensor<1x4xi1, #blocked> -> tensor<64x4xi1, #blocked> loc(#loc87) + %tmp0_62 = arith.andi %tmp0_61, %tmp0_38 : tensor<64x4xi1, #blocked> loc(#loc87) + %tmp0_63 = tt.load %tmp0_60, %tmp0_62, %cst_5 evictionPolicy = evict_first : tensor<64x4x!tt.ptr, #blocked> loc(#loc102) + %tmp0_64 = arith.extf %tmp0_63 : tensor<64x4xbf16, #blocked> to tensor<64x4xf32, #blocked> loc(#loc103) + %tmp1_65 = arith.addi %tmp0_56, %tmp1_39 : tensor<64x4xi64, #blocked> loc(#loc89) + %tmp1_66 = arith.addi %tmp1_65, %tmp1_49 : tensor<64x4xi64, #blocked> loc(#loc97) + %tmp1_67 = tt.addptr %tmp1_50, %tmp1_66 : tensor<64x4x!tt.ptr, #blocked>, tensor<64x4xi64, #blocked> loc(#loc98) + %tmp1_68 = tt.load %tmp1_67, %tmp0_62, %cst_5 evictionPolicy = evict_first : tensor<64x4x!tt.ptr, #blocked> loc(#loc104) + %tmp1_69 = arith.extf %tmp1_68 : tensor<64x4xbf16, #blocked> to tensor<64x4xf32, #blocked> loc(#loc105) + %tmp2 = arith.mulf %tmp0_64, %tmp1_69 : tensor<64x4xf32, #blocked> loc(#loc106) + %tmp5 = arith.addf %_tmp4_53, %tmp2 : tensor<64x4xf32, #blocked> loc(#loc107) + %_tmp4_70 = arith.select %tmp0_62, %tmp5, %_tmp4_53 : tensor<64x4xi1, #blocked>, tensor<64x4xf32, #blocked> loc(#loc108) + scf.yield %_tmp4_70 : tensor<64x4xf32, #blocked> loc(#loc50) + } loc(#loc99) + %tmp4 = "tt.reduce"(%_tmp4) <{axis = 1 : i32}> ({ + ^bb0(%tmp4_53: f32 loc(callsite(#loc1 at #loc109)), %tmp4_54: f32 loc(callsite(#loc1 at #loc109))): + %tmp4_55 = arith.addf %tmp4_53, %tmp4_54 : f32 loc(#loc121) + tt.reduce.return %tmp4_55 : f32 loc(#loc119) + }) : (tensor<64x4xf32, #blocked>) -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc119) + %tmp4_51 = ttg.convert_layout %tmp4 : tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc110) + %tmp4_52 = tt.expand_dims %tmp4_51 {axis = 1 : i32} : tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xf32, #blocked1> loc(#loc110) + %0 = tt.splat %out_ptr1 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked1> loc(#loc55) + %1 = tt.addptr %0, %xindex_15 : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> loc(#loc55) + tt.store %1, %tmp4_52, %xmask_18 : tensor<64x1x!tt.ptr, #blocked1> loc(#loc56) + tt.return loc(#loc57) + } loc(#loc) +} loc(#loc) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":22:28) +#loc3 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":22:33) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":23:44) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":23:23) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":24:21) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":25:37) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":27:19) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":28:21) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":28:28) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":29:19) +#loc12 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":74:34) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":30:51) +#loc14 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":74:44) +#loc15 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":74:47) +#loc16 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":75:25) +#loc17 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":75:36) +#loc18 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":75:32) +#loc19 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":75:47) +#loc20 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:45) +#loc21 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:41) +#loc22 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:55) +#loc23 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:50) +#loc24 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:65) +#loc25 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:69) +#loc26 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:60) +#loc27 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:34) +#loc28 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:84) +#loc29 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:45) +#loc30 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:41) +#loc31 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:54) +#loc32 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:73) +#loc33 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:99) +#loc34 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:90) +#loc35 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:81) +#loc36 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:65) +#loc37 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:58) +#loc38 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:50) +#loc39 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:34) +#loc40 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":33:40) +#loc41 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":34:31) +#loc42 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":35:29) +#loc43 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:74) +#loc44 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:136) +#loc45 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:106) +#loc46 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:168) +#loc47 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":41:22) +#loc48 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":43:23) +#loc49 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":44:48) +#loc50 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":44:8) +#loc51 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:36) +#loc53 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:15) +#loc54 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":45:28) +#loc55 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":49:25) +#loc56 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":49:36) +#loc57 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":49:4) +#loc65 = loc("xoffset"(#loc2)) +#loc66 = loc("xoffset"(#loc3)) +#loc67 = loc("xindex"(#loc4)) +#loc68 = loc("xindex"(#loc5)) +#loc69 = loc("xmask"(#loc6)) +#loc70 = loc("r0_base"(#loc7)) +#loc71 = loc("x0"(#loc8)) +#loc72 = loc("x1"(#loc9)) +#loc73 = loc("x1"(#loc10)) +#loc74 = loc("x2"(#loc11)) +#loc75 = loc("fixed"(#loc12)) +#loc76 = loc("x5"(#loc13)) +#loc77 = loc("fixed"(#loc14)) +#loc78 = loc("fixed"(#loc15)) +#loc79 = loc("tmp0"(#loc20)) +#loc80 = loc("tmp0"(#loc21)) +#loc81 = loc("tmp0"(#loc22)) +#loc82 = loc("tmp0"(#loc23)) +#loc83 = loc("tmp0"(#loc24)) +#loc84 = loc("tmp0"(#loc25)) +#loc85 = loc("tmp0"(#loc26)) +#loc86 = loc("tmp0"(#loc27)) +#loc87 = loc("tmp0"(#loc28)) +#loc88 = loc("tmp1"(#loc29)) +#loc89 = loc("tmp1"(#loc30)) +#loc90 = loc("tmp1"(#loc31)) +#loc91 = loc("tmp1"(#loc32)) +#loc92 = loc("tmp1"(#loc33)) +#loc93 = loc("tmp1"(#loc34)) +#loc94 = loc("tmp1"(#loc35)) +#loc95 = loc("tmp1"(#loc36)) +#loc96 = loc("tmp1"(#loc37)) +#loc97 = loc("tmp1"(#loc38)) +#loc98 = loc("tmp1"(#loc39)) +#loc99 = loc("_tmp4"(#loc40)) +#loc100 = loc("r0_index"(#loc41)) +#loc101 = loc("r0_mask"(#loc42)) +#loc102 = loc("tmp0"(#loc43)) +#loc103 = loc("tmp0"(#loc44)) +#loc104 = loc("tmp1"(#loc45)) +#loc105 = loc("tmp1"(#loc46)) +#loc106 = loc("tmp2"(#loc47)) +#loc107 = loc("tmp5"(#loc48)) +#loc108 = loc("_tmp4"(#loc49)) +#loc110 = loc("tmp4"(#loc54)) +#loc111 = loc(callsite(#loc75 at #loc76)) +#loc112 = loc(callsite(#loc77 at #loc76)) +#loc113 = loc(callsite(#loc78 at #loc76)) +#loc114 = loc(callsite(#loc16 at #loc76)) +#loc115 = loc(callsite(#loc17 at #loc76)) +#loc116 = loc(callsite(#loc18 at #loc76)) +#loc117 = loc(callsite(#loc19 at #loc76)) +#loc118 = loc(fused[#loc94, #loc95]) +#loc119 = loc(callsite(#loc51 at #loc109)) +#loc121 = loc(callsite(#loc53 at #loc119)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/FM6CWX5NDNUSSOQWH3EGE2NL4IM53NMNXM4NI2FUVI2663Y6EMSQ/triton_red_fused_zeros_0.ttir b/SpecForge-ext/cache/compiled_kernels/triton/0/FM6CWX5NDNUSSOQWH3EGE2NL4IM53NMNXM4NI2FUVI2663Y6EMSQ/triton_red_fused_zeros_0.ttir new file mode 100644 index 0000000000000000000000000000000000000000..1780babbd588d70a514b373357742a93a12ff6af --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/FM6CWX5NDNUSSOQWH3EGE2NL4IM53NMNXM4NI2FUVI2663Y6EMSQ/triton_red_fused_zeros_0.ttir @@ -0,0 +1,228 @@ +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":18:0) +#loc6 = loc(unknown) +#loc54 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":45:25) +#loc60 = loc("in_ptr0"(#loc)) +#loc61 = loc("in_ptr1"(#loc)) +#loc62 = loc("out_ptr1"(#loc)) +#loc63 = loc("ks0"(#loc)) +#loc64 = loc("ks1"(#loc)) +#loc65 = loc("xnumel"(#loc)) +#loc66 = loc("r0_numel"(#loc)) +#loc113 = loc("tmp4"(#loc54)) +#loc124 = loc(callsite(#loc6 at #loc113)) +module { + tt.func public @triton_red_fused_zeros_0(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %in_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr1"(#loc)), %out_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr1"(#loc)), %ks0: i64 loc("ks0"(#loc)), %ks1: i64 {tt.divisibility = 16 : i32} loc("ks1"(#loc)), %xnumel: i32 {tt.divisibility = 16 : i32} loc("xnumel"(#loc)), %r0_numel: i32 {tt.divisibility = 16 : i32} loc("r0_numel"(#loc))) attributes {noinline = false} { + %fixed = arith.constant dense<1> : tensor<64x1xi64> loc(#loc115) + %x5 = arith.constant dense<0> : tensor<64x1xi32> loc(#loc116) + %fixed_0 = arith.constant dense<0> : tensor<64x1xi64> loc(#loc117) + %x5_1 = arith.constant 0 : i64 loc(#loc118) + %c1_i64 = arith.constant 1 : i64 loc(#loc6) + %cst = arith.constant dense<0.000000e+00> : tensor<64x4xbf16> loc(#loc6) + %c4_i32 = arith.constant 4 : i32 loc(#loc7) + %c128_i32 = arith.constant 128 : i32 loc(#loc7) + %c0_i32 = arith.constant 0 : i32 loc(#loc7) + %cst_2 = arith.constant dense<4096> : tensor<64x1xi64> loc(#loc6) + %c4096_i64 = arith.constant 4096 : i64 loc(#loc6) + %cst_3 = arith.constant dense<128> : tensor<64x1xi64> loc(#loc6) + %cst_4 = arith.constant dense<128> : tensor<1x4xi32> loc(#loc6) + %cst_5 = arith.constant dense<0.000000e+00> : tensor<64x4xf32> loc(#loc6) + %x1 = arith.constant dense<32> : tensor<64x1xi64> loc(#loc70) + %c64_i32 = arith.constant 64 : i32 loc(#loc6) + %xoffset = tt.get_program_id x : i32 loc(#loc71) + %xoffset_6 = arith.muli %xoffset, %c64_i32 : i32 loc(#loc72) + %xindex = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc73) + %xindex_7 = tt.expand_dims %xindex {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc74) + %xindex_8 = tt.splat %xoffset_6 : i32 -> tensor<64x1xi32> loc(#loc75) + %xindex_9 = arith.addi %xindex_8, %xindex_7 : tensor<64x1xi32> loc(#loc75) + %xmask = tt.splat %xnumel : i32 -> tensor<64x1xi32> loc(#loc76) + %xmask_10 = arith.cmpi slt, %xindex_9, %xmask : tensor<64x1xi32> loc(#loc76) + %r0_base = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> loc(#loc77) + %r0_base_11 = tt.expand_dims %r0_base {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> loc(#loc78) + %x0 = arith.extsi %xindex_9 : tensor<64x1xi32> to tensor<64x1xi64> loc(#loc79) + %x0_12 = tt.splat %ks0 : i64 -> tensor<64x1xi64> loc(#loc79) + %x0_13 = arith.remsi %x0, %x0_12 : tensor<64x1xi64> loc(#loc79) + %x1_14 = arith.divsi %x0, %x0_12 : tensor<64x1xi64> loc(#loc80) + %x1_15 = arith.remsi %x1_14, %x1 : tensor<64x1xi64> loc(#loc70) + %x2 = tt.splat %ks1 : i64 -> tensor<64x1xi64> loc(#loc81) + %x2_16 = arith.divsi %x0, %x2 : tensor<64x1xi64> loc(#loc81) + %fixed_17 = arith.cmpi ne, %x0_13, %fixed_0 : tensor<64x1xi64> loc(#loc117) + %fixed_18 = arith.subi %x1_14, %fixed : tensor<64x1xi64> loc(#loc115) + %fixed_19 = arith.select %fixed_17, %fixed_18, %x1_14 : tensor<64x1xi1>, tensor<64x1xi64> loc(#loc119) + %x5_20 = arith.cmpi slt, %xindex_9, %x5 : tensor<64x1xi32> loc(#loc116) + %x5_21 = arith.cmpi slt, %ks0, %x5_1 : i64 loc(#loc118) + %x5_22 = tt.splat %x5_21 : i1 -> tensor<64x1xi1> loc(#loc120) + %x5_23 = arith.cmpi ne, %x5_20, %x5_22 : tensor<64x1xi1> loc(#loc120) + %x5_24 = arith.select %x5_23, %fixed_19, %x1_14 : tensor<64x1xi1>, tensor<64x1xi64> loc(#loc121) + %_tmp4 = scf.for %r0_offset = %c0_i32 to %c128_i32 step %c4_i32 iter_args(%_tmp4_26 = %cst_5) -> (tensor<64x4xf32>) : i32 { + %r0_index = tt.splat %r0_offset : i32 -> tensor<1x4xi32> loc(#loc84) + %r0_index_27 = arith.addi %r0_index, %r0_base_11 : tensor<1x4xi32> loc(#loc84) + %r0_mask = arith.cmpi slt, %r0_index_27, %cst_4 : tensor<1x4xi32> loc(#loc85) + %tmp0 = arith.muli %x1_15, %cst_3 : tensor<64x1xi64> loc(#loc86) + %tmp0_28 = arith.extsi %r0_index_27 : tensor<1x4xi32> to tensor<1x4xi64> loc(#loc87) + %tmp0_29 = tt.broadcast %tmp0_28 : tensor<1x4xi64> -> tensor<64x4xi64> loc(#loc87) + %tmp0_30 = tt.broadcast %tmp0 : tensor<64x1xi64> -> tensor<64x4xi64> loc(#loc87) + %tmp0_31 = arith.addi %tmp0_29, %tmp0_30 : tensor<64x4xi64> loc(#loc87) + %tmp0_32 = arith.muli %x0_13, %cst_2 : tensor<64x1xi64> loc(#loc88) + %tmp0_33 = tt.broadcast %tmp0_32 : tensor<64x1xi64> -> tensor<64x4xi64> loc(#loc89) + %tmp0_34 = arith.addi %tmp0_31, %tmp0_33 : tensor<64x4xi64> loc(#loc89) + %tmp0_35 = arith.muli %ks0, %c4096_i64 : i64 loc(#loc90) + %tmp0_36 = tt.splat %tmp0_35 : i64 -> tensor<64x1xi64> loc(#loc91) + %tmp0_37 = arith.muli %tmp0_36, %x2_16 : tensor<64x1xi64> loc(#loc91) + %tmp0_38 = tt.broadcast %tmp0_37 : tensor<64x1xi64> -> tensor<64x4xi64> loc(#loc92) + %tmp0_39 = arith.addi %tmp0_34, %tmp0_38 : tensor<64x4xi64> loc(#loc92) + %tmp0_40 = tt.splat %in_ptr0 : !tt.ptr -> tensor<64x4x!tt.ptr> loc(#loc93) + %tmp0_41 = tt.addptr %tmp0_40, %tmp0_39 : tensor<64x4x!tt.ptr>, tensor<64x4xi64> loc(#loc93) + %tmp0_42 = tt.broadcast %r0_mask : tensor<1x4xi1> -> tensor<64x4xi1> loc(#loc94) + %tmp0_43 = tt.broadcast %xmask_10 : tensor<64x1xi1> -> tensor<64x4xi1> loc(#loc94) + %tmp0_44 = arith.andi %tmp0_42, %tmp0_43 : tensor<64x4xi1> loc(#loc94) + %tmp0_45 = tt.load %tmp0_41, %tmp0_44, %cst evictionPolicy = evict_first : tensor<64x4x!tt.ptr> loc(#loc95) + %tmp0_46 = arith.extf %tmp0_45 : tensor<64x4xbf16> to tensor<64x4xf32> loc(#loc96) + %tmp1 = arith.muli %x0_13, %cst_3 : tensor<64x1xi64> loc(#loc97) + %tmp1_47 = tt.broadcast %tmp1 : tensor<64x1xi64> -> tensor<64x4xi64> loc(#loc98) + %tmp1_48 = arith.addi %tmp0_29, %tmp1_47 : tensor<64x4xi64> loc(#loc98) + %tmp1_49 = arith.muli %x5_24, %cst_3 : tensor<64x1xi64> loc(#loc99) + %tmp1_50 = arith.cmpi sle, %ks0, %c1_i64 : i64 loc(#loc100) + %tmp1_51 = arith.cmpi sgt, %ks0, %c1_i64 : i64 loc(#loc101) + %tmp1_52 = arith.extui %tmp1_51 : i1 to i64 loc(#loc102) + %tmp1_53 = arith.muli %ks0, %tmp1_52 : i64 loc(#loc102) + %tmp1_54 = arith.extui %tmp1_50 : i1 to i64 loc(#loc122) + %tmp1_55 = arith.addi %tmp1_54, %tmp1_53 : i64 loc(#loc103) + %tmp1_56 = tt.splat %tmp1_55 : i64 -> tensor<64x1xi64> loc(#loc105) + %tmp1_57 = arith.muli %tmp1_49, %tmp1_56 : tensor<64x1xi64> loc(#loc105) + %tmp1_58 = tt.broadcast %tmp1_57 : tensor<64x1xi64> -> tensor<64x4xi64> loc(#loc106) + %tmp1_59 = arith.addi %tmp1_48, %tmp1_58 : tensor<64x4xi64> loc(#loc106) + %tmp1_60 = tt.splat %in_ptr1 : !tt.ptr -> tensor<64x4x!tt.ptr> loc(#loc107) + %tmp1_61 = tt.addptr %tmp1_60, %tmp1_59 : tensor<64x4x!tt.ptr>, tensor<64x4xi64> loc(#loc107) + %tmp1_62 = tt.load %tmp1_61, %tmp0_44, %cst evictionPolicy = evict_first : tensor<64x4x!tt.ptr> loc(#loc108) + %tmp1_63 = arith.extf %tmp1_62 : tensor<64x4xbf16> to tensor<64x4xf32> loc(#loc109) + %tmp2 = arith.mulf %tmp0_46, %tmp1_63 : tensor<64x4xf32> loc(#loc110) + %tmp5 = arith.addf %_tmp4_26, %tmp2 : tensor<64x4xf32> loc(#loc111) + %_tmp4_64 = arith.select %tmp0_44, %tmp5, %_tmp4_26 : tensor<64x4xi1>, tensor<64x4xf32> loc(#loc112) + scf.yield %_tmp4_64 : tensor<64x4xf32> loc(#loc52) + } loc(#loc83) + %tmp4 = "tt.reduce"(%_tmp4) <{axis = 1 : i32}> ({ + ^bb0(%tmp4_26: f32 loc(callsite(#loc6 at #loc113)), %tmp4_27: f32 loc(callsite(#loc6 at #loc113))): + %tmp4_28 = arith.addf %tmp4_26, %tmp4_27 : f32 loc(#loc125) + tt.reduce.return %tmp4_28 : f32 loc(#loc123) + }) : (tensor<64x4xf32>) -> tensor<64xf32> loc(#loc123) + %tmp4_25 = tt.expand_dims %tmp4 {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32> loc(#loc114) + %0 = tt.splat %out_ptr1 : !tt.ptr -> tensor<64x1x!tt.ptr> loc(#loc57) + %1 = tt.addptr %0, %xindex_9 : tensor<64x1x!tt.ptr>, tensor<64x1xi32> loc(#loc57) + tt.store %1, %tmp4_25, %xmask_10 : tensor<64x1x!tt.ptr> loc(#loc58) + tt.return loc(#loc59) + } loc(#loc) +} loc(#loc) +#loc1 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":74:44) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":30:51) +#loc3 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":75:25) +#loc4 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":74:34) +#loc5 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":75:36) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":33:40) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":28:28) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":22:28) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":22:33) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":23:36) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":23:44) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":23:23) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":24:21) +#loc15 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":25:27) +#loc16 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":25:37) +#loc17 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":27:19) +#loc18 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":28:21) +#loc19 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":29:19) +#loc20 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":74:47) +#loc21 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":75:32) +#loc22 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":75:47) +#loc23 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":34:31) +#loc24 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":35:29) +#loc25 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:45) +#loc26 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:41) +#loc27 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:55) +#loc28 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:50) +#loc29 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:65) +#loc30 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:69) +#loc31 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:60) +#loc32 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:34) +#loc33 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:84) +#loc34 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:74) +#loc35 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":39:136) +#loc36 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:45) +#loc37 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:41) +#loc38 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:54) +#loc39 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:73) +#loc40 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:99) +#loc41 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:90) +#loc42 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:81) +#loc43 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:65) +#loc44 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:58) +#loc45 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:50) +#loc46 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:34) +#loc47 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:106) +#loc48 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":40:168) +#loc49 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":41:22) +#loc50 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":43:23) +#loc51 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":44:48) +#loc52 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":44:8) +#loc53 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:36) +#loc55 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:15) +#loc56 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":45:28) +#loc57 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":49:25) +#loc58 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":49:36) +#loc59 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xc/cxc4incwmkl2xoh3t6qhrirh4hubiv3q645vdq5swb5g2dxegz7b.py":49:4) +#loc67 = loc("fixed"(#loc1)) +#loc68 = loc("x5"(#loc2)) +#loc69 = loc("fixed"(#loc4)) +#loc70 = loc("x1"(#loc8)) +#loc71 = loc("xoffset"(#loc9)) +#loc72 = loc("xoffset"(#loc10)) +#loc73 = loc("xindex"(#loc11)) +#loc74 = loc("xindex"(#loc12)) +#loc75 = loc("xindex"(#loc13)) +#loc76 = loc("xmask"(#loc14)) +#loc77 = loc("r0_base"(#loc15)) +#loc78 = loc("r0_base"(#loc16)) +#loc79 = loc("x0"(#loc17)) +#loc80 = loc("x1"(#loc18)) +#loc81 = loc("x2"(#loc19)) +#loc82 = loc("fixed"(#loc20)) +#loc83 = loc("_tmp4"(#loc7)) +#loc84 = loc("r0_index"(#loc23)) +#loc85 = loc("r0_mask"(#loc24)) +#loc86 = loc("tmp0"(#loc25)) +#loc87 = loc("tmp0"(#loc26)) +#loc88 = loc("tmp0"(#loc27)) +#loc89 = loc("tmp0"(#loc28)) +#loc90 = loc("tmp0"(#loc29)) +#loc91 = loc("tmp0"(#loc30)) +#loc92 = loc("tmp0"(#loc31)) +#loc93 = loc("tmp0"(#loc32)) +#loc94 = loc("tmp0"(#loc33)) +#loc95 = loc("tmp0"(#loc34)) +#loc96 = loc("tmp0"(#loc35)) +#loc97 = loc("tmp1"(#loc36)) +#loc98 = loc("tmp1"(#loc37)) +#loc99 = loc("tmp1"(#loc38)) +#loc100 = loc("tmp1"(#loc39)) +#loc101 = loc("tmp1"(#loc40)) +#loc102 = loc("tmp1"(#loc41)) +#loc103 = loc("tmp1"(#loc42)) +#loc104 = loc("tmp1"(#loc43)) +#loc105 = loc("tmp1"(#loc44)) +#loc106 = loc("tmp1"(#loc45)) +#loc107 = loc("tmp1"(#loc46)) +#loc108 = loc("tmp1"(#loc47)) +#loc109 = loc("tmp1"(#loc48)) +#loc110 = loc("tmp2"(#loc49)) +#loc111 = loc("tmp5"(#loc50)) +#loc112 = loc("_tmp4"(#loc51)) +#loc114 = loc("tmp4"(#loc56)) +#loc115 = loc(callsite(#loc67 at #loc68)) +#loc116 = loc(callsite(#loc3 at #loc68)) +#loc117 = loc(callsite(#loc69 at #loc68)) +#loc118 = loc(callsite(#loc5 at #loc68)) +#loc119 = loc(callsite(#loc82 at #loc68)) +#loc120 = loc(callsite(#loc21 at #loc68)) +#loc121 = loc(callsite(#loc22 at #loc68)) +#loc122 = loc(fused[#loc103, #loc104]) +#loc123 = loc(callsite(#loc53 at #loc113)) +#loc125 = loc(callsite(#loc55 at #loc123)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/HULI6VL32CI7Q6UGP75BIUMHXEMSAWXMPR3LDJJLGOCFNMUFGPVQ/triton_tem_fused_zeros_1.llir b/SpecForge-ext/cache/compiled_kernels/triton/0/HULI6VL32CI7Q6UGP75BIUMHXEMSAWXMPR3LDJJLGOCFNMUFGPVQ/triton_tem_fused_zeros_1.llir new file mode 100644 index 0000000000000000000000000000000000000000..95bd1fd6298687177e98b144f4dc3b34bcf22922 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/HULI6VL32CI7Q6UGP75BIUMHXEMSAWXMPR3LDJJLGOCFNMUFGPVQ/triton_tem_fused_zeros_1.llir @@ -0,0 +1,12657 @@ +; ModuleID = 'LLVMDialectModule' +source_filename = "LLVMDialectModule" +target datalayout = "e-p3:32:32-p4:32:32-p5:32:32-p6:32:32-p7:32:32-i64:64-i128:128-v16:16-v32:32-n16:32:64" + +@global_smem = external addrspace(3) global [0 x i8], align 16 +@.str = private unnamed_addr constant [11 x i8] c"__CUDA_FTZ\00", align 1 + +; Function Attrs: nounwind +define ptx_kernel void @triton_tem_fused_zeros_1(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, ptr addrspace(1) %3, ptr addrspace(1) %4, ptr addrspace(1) %5, ptr addrspace(1) %6, ptr addrspace(1) %7, ptr addrspace(1) %8, ptr addrspace(1) %9, ptr addrspace(1) %10, ptr addrspace(1) %11, ptr addrspace(1) %12, ptr addrspace(1) %13, ptr addrspace(1) %14, ptr addrspace(1) %15, ptr addrspace(1) %16, ptr addrspace(1) %17, ptr addrspace(1) readnone captures(none) %18, ptr addrspace(1) readnone captures(none) %19) local_unnamed_addr #0 !dbg !5 { + %21 = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !dbg !8 + %22 = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y(), !dbg !9 + %23 = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.z(), !dbg !10 + %24 = and i32 %22, 1, !dbg !11 + %25 = shl i32 %23, 18, !dbg !12 + %26 = shl nuw nsw i32 %24, 21, !dbg !13 + %27 = add i32 %26, %25, !dbg !14 + %28 = sext i32 %27 to i64, !dbg !15 + %29 = shl i32 %22, 21, !dbg !16 + %30 = add i32 %25, %29, !dbg !17 + %31 = sext i32 %30 to i64, !dbg !18 + %32 = getelementptr bfloat, ptr addrspace(1) %1, i64 %28, !dbg !19 + %33 = getelementptr bfloat, ptr addrspace(1) %2, i64 %28, !dbg !20 + %34 = getelementptr bfloat, ptr addrspace(1) %7, i64 %31, !dbg !21 + %35 = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !dbg !22 + %36 = lshr i32 %35, 5, !dbg !22 + %37 = and i32 %35, 240, !dbg !22 + %38 = lshr exact i32 %37, 4, !dbg !22 + %39 = or disjoint i32 %38, 16, !dbg !22 + %40 = or disjoint i32 %38, 32, !dbg !22 + %41 = or disjoint i32 %38, 48, !dbg !22 + %42 = or disjoint i32 %38, 64, !dbg !22 + %43 = or disjoint i32 %38, 80, !dbg !22 + %44 = or disjoint i32 %38, 96, !dbg !22 + %45 = or disjoint i32 %38, 112, !dbg !22 + %46 = lshr i32 %35, 1, !dbg !22 + %47 = and i32 %46, 112, !dbg !22 + %48 = lshr i32 %35, 2, !dbg !22 + %49 = and i32 %48, 7, !dbg !22 + %50 = or disjoint i32 %47, %49, !dbg !22 + %51 = or disjoint i32 %50, 8, !dbg !22 + %52 = icmp samesign ugt i32 %21, 15, !dbg !23 + br i1 %52, label %53, label %4460, !dbg !24 + +53: ; preds = %20 + %54 = add nsw i32 %21, -16, !dbg !25 + %55 = lshr i32 %54, 4, !dbg !26 + %56 = shl nuw nsw i32 %23, 2, !dbg !27 + %57 = add nuw nsw i32 %55, %56, !dbg !28 + %58 = and i32 %21, 15, !dbg !29 + %59 = shl nuw nsw i32 %24, 4, !dbg !30 + %60 = or disjoint i32 %59, %58, !dbg !31 + %61 = shl nuw nsw i32 %24, 8, !dbg !32 + %62 = shl nuw nsw i32 %58, 4, !dbg !33 + %63 = or disjoint i32 %61, %62, !dbg !34 + %64 = shl i32 %57, 7, !dbg !35 + %65 = shl i32 %22, 23, !dbg !36 + %66 = add i32 %64, %65, !dbg !37 + %67 = sext i32 %66 to i64, !dbg !38 + %68 = shl i32 %57, 18, !dbg !39 + %69 = add i32 %68, %65, !dbg !40 + %70 = sext i32 %69 to i64, !dbg !41 + %71 = shl nuw i32 %22, 16, !dbg !42 + %72 = shl i32 %57, 11, !dbg !42 + %73 = add i32 %72, %71, !dbg !42 + %74 = sext i32 %73 to i64, !dbg !43 + %75 = getelementptr bfloat, ptr addrspace(1) %0, i64 %67, !dbg !44 + %76 = getelementptr bfloat, ptr addrspace(1) %5, i64 %70, !dbg !45 + %77 = getelementptr bfloat, ptr addrspace(1) %6, i64 %67, !dbg !46 + %78 = getelementptr float, ptr addrspace(1) %3, i64 %74, !dbg !47 + %79 = getelementptr float, ptr addrspace(1) %4, i64 %74, !dbg !48 + %80 = shl nuw nsw i32 %58, 7, !dbg !49 + %81 = or disjoint i32 %80, %38, !dbg !50 + %82 = or disjoint i32 %39, %80, !dbg !50 + %83 = or disjoint i32 %40, %80, !dbg !50 + %84 = or disjoint i32 %41, %80, !dbg !50 + %85 = or disjoint i32 %42, %80, !dbg !50 + %86 = or disjoint i32 %43, %80, !dbg !50 + %87 = or disjoint i32 %44, %80, !dbg !50 + %88 = or disjoint i32 %45, %80, !dbg !50 + %89 = insertelement <2 x i32> poison, i32 %51, i64 0, !dbg !50 + %90 = insertelement <2 x i32> %89, i32 %50, i64 1, !dbg !50 + %91 = insertelement <2 x i32> poison, i32 %80, i64 0, !dbg !50 + %92 = shufflevector <2 x i32> %90, <2 x i32> poison, <32 x i32> , !dbg !50 + %93 = shufflevector <2 x i32> %91, <2 x i32> poison, <32 x i32> zeroinitializer, !dbg !50 + %94 = or disjoint <32 x i32> %92, %93, !dbg !50 + %95 = shl nuw nsw i32 %81, 12, !dbg !51 + %96 = shl nuw nsw i32 %82, 12, !dbg !51 + %97 = shl nuw nsw i32 %83, 12, !dbg !51 + %98 = shl nuw nsw i32 %84, 12, !dbg !51 + %99 = shl nuw nsw i32 %85, 12, !dbg !51 + %100 = shl nuw nsw i32 %86, 12, !dbg !51 + %101 = shl nuw nsw i32 %87, 12, !dbg !51 + %102 = shl nuw nsw i32 %88, 12, !dbg !51 + %103 = zext nneg i32 %95 to i64, !dbg !54 + %104 = getelementptr bfloat, ptr addrspace(1) %75, i64 %103, !dbg !54 + %105 = zext nneg i32 %96 to i64, !dbg !54 + %106 = getelementptr bfloat, ptr addrspace(1) %75, i64 %105, !dbg !54 + %107 = zext nneg i32 %97 to i64, !dbg !54 + %108 = getelementptr bfloat, ptr addrspace(1) %75, i64 %107, !dbg !54 + %109 = zext nneg i32 %98 to i64, !dbg !54 + %110 = getelementptr bfloat, ptr addrspace(1) %75, i64 %109, !dbg !54 + %111 = zext nneg i32 %99 to i64, !dbg !54 + %112 = getelementptr bfloat, ptr addrspace(1) %75, i64 %111, !dbg !54 + %113 = zext nneg i32 %100 to i64, !dbg !54 + %114 = getelementptr bfloat, ptr addrspace(1) %75, i64 %113, !dbg !54 + %115 = zext nneg i32 %101 to i64, !dbg !54 + %116 = getelementptr bfloat, ptr addrspace(1) %75, i64 %115, !dbg !54 + %117 = zext nneg i32 %102 to i64, !dbg !54 + %118 = getelementptr bfloat, ptr addrspace(1) %75, i64 %117, !dbg !54 + %119 = shl nuw nsw i32 %35, 3, !dbg !55 + %120 = and i32 %119, 120, !dbg !55 + %121 = zext nneg i32 %120 to i64, !dbg !56 + %122 = getelementptr bfloat, ptr addrspace(1) %104, i64 %121, !dbg !56 + %123 = getelementptr bfloat, ptr addrspace(1) %106, i64 %121, !dbg !56 + %124 = getelementptr bfloat, ptr addrspace(1) %108, i64 %121, !dbg !56 + %125 = getelementptr bfloat, ptr addrspace(1) %110, i64 %121, !dbg !56 + %126 = getelementptr bfloat, ptr addrspace(1) %112, i64 %121, !dbg !56 + %127 = getelementptr bfloat, ptr addrspace(1) %114, i64 %121, !dbg !56 + %128 = getelementptr bfloat, ptr addrspace(1) %116, i64 %121, !dbg !56 + %129 = getelementptr bfloat, ptr addrspace(1) %118, i64 %121, !dbg !56 + %130 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l"(ptr addrspace(1) %122) #3, !dbg !57 + %131 = extractvalue { i32, i32, i32, i32 } %130, 0, !dbg !57 + %132 = extractvalue { i32, i32, i32, i32 } %130, 1, !dbg !57 + %133 = extractvalue { i32, i32, i32, i32 } %130, 2, !dbg !57 + %134 = extractvalue { i32, i32, i32, i32 } %130, 3, !dbg !57 + %135 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l"(ptr addrspace(1) %123) #3, !dbg !57 + %136 = extractvalue { i32, i32, i32, i32 } %135, 0, !dbg !57 + %137 = extractvalue { i32, i32, i32, i32 } %135, 1, !dbg !57 + %138 = extractvalue { i32, i32, i32, i32 } %135, 2, !dbg !57 + %139 = extractvalue { i32, i32, i32, i32 } %135, 3, !dbg !57 + %140 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l"(ptr addrspace(1) %124) #3, !dbg !57 + %141 = extractvalue { i32, i32, i32, i32 } %140, 0, !dbg !57 + %142 = extractvalue { i32, i32, i32, i32 } %140, 1, !dbg !57 + %143 = extractvalue { i32, i32, i32, i32 } %140, 2, !dbg !57 + %144 = extractvalue { i32, i32, i32, i32 } %140, 3, !dbg !57 + %145 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l"(ptr addrspace(1) %125) #3, !dbg !57 + %146 = extractvalue { i32, i32, i32, i32 } %145, 0, !dbg !57 + %147 = extractvalue { i32, i32, i32, i32 } %145, 1, !dbg !57 + %148 = extractvalue { i32, i32, i32, i32 } %145, 2, !dbg !57 + %149 = extractvalue { i32, i32, i32, i32 } %145, 3, !dbg !57 + %150 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l"(ptr addrspace(1) %126) #3, !dbg !57 + %151 = extractvalue { i32, i32, i32, i32 } %150, 0, !dbg !57 + %152 = extractvalue { i32, i32, i32, i32 } %150, 1, !dbg !57 + %153 = extractvalue { i32, i32, i32, i32 } %150, 2, !dbg !57 + %154 = extractvalue { i32, i32, i32, i32 } %150, 3, !dbg !57 + %155 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l"(ptr addrspace(1) %127) #3, !dbg !57 + %156 = extractvalue { i32, i32, i32, i32 } %155, 0, !dbg !57 + %157 = extractvalue { i32, i32, i32, i32 } %155, 1, !dbg !57 + %158 = extractvalue { i32, i32, i32, i32 } %155, 2, !dbg !57 + %159 = extractvalue { i32, i32, i32, i32 } %155, 3, !dbg !57 + %160 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l"(ptr addrspace(1) %128) #3, !dbg !57 + %161 = extractvalue { i32, i32, i32, i32 } %160, 0, !dbg !57 + %162 = extractvalue { i32, i32, i32, i32 } %160, 1, !dbg !57 + %163 = extractvalue { i32, i32, i32, i32 } %160, 2, !dbg !57 + %164 = extractvalue { i32, i32, i32, i32 } %160, 3, !dbg !57 + %165 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l"(ptr addrspace(1) %129) #3, !dbg !57 + %166 = extractvalue { i32, i32, i32, i32 } %165, 0, !dbg !57 + %167 = extractvalue { i32, i32, i32, i32 } %165, 1, !dbg !57 + %168 = extractvalue { i32, i32, i32, i32 } %165, 2, !dbg !57 + %169 = extractvalue { i32, i32, i32, i32 } %165, 3, !dbg !57 + %170 = shl nuw nsw i32 %35, 4, !dbg !57 + %171 = and i32 %170, 112, !dbg !57 + %172 = shl nuw nsw i32 %37, 3, !dbg !57 + %173 = and i32 %35, 112, !dbg !57 + %174 = and i32 %35, 8, !dbg !57 + %175 = shl nuw nsw i32 %174, 11, !dbg !57 + %176 = or disjoint i32 %171, %172, !dbg !57 + %177 = xor i32 %176, %173, !dbg !57 + %178 = or disjoint i32 %177, %175, !dbg !57 + %179 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %178, !dbg !57 + %180 = insertelement <4 x i32> poison, i32 %131, i64 0, !dbg !57 + %181 = insertelement <4 x i32> %180, i32 %132, i64 1, !dbg !57 + %182 = insertelement <4 x i32> %181, i32 %133, i64 2, !dbg !57 + %183 = insertelement <4 x i32> %182, i32 %134, i64 3, !dbg !57 + store <4 x i32> %183, ptr addrspace(3) %179, align 16, !dbg !57 + %184 = or disjoint i32 %178, 2048, !dbg !57 + %185 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %184, !dbg !57 + %186 = insertelement <4 x i32> poison, i32 %136, i64 0, !dbg !57 + %187 = insertelement <4 x i32> %186, i32 %137, i64 1, !dbg !57 + %188 = insertelement <4 x i32> %187, i32 %138, i64 2, !dbg !57 + %189 = insertelement <4 x i32> %188, i32 %139, i64 3, !dbg !57 + store <4 x i32> %189, ptr addrspace(3) %185, align 16, !dbg !57 + %190 = or disjoint i32 %178, 4096, !dbg !57 + %191 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %190, !dbg !57 + %192 = insertelement <4 x i32> poison, i32 %141, i64 0, !dbg !57 + %193 = insertelement <4 x i32> %192, i32 %142, i64 1, !dbg !57 + %194 = insertelement <4 x i32> %193, i32 %143, i64 2, !dbg !57 + %195 = insertelement <4 x i32> %194, i32 %144, i64 3, !dbg !57 + store <4 x i32> %195, ptr addrspace(3) %191, align 16, !dbg !57 + %196 = or disjoint i32 %178, 6144, !dbg !57 + %197 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %196, !dbg !57 + %198 = insertelement <4 x i32> poison, i32 %146, i64 0, !dbg !57 + %199 = insertelement <4 x i32> %198, i32 %147, i64 1, !dbg !57 + %200 = insertelement <4 x i32> %199, i32 %148, i64 2, !dbg !57 + %201 = insertelement <4 x i32> %200, i32 %149, i64 3, !dbg !57 + store <4 x i32> %201, ptr addrspace(3) %197, align 16, !dbg !57 + %202 = or disjoint i32 %178, 8192, !dbg !57 + %203 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %202, !dbg !57 + %204 = insertelement <4 x i32> poison, i32 %151, i64 0, !dbg !57 + %205 = insertelement <4 x i32> %204, i32 %152, i64 1, !dbg !57 + %206 = insertelement <4 x i32> %205, i32 %153, i64 2, !dbg !57 + %207 = insertelement <4 x i32> %206, i32 %154, i64 3, !dbg !57 + store <4 x i32> %207, ptr addrspace(3) %203, align 16, !dbg !57 + %208 = or disjoint i32 %178, 10240, !dbg !57 + %209 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %208, !dbg !57 + %210 = insertelement <4 x i32> poison, i32 %156, i64 0, !dbg !57 + %211 = insertelement <4 x i32> %210, i32 %157, i64 1, !dbg !57 + %212 = insertelement <4 x i32> %211, i32 %158, i64 2, !dbg !57 + %213 = insertelement <4 x i32> %212, i32 %159, i64 3, !dbg !57 + store <4 x i32> %213, ptr addrspace(3) %209, align 16, !dbg !57 + %214 = or disjoint i32 %178, 12288, !dbg !57 + %215 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %214, !dbg !57 + %216 = insertelement <4 x i32> poison, i32 %161, i64 0, !dbg !57 + %217 = insertelement <4 x i32> %216, i32 %162, i64 1, !dbg !57 + %218 = insertelement <4 x i32> %217, i32 %163, i64 2, !dbg !57 + %219 = insertelement <4 x i32> %218, i32 %164, i64 3, !dbg !57 + store <4 x i32> %219, ptr addrspace(3) %215, align 16, !dbg !57 + %220 = or disjoint i32 %178, 14336, !dbg !57 + %221 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %220, !dbg !57 + %222 = insertelement <4 x i32> poison, i32 %166, i64 0, !dbg !57 + %223 = insertelement <4 x i32> %222, i32 %167, i64 1, !dbg !57 + %224 = insertelement <4 x i32> %223, i32 %168, i64 2, !dbg !57 + %225 = insertelement <4 x i32> %224, i32 %169, i64 3, !dbg !57 + store <4 x i32> %225, ptr addrspace(3) %221, align 16, !dbg !57 + %226 = shl nuw nsw i32 %81, 7, !dbg !58 + %227 = shl nuw nsw i32 %82, 7, !dbg !58 + %228 = shl nuw nsw i32 %83, 7, !dbg !58 + %229 = shl nuw nsw i32 %84, 7, !dbg !58 + %230 = shl nuw nsw i32 %85, 7, !dbg !58 + %231 = shl nuw nsw i32 %86, 7, !dbg !58 + %232 = shl nuw nsw i32 %87, 7, !dbg !58 + %233 = shl nuw nsw i32 %88, 7, !dbg !58 + %234 = zext nneg i32 %226 to i64, !dbg !60 + %235 = getelementptr bfloat, ptr addrspace(1) %76, i64 %234, !dbg !60 + %236 = zext nneg i32 %227 to i64, !dbg !60 + %237 = getelementptr bfloat, ptr addrspace(1) %76, i64 %236, !dbg !60 + %238 = zext nneg i32 %228 to i64, !dbg !60 + %239 = getelementptr bfloat, ptr addrspace(1) %76, i64 %238, !dbg !60 + %240 = zext nneg i32 %229 to i64, !dbg !60 + %241 = getelementptr bfloat, ptr addrspace(1) %76, i64 %240, !dbg !60 + %242 = zext nneg i32 %230 to i64, !dbg !60 + %243 = getelementptr bfloat, ptr addrspace(1) %76, i64 %242, !dbg !60 + %244 = zext nneg i32 %231 to i64, !dbg !60 + %245 = getelementptr bfloat, ptr addrspace(1) %76, i64 %244, !dbg !60 + %246 = zext nneg i32 %232 to i64, !dbg !60 + %247 = getelementptr bfloat, ptr addrspace(1) %76, i64 %246, !dbg !60 + %248 = zext nneg i32 %233 to i64, !dbg !60 + %249 = getelementptr bfloat, ptr addrspace(1) %76, i64 %248, !dbg !60 + %250 = getelementptr bfloat, ptr addrspace(1) %235, i64 %121, !dbg !61 + %251 = getelementptr bfloat, ptr addrspace(1) %237, i64 %121, !dbg !61 + %252 = getelementptr bfloat, ptr addrspace(1) %239, i64 %121, !dbg !61 + %253 = getelementptr bfloat, ptr addrspace(1) %241, i64 %121, !dbg !61 + %254 = getelementptr bfloat, ptr addrspace(1) %243, i64 %121, !dbg !61 + %255 = getelementptr bfloat, ptr addrspace(1) %245, i64 %121, !dbg !61 + %256 = getelementptr bfloat, ptr addrspace(1) %247, i64 %121, !dbg !61 + %257 = getelementptr bfloat, ptr addrspace(1) %249, i64 %121, !dbg !61 + %258 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l"(ptr addrspace(1) %250) #3, !dbg !62 + %259 = extractvalue { i32, i32, i32, i32 } %258, 0, !dbg !62 + %260 = extractvalue { i32, i32, i32, i32 } %258, 1, !dbg !62 + %261 = extractvalue { i32, i32, i32, i32 } %258, 2, !dbg !62 + %262 = extractvalue { i32, i32, i32, i32 } %258, 3, !dbg !62 + %263 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l"(ptr addrspace(1) %251) #3, !dbg !62 + %264 = extractvalue { i32, i32, i32, i32 } %263, 0, !dbg !62 + %265 = extractvalue { i32, i32, i32, i32 } %263, 1, !dbg !62 + %266 = extractvalue { i32, i32, i32, i32 } %263, 2, !dbg !62 + %267 = extractvalue { i32, i32, i32, i32 } %263, 3, !dbg !62 + %268 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l"(ptr addrspace(1) %252) #3, !dbg !62 + %269 = extractvalue { i32, i32, i32, i32 } %268, 0, !dbg !62 + %270 = extractvalue { i32, i32, i32, i32 } %268, 1, !dbg !62 + %271 = extractvalue { i32, i32, i32, i32 } %268, 2, !dbg !62 + %272 = extractvalue { i32, i32, i32, i32 } %268, 3, !dbg !62 + %273 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l"(ptr addrspace(1) %253) #3, !dbg !62 + %274 = extractvalue { i32, i32, i32, i32 } %273, 0, !dbg !62 + %275 = extractvalue { i32, i32, i32, i32 } %273, 1, !dbg !62 + %276 = extractvalue { i32, i32, i32, i32 } %273, 2, !dbg !62 + %277 = extractvalue { i32, i32, i32, i32 } %273, 3, !dbg !62 + %278 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l"(ptr addrspace(1) %254) #3, !dbg !62 + %279 = extractvalue { i32, i32, i32, i32 } %278, 0, !dbg !62 + %280 = extractvalue { i32, i32, i32, i32 } %278, 1, !dbg !62 + %281 = extractvalue { i32, i32, i32, i32 } %278, 2, !dbg !62 + %282 = extractvalue { i32, i32, i32, i32 } %278, 3, !dbg !62 + %283 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l"(ptr addrspace(1) %255) #3, !dbg !62 + %284 = extractvalue { i32, i32, i32, i32 } %283, 0, !dbg !62 + %285 = extractvalue { i32, i32, i32, i32 } %283, 1, !dbg !62 + %286 = extractvalue { i32, i32, i32, i32 } %283, 2, !dbg !62 + %287 = extractvalue { i32, i32, i32, i32 } %283, 3, !dbg !62 + %288 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l"(ptr addrspace(1) %256) #3, !dbg !62 + %289 = extractvalue { i32, i32, i32, i32 } %288, 0, !dbg !62 + %290 = extractvalue { i32, i32, i32, i32 } %288, 1, !dbg !62 + %291 = extractvalue { i32, i32, i32, i32 } %288, 2, !dbg !62 + %292 = extractvalue { i32, i32, i32, i32 } %288, 3, !dbg !62 + %293 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l"(ptr addrspace(1) %257) #3, !dbg !62 + %294 = extractvalue { i32, i32, i32, i32 } %293, 0, !dbg !62 + %295 = extractvalue { i32, i32, i32, i32 } %293, 1, !dbg !62 + %296 = extractvalue { i32, i32, i32, i32 } %293, 2, !dbg !62 + %297 = extractvalue { i32, i32, i32, i32 } %293, 3, !dbg !62 + %298 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072), i32 %178, !dbg !62 + %299 = insertelement <4 x i32> poison, i32 %259, i64 0, !dbg !62 + %300 = insertelement <4 x i32> %299, i32 %260, i64 1, !dbg !62 + %301 = insertelement <4 x i32> %300, i32 %261, i64 2, !dbg !62 + %302 = insertelement <4 x i32> %301, i32 %262, i64 3, !dbg !62 + store <4 x i32> %302, ptr addrspace(3) %298, align 16, !dbg !62 + %303 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072), i32 %184, !dbg !62 + %304 = insertelement <4 x i32> poison, i32 %264, i64 0, !dbg !62 + %305 = insertelement <4 x i32> %304, i32 %265, i64 1, !dbg !62 + %306 = insertelement <4 x i32> %305, i32 %266, i64 2, !dbg !62 + %307 = insertelement <4 x i32> %306, i32 %267, i64 3, !dbg !62 + store <4 x i32> %307, ptr addrspace(3) %303, align 16, !dbg !62 + %308 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072), i32 %190, !dbg !62 + %309 = insertelement <4 x i32> poison, i32 %269, i64 0, !dbg !62 + %310 = insertelement <4 x i32> %309, i32 %270, i64 1, !dbg !62 + %311 = insertelement <4 x i32> %310, i32 %271, i64 2, !dbg !62 + %312 = insertelement <4 x i32> %311, i32 %272, i64 3, !dbg !62 + store <4 x i32> %312, ptr addrspace(3) %308, align 16, !dbg !62 + %313 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072), i32 %196, !dbg !62 + %314 = insertelement <4 x i32> poison, i32 %274, i64 0, !dbg !62 + %315 = insertelement <4 x i32> %314, i32 %275, i64 1, !dbg !62 + %316 = insertelement <4 x i32> %315, i32 %276, i64 2, !dbg !62 + %317 = insertelement <4 x i32> %316, i32 %277, i64 3, !dbg !62 + store <4 x i32> %317, ptr addrspace(3) %313, align 16, !dbg !62 + %318 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072), i32 %202, !dbg !62 + %319 = insertelement <4 x i32> poison, i32 %279, i64 0, !dbg !62 + %320 = insertelement <4 x i32> %319, i32 %280, i64 1, !dbg !62 + %321 = insertelement <4 x i32> %320, i32 %281, i64 2, !dbg !62 + %322 = insertelement <4 x i32> %321, i32 %282, i64 3, !dbg !62 + store <4 x i32> %322, ptr addrspace(3) %318, align 16, !dbg !62 + %323 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072), i32 %208, !dbg !62 + %324 = insertelement <4 x i32> poison, i32 %284, i64 0, !dbg !62 + %325 = insertelement <4 x i32> %324, i32 %285, i64 1, !dbg !62 + %326 = insertelement <4 x i32> %325, i32 %286, i64 2, !dbg !62 + %327 = insertelement <4 x i32> %326, i32 %287, i64 3, !dbg !62 + store <4 x i32> %327, ptr addrspace(3) %323, align 16, !dbg !62 + %328 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072), i32 %214, !dbg !62 + %329 = insertelement <4 x i32> poison, i32 %289, i64 0, !dbg !62 + %330 = insertelement <4 x i32> %329, i32 %290, i64 1, !dbg !62 + %331 = insertelement <4 x i32> %330, i32 %291, i64 2, !dbg !62 + %332 = insertelement <4 x i32> %331, i32 %292, i64 3, !dbg !62 + store <4 x i32> %332, ptr addrspace(3) %328, align 16, !dbg !62 + %333 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072), i32 %220, !dbg !62 + %334 = insertelement <4 x i32> poison, i32 %294, i64 0, !dbg !62 + %335 = insertelement <4 x i32> %334, i32 %295, i64 1, !dbg !62 + %336 = insertelement <4 x i32> %335, i32 %296, i64 2, !dbg !62 + %337 = insertelement <4 x i32> %336, i32 %297, i64 3, !dbg !62 + store <4 x i32> %337, ptr addrspace(3) %333, align 16, !dbg !62 + %338 = extractelement <32 x i32> %94, i64 2, !dbg !63 + %339 = zext nneg i32 %338 to i64, !dbg !63 + %340 = getelementptr float, ptr addrspace(1) %79, i64 %339, !dbg !63 + %341 = extractelement <32 x i32> %94, i64 0, !dbg !63 + %342 = zext nneg i32 %341 to i64, !dbg !63 + %343 = getelementptr float, ptr addrspace(1) %79, i64 %342, !dbg !63 + %344 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l"(ptr addrspace(1) %340) #3, !dbg !64 + %345 = bitcast i32 %344 to float, !dbg !64 + %346 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l"(ptr addrspace(1) %343) #3, !dbg !64 + %347 = bitcast i32 %346 to float, !dbg !64 + %348 = getelementptr float, ptr addrspace(1) %78, i64 %339, !dbg !65 + %349 = getelementptr float, ptr addrspace(1) %78, i64 %342, !dbg !65 + %350 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l"(ptr addrspace(1) %348) #3, !dbg !66 + %351 = bitcast i32 %350 to float, !dbg !66 + %352 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l"(ptr addrspace(1) %349) #3, !dbg !66 + %353 = bitcast i32 %352 to float, !dbg !66 + %354 = fcmp oeq float %351, 0xFFF0000000000000, !dbg !67 + %355 = fcmp oeq float %353, 0xFFF0000000000000, !dbg !67 + %356 = select i1 %354, float 0.000000e+00, float %351, !dbg !68 + %357 = select i1 %355, float 0.000000e+00, float %353, !dbg !68 + %358 = zext nneg i32 %63 to i64, !dbg !69 + %359 = getelementptr i32, ptr addrspace(1) %9, i64 %358, !dbg !69 + %360 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l"(ptr addrspace(1) %359) #3, !dbg !70 + %361 = shl i32 %360, 7, !dbg !71 + %362 = zext nneg i32 %60 to i64, !dbg !72 + %363 = getelementptr i32, ptr addrspace(1) %8, i64 %362, !dbg !72 + %364 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l"(ptr addrspace(1) %363) #3, !dbg !73 + %365 = and i32 %35, 3, !dbg !74 + %366 = or disjoint i32 %361, %38, !dbg !75 + %367 = or disjoint i32 %361, %39, !dbg !75 + %368 = or disjoint i32 %361, %40, !dbg !75 + %369 = or disjoint i32 %361, %41, !dbg !75 + %370 = shl i32 %366, 7, !dbg !76 + %371 = shl i32 %367, 7, !dbg !76 + %372 = shl i32 %368, 7, !dbg !76 + %373 = shl i32 %369, 7, !dbg !76 + %374 = sext i32 %370 to i64, !dbg !78 + %375 = getelementptr bfloat, ptr addrspace(1) %32, i64 %374, !dbg !78 + %376 = sext i32 %371 to i64, !dbg !78 + %377 = getelementptr bfloat, ptr addrspace(1) %32, i64 %376, !dbg !78 + %378 = sext i32 %372 to i64, !dbg !78 + %379 = getelementptr bfloat, ptr addrspace(1) %32, i64 %378, !dbg !78 + %380 = sext i32 %373 to i64, !dbg !78 + %381 = getelementptr bfloat, ptr addrspace(1) %32, i64 %380, !dbg !78 + %382 = getelementptr bfloat, ptr addrspace(1) %375, i64 %121, !dbg !79 + %383 = getelementptr bfloat, ptr addrspace(1) %377, i64 %121, !dbg !79 + %384 = getelementptr bfloat, ptr addrspace(1) %379, i64 %121, !dbg !79 + %385 = getelementptr bfloat, ptr addrspace(1) %381, i64 %121, !dbg !79 + %386 = getelementptr bfloat, ptr addrspace(1) %33, i64 %374, !dbg !80 + %387 = getelementptr bfloat, ptr addrspace(1) %33, i64 %376, !dbg !80 + %388 = getelementptr bfloat, ptr addrspace(1) %33, i64 %378, !dbg !80 + %389 = getelementptr bfloat, ptr addrspace(1) %33, i64 %380, !dbg !80 + %390 = getelementptr bfloat, ptr addrspace(1) %386, i64 %121, !dbg !81 + %391 = getelementptr bfloat, ptr addrspace(1) %387, i64 %121, !dbg !81 + %392 = getelementptr bfloat, ptr addrspace(1) %388, i64 %121, !dbg !81 + %393 = getelementptr bfloat, ptr addrspace(1) %389, i64 %121, !dbg !81 + %394 = shl i32 %364, 1, !dbg !82 + %395 = zext nneg i32 %22 to i64, !dbg !83 + %396 = getelementptr i64, ptr addrspace(1) %16, i64 %395, !dbg !83 + %397 = icmp sgt i32 %394, 0, !dbg !84 + %398 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09@$2 ld.global.b64 { $0 }, [ $1 + 0 ];", "=l,l,b"(ptr addrspace(1) %396, i1 %397) #3, !dbg !85 + %399 = icmp sgt i64 %398, %339, !dbg !86 + %400 = icmp sgt i64 %398, %342, !dbg !86 + %401 = shl nuw nsw i32 %174, 10, !dbg !87 + %402 = or disjoint i32 %177, %401, !dbg !87 + %403 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %402, !dbg !87 + %404 = select i1 %397, i32 16, i32 0, !dbg !87 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %403, ptr addrspace(1) %382, i32 %404) #3, !dbg !87 + %405 = or disjoint i32 %402, 2048, !dbg !87 + %406 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %405, !dbg !87 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %406, ptr addrspace(1) %383, i32 %404) #3, !dbg !87 + %407 = or disjoint i32 %402, 4096, !dbg !87 + %408 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %407, !dbg !87 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %408, ptr addrspace(1) %384, i32 %404) #3, !dbg !87 + %409 = or disjoint i32 %402, 6144, !dbg !87 + %410 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %409, !dbg !87 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %410, ptr addrspace(1) %385, i32 %404) #3, !dbg !87 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !87 + %411 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %402, !dbg !87 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %411, ptr addrspace(1) %390, i32 %404) #3, !dbg !87 + %412 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %405, !dbg !87 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %412, ptr addrspace(1) %391, i32 %404) #3, !dbg !87 + %413 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %407, !dbg !87 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %413, ptr addrspace(1) %392, i32 %404) #3, !dbg !87 + %414 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %409, !dbg !87 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %414, ptr addrspace(1) %393, i32 %404) #3, !dbg !87 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !87 + %415 = icmp sgt i32 %394, 1, !dbg !84 + %416 = getelementptr i8, ptr addrspace(1) %382, i64 16384, !dbg !88 + %417 = getelementptr i8, ptr addrspace(1) %383, i64 16384, !dbg !88 + %418 = getelementptr i8, ptr addrspace(1) %384, i64 16384, !dbg !88 + %419 = getelementptr i8, ptr addrspace(1) %385, i64 16384, !dbg !88 + %420 = getelementptr i8, ptr addrspace(1) %390, i64 16384, !dbg !89 + %421 = getelementptr i8, ptr addrspace(1) %391, i64 16384, !dbg !89 + %422 = getelementptr i8, ptr addrspace(1) %392, i64 16384, !dbg !89 + %423 = getelementptr i8, ptr addrspace(1) %393, i64 16384, !dbg !89 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !87 + %424 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 16384), i32 %402, !dbg !87 + %425 = select i1 %415, i32 16, i32 0, !dbg !87 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %424, ptr addrspace(1) %416, i32 %425) #3, !dbg !87 + %426 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 16384), i32 %405, !dbg !87 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %426, ptr addrspace(1) %417, i32 %425) #3, !dbg !87 + %427 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 16384), i32 %407, !dbg !87 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %427, ptr addrspace(1) %418, i32 %425) #3, !dbg !87 + %428 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 16384), i32 %409, !dbg !87 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %428, ptr addrspace(1) %419, i32 %425) #3, !dbg !87 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !87 + %429 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 65536), i32 %402, !dbg !87 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %429, ptr addrspace(1) %420, i32 %425) #3, !dbg !87 + %430 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 65536), i32 %405, !dbg !87 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %430, ptr addrspace(1) %421, i32 %425) #3, !dbg !87 + %431 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 65536), i32 %407, !dbg !87 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %431, ptr addrspace(1) %422, i32 %425) #3, !dbg !87 + %432 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 65536), i32 %409, !dbg !87 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %432, ptr addrspace(1) %423, i32 %425) #3, !dbg !87 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !87 + tail call void asm sideeffect "fence.proxy.async.shared::cta;", ""() #3, !dbg !90 + br i1 %397, label %.lr.ph1846, label %._crit_edge1847, !dbg !84 + +.lr.ph1846: ; preds = %53 + %433 = tail call i32 @llvm.umin.i32(i32 %394, i32 32), !dbg !91 + %434 = shl nuw nsw i32 %365, 1, !dbg !74 + %435 = or disjoint i32 %361, %434, !dbg !74 + %436 = insertelement <8 x i32> poison, i32 %435, i64 0, !dbg !75 + %437 = add nsw i32 %433, -2 + %438 = add nsw i32 %433, -1 + %439 = shufflevector <8 x i32> %436, <8 x i32> poison, <16 x i32> + %440 = insertelement <16 x i32> %439, i32 %435, i64 8 + %441 = insertelement <16 x i32> %440, i32 %435, i64 9 + %442 = insertelement <16 x i32> %441, i32 %435, i64 10 + %443 = insertelement <16 x i32> %442, i32 %435, i64 11 + %444 = insertelement <16 x i32> %443, i32 %435, i64 12 + %445 = insertelement <16 x i32> %444, i32 %435, i64 13 + %446 = insertelement <16 x i32> %445, i32 %435, i64 14 + %447 = or disjoint <16 x i32> %446, + %448 = insertelement <16 x i32> %447, i32 %435, i64 15 + %449 = insertelement <8 x i64> poison, i64 %398, i64 0, !dbg !92 + %450 = shufflevector <8 x i64> %449, <8 x i64> poison, <8 x i32> zeroinitializer, !dbg !92 + br label %451, !dbg !84 + +451: ; preds = %.lr.ph1846, %__nv_exp2f.exit1507 + %452 = phi i32 [ 64, %.lr.ph1846 ], [ %2289, %__nv_exp2f.exit1507 ] + %453 = phi i32 [ -1, %.lr.ph1846 ], [ %526, %__nv_exp2f.exit1507 ] + %454 = phi i32 [ 1, %.lr.ph1846 ], [ %2302, %__nv_exp2f.exit1507 ] + %.pn9181828 = phi ptr addrspace(1) [ %423, %.lr.ph1846 ], [ %2299, %__nv_exp2f.exit1507 ] + %.pn9341827 = phi ptr addrspace(1) [ %422, %.lr.ph1846 ], [ %2298, %__nv_exp2f.exit1507 ] + %.pn9501826 = phi ptr addrspace(1) [ %421, %.lr.ph1846 ], [ %2297, %__nv_exp2f.exit1507 ] + %.pn9661825 = phi ptr addrspace(1) [ %420, %.lr.ph1846 ], [ %2296, %__nv_exp2f.exit1507 ] + %.pn8541824 = phi ptr addrspace(1) [ %419, %.lr.ph1846 ], [ %2295, %__nv_exp2f.exit1507 ] + %.pn8701823 = phi ptr addrspace(1) [ %418, %.lr.ph1846 ], [ %2294, %__nv_exp2f.exit1507 ] + %.pn8861822 = phi ptr addrspace(1) [ %417, %.lr.ph1846 ], [ %2293, %__nv_exp2f.exit1507 ] + %.pn9021821 = phi ptr addrspace(1) [ %416, %.lr.ph1846 ], [ %2292, %__nv_exp2f.exit1507 ] + %455 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2203, %__nv_exp2f.exit1507 ] + %456 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2204, %__nv_exp2f.exit1507 ] + %457 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2205, %__nv_exp2f.exit1507 ] + %458 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2206, %__nv_exp2f.exit1507 ] + %459 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2207, %__nv_exp2f.exit1507 ] + %460 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2208, %__nv_exp2f.exit1507 ] + %461 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2209, %__nv_exp2f.exit1507 ] + %462 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2210, %__nv_exp2f.exit1507 ] + %463 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2211, %__nv_exp2f.exit1507 ] + %464 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2212, %__nv_exp2f.exit1507 ] + %465 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2213, %__nv_exp2f.exit1507 ] + %466 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2214, %__nv_exp2f.exit1507 ] + %467 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2215, %__nv_exp2f.exit1507 ] + %468 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2216, %__nv_exp2f.exit1507 ] + %469 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2217, %__nv_exp2f.exit1507 ] + %470 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2218, %__nv_exp2f.exit1507 ] + %471 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2219, %__nv_exp2f.exit1507 ] + %472 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2220, %__nv_exp2f.exit1507 ] + %473 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2221, %__nv_exp2f.exit1507 ] + %474 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2222, %__nv_exp2f.exit1507 ] + %475 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2223, %__nv_exp2f.exit1507 ] + %476 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2224, %__nv_exp2f.exit1507 ] + %477 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2225, %__nv_exp2f.exit1507 ] + %478 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2226, %__nv_exp2f.exit1507 ] + %479 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2227, %__nv_exp2f.exit1507 ] + %480 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2228, %__nv_exp2f.exit1507 ] + %481 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2229, %__nv_exp2f.exit1507 ] + %482 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2230, %__nv_exp2f.exit1507 ] + %483 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2231, %__nv_exp2f.exit1507 ] + %484 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2232, %__nv_exp2f.exit1507 ] + %485 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2233, %__nv_exp2f.exit1507 ] + %486 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2234, %__nv_exp2f.exit1507 ] + %487 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2235, %__nv_exp2f.exit1507 ] + %488 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2236, %__nv_exp2f.exit1507 ] + %489 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2237, %__nv_exp2f.exit1507 ] + %490 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2238, %__nv_exp2f.exit1507 ] + %491 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2239, %__nv_exp2f.exit1507 ] + %492 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2240, %__nv_exp2f.exit1507 ] + %493 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2241, %__nv_exp2f.exit1507 ] + %494 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2242, %__nv_exp2f.exit1507 ] + %495 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2243, %__nv_exp2f.exit1507 ] + %496 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2244, %__nv_exp2f.exit1507 ] + %497 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2245, %__nv_exp2f.exit1507 ] + %498 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2246, %__nv_exp2f.exit1507 ] + %499 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2247, %__nv_exp2f.exit1507 ] + %500 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2248, %__nv_exp2f.exit1507 ] + %501 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2249, %__nv_exp2f.exit1507 ] + %502 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2250, %__nv_exp2f.exit1507 ] + %503 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2251, %__nv_exp2f.exit1507 ] + %504 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2252, %__nv_exp2f.exit1507 ] + %505 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2253, %__nv_exp2f.exit1507 ] + %506 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2254, %__nv_exp2f.exit1507 ] + %507 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2255, %__nv_exp2f.exit1507 ] + %508 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2256, %__nv_exp2f.exit1507 ] + %509 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2257, %__nv_exp2f.exit1507 ] + %510 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2258, %__nv_exp2f.exit1507 ] + %511 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2259, %__nv_exp2f.exit1507 ] + %512 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2260, %__nv_exp2f.exit1507 ] + %513 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2261, %__nv_exp2f.exit1507 ] + %514 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2262, %__nv_exp2f.exit1507 ] + %515 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2263, %__nv_exp2f.exit1507 ] + %516 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2264, %__nv_exp2f.exit1507 ] + %517 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2265, %__nv_exp2f.exit1507 ] + %518 = phi float [ 0.000000e+00, %.lr.ph1846 ], [ %2266, %__nv_exp2f.exit1507 ] + %519 = phi i32 [ 0, %.lr.ph1846 ], [ %2270, %__nv_exp2f.exit1507 ] + %520 = phi <16 x i32> [ %448, %.lr.ph1846 ], [ %2269, %__nv_exp2f.exit1507 ] + %521 = shufflevector <16 x i32> %520, <16 x i32> poison, <32 x i32> + %522 = icmp slt i32 %519, %437, !dbg !84 + %523 = icmp slt i32 %519, %438, !dbg !84 + %524 = add i32 %453, 1, !dbg !84 + %525 = icmp sgt i32 %524, 2, !dbg !84 + %526 = select i1 %525, i32 0, i32 %524, !dbg !84 + tail call void @llvm.nvvm.cp.async.wait.group(i32 2), !dbg !87 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !87 + %527 = shl i32 %526, 13, !dbg !87 + %528 = getelementptr bfloat, ptr addrspace(3) @global_smem, i32 %527, !dbg !87 + %529 = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 %36, i32 0, i32 31), !dbg !90 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !90 + %530 = shl i32 %529, 11, !dbg !90 + %531 = and i32 %530, 8192, !dbg !90 + %532 = add i32 %531, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !90 + %533 = lshr exact i32 %532, 4, !dbg !90 + %534 = and i32 %533, 16383, !dbg !90 + %535 = zext nneg i32 %534 to i64, !dbg !90 + %536 = or disjoint i64 %535, 4611686293372403712, !dbg !90 + %537 = ptrtoint ptr addrspace(3) %528 to i32, !dbg !90 + %538 = lshr exact i32 %537, 4, !dbg !90 + %539 = and i32 %538, 16383, !dbg !90 + %540 = zext nneg i32 %539 to i64, !dbg !90 + %541 = or disjoint i64 %540, 4611686293338849280, !dbg !90 + %542 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $32, $33, 0, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,l,l"(i64 %536, i64 %541) #3, !dbg !90 + %543 = or disjoint i32 %531, 32, !dbg !90 + %544 = add i32 %543, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !90 + %545 = lshr exact i32 %544, 4, !dbg !90 + %546 = and i32 %545, 16383, !dbg !90 + %547 = zext nneg i32 %546 to i64, !dbg !90 + %548 = or disjoint i64 %547, 4611686293372403712, !dbg !90 + %549 = add i32 %537, 32, !dbg !90 + %550 = lshr exact i32 %549, 4, !dbg !90 + %551 = and i32 %550, 16383, !dbg !90 + %552 = zext nneg i32 %551 to i64, !dbg !90 + %553 = or disjoint i64 %552, 4611686293338849280, !dbg !90 + %554 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %542, 0, !dbg !90 + %555 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %542, 1, !dbg !90 + %556 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %542, 2, !dbg !90 + %557 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %542, 3, !dbg !90 + %558 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %542, 4, !dbg !90 + %559 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %542, 5, !dbg !90 + %560 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %542, 6, !dbg !90 + %561 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %542, 7, !dbg !90 + %562 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %542, 8, !dbg !90 + %563 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %542, 9, !dbg !90 + %564 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %542, 10, !dbg !90 + %565 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %542, 11, !dbg !90 + %566 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %542, 12, !dbg !90 + %567 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %542, 13, !dbg !90 + %568 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %542, 14, !dbg !90 + %569 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %542, 15, !dbg !90 + %570 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %542, 16, !dbg !90 + %571 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %542, 17, !dbg !90 + %572 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %542, 18, !dbg !90 + %573 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %542, 19, !dbg !90 + %574 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %542, 20, !dbg !90 + %575 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %542, 21, !dbg !90 + %576 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %542, 22, !dbg !90 + %577 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %542, 23, !dbg !90 + %578 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %542, 24, !dbg !90 + %579 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %542, 25, !dbg !90 + %580 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %542, 26, !dbg !90 + %581 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %542, 27, !dbg !90 + %582 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %542, 28, !dbg !90 + %583 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %542, 29, !dbg !90 + %584 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %542, 30, !dbg !90 + %585 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %542, 31, !dbg !90 + %586 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %554, float %555, float %556, float %557, float %558, float %559, float %560, float %561, float %562, float %563, float %564, float %565, float %566, float %567, float %568, float %569, float %570, float %571, float %572, float %573, float %574, float %575, float %576, float %577, float %578, float %579, float %580, float %581, float %582, float %583, float %584, float %585, i64 %548, i64 %553, i1 true) #3, !dbg !90 + %587 = or disjoint i32 %531, 64, !dbg !90 + %588 = add i32 %587, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !90 + %589 = lshr exact i32 %588, 4, !dbg !90 + %590 = and i32 %589, 16383, !dbg !90 + %591 = zext nneg i32 %590 to i64, !dbg !90 + %592 = or disjoint i64 %591, 4611686293372403712, !dbg !90 + %593 = add i32 %537, 64, !dbg !90 + %594 = lshr exact i32 %593, 4, !dbg !90 + %595 = and i32 %594, 16383, !dbg !90 + %596 = zext nneg i32 %595 to i64, !dbg !90 + %597 = or disjoint i64 %596, 4611686293338849280, !dbg !90 + %598 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %586, 0, !dbg !90 + %599 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %586, 1, !dbg !90 + %600 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %586, 2, !dbg !90 + %601 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %586, 3, !dbg !90 + %602 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %586, 4, !dbg !90 + %603 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %586, 5, !dbg !90 + %604 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %586, 6, !dbg !90 + %605 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %586, 7, !dbg !90 + %606 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %586, 8, !dbg !90 + %607 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %586, 9, !dbg !90 + %608 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %586, 10, !dbg !90 + %609 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %586, 11, !dbg !90 + %610 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %586, 12, !dbg !90 + %611 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %586, 13, !dbg !90 + %612 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %586, 14, !dbg !90 + %613 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %586, 15, !dbg !90 + %614 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %586, 16, !dbg !90 + %615 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %586, 17, !dbg !90 + %616 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %586, 18, !dbg !90 + %617 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %586, 19, !dbg !90 + %618 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %586, 20, !dbg !90 + %619 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %586, 21, !dbg !90 + %620 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %586, 22, !dbg !90 + %621 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %586, 23, !dbg !90 + %622 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %586, 24, !dbg !90 + %623 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %586, 25, !dbg !90 + %624 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %586, 26, !dbg !90 + %625 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %586, 27, !dbg !90 + %626 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %586, 28, !dbg !90 + %627 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %586, 29, !dbg !90 + %628 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %586, 30, !dbg !90 + %629 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %586, 31, !dbg !90 + %630 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %598, float %599, float %600, float %601, float %602, float %603, float %604, float %605, float %606, float %607, float %608, float %609, float %610, float %611, float %612, float %613, float %614, float %615, float %616, float %617, float %618, float %619, float %620, float %621, float %622, float %623, float %624, float %625, float %626, float %627, float %628, float %629, i64 %592, i64 %597, i1 true) #3, !dbg !90 + %631 = or disjoint i32 %531, 96, !dbg !90 + %632 = add i32 %631, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !90 + %633 = lshr exact i32 %632, 4, !dbg !90 + %634 = and i32 %633, 16383, !dbg !90 + %635 = zext nneg i32 %634 to i64, !dbg !90 + %636 = or disjoint i64 %635, 4611686293372403712, !dbg !90 + %637 = add i32 %537, 96, !dbg !90 + %638 = lshr exact i32 %637, 4, !dbg !90 + %639 = and i32 %638, 16383, !dbg !90 + %640 = zext nneg i32 %639 to i64, !dbg !90 + %641 = or disjoint i64 %640, 4611686293338849280, !dbg !90 + %642 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %630, 0, !dbg !90 + %643 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %630, 1, !dbg !90 + %644 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %630, 2, !dbg !90 + %645 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %630, 3, !dbg !90 + %646 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %630, 4, !dbg !90 + %647 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %630, 5, !dbg !90 + %648 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %630, 6, !dbg !90 + %649 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %630, 7, !dbg !90 + %650 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %630, 8, !dbg !90 + %651 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %630, 9, !dbg !90 + %652 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %630, 10, !dbg !90 + %653 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %630, 11, !dbg !90 + %654 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %630, 12, !dbg !90 + %655 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %630, 13, !dbg !90 + %656 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %630, 14, !dbg !90 + %657 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %630, 15, !dbg !90 + %658 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %630, 16, !dbg !90 + %659 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %630, 17, !dbg !90 + %660 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %630, 18, !dbg !90 + %661 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %630, 19, !dbg !90 + %662 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %630, 20, !dbg !90 + %663 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %630, 21, !dbg !90 + %664 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %630, 22, !dbg !90 + %665 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %630, 23, !dbg !90 + %666 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %630, 24, !dbg !90 + %667 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %630, 25, !dbg !90 + %668 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %630, 26, !dbg !90 + %669 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %630, 27, !dbg !90 + %670 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %630, 28, !dbg !90 + %671 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %630, 29, !dbg !90 + %672 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %630, 30, !dbg !90 + %673 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %630, 31, !dbg !90 + %674 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %642, float %643, float %644, float %645, float %646, float %647, float %648, float %649, float %650, float %651, float %652, float %653, float %654, float %655, float %656, float %657, float %658, float %659, float %660, float %661, float %662, float %663, float %664, float %665, float %666, float %667, float %668, float %669, float %670, float %671, float %672, float %673, i64 %636, i64 %641, i1 true) #3, !dbg !90 + %675 = or disjoint i32 %531, 16384, !dbg !90 + %676 = add i32 %675, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !90 + %677 = lshr exact i32 %676, 4, !dbg !90 + %678 = and i32 %677, 16383, !dbg !90 + %679 = zext nneg i32 %678 to i64, !dbg !90 + %680 = or disjoint i64 %679, 4611686293372403712, !dbg !90 + %681 = add i32 %537, 8192, !dbg !90 + %682 = lshr exact i32 %681, 4, !dbg !90 + %683 = and i32 %682, 16383, !dbg !90 + %684 = zext nneg i32 %683 to i64, !dbg !90 + %685 = or disjoint i64 %684, 4611686293338849280, !dbg !90 + %686 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %674, 0, !dbg !90 + %687 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %674, 1, !dbg !90 + %688 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %674, 2, !dbg !90 + %689 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %674, 3, !dbg !90 + %690 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %674, 4, !dbg !90 + %691 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %674, 5, !dbg !90 + %692 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %674, 6, !dbg !90 + %693 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %674, 7, !dbg !90 + %694 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %674, 8, !dbg !90 + %695 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %674, 9, !dbg !90 + %696 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %674, 10, !dbg !90 + %697 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %674, 11, !dbg !90 + %698 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %674, 12, !dbg !90 + %699 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %674, 13, !dbg !90 + %700 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %674, 14, !dbg !90 + %701 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %674, 15, !dbg !90 + %702 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %674, 16, !dbg !90 + %703 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %674, 17, !dbg !90 + %704 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %674, 18, !dbg !90 + %705 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %674, 19, !dbg !90 + %706 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %674, 20, !dbg !90 + %707 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %674, 21, !dbg !90 + %708 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %674, 22, !dbg !90 + %709 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %674, 23, !dbg !90 + %710 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %674, 24, !dbg !90 + %711 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %674, 25, !dbg !90 + %712 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %674, 26, !dbg !90 + %713 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %674, 27, !dbg !90 + %714 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %674, 28, !dbg !90 + %715 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %674, 29, !dbg !90 + %716 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %674, 30, !dbg !90 + %717 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %674, 31, !dbg !90 + %718 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %686, float %687, float %688, float %689, float %690, float %691, float %692, float %693, float %694, float %695, float %696, float %697, float %698, float %699, float %700, float %701, float %702, float %703, float %704, float %705, float %706, float %707, float %708, float %709, float %710, float %711, float %712, float %713, float %714, float %715, float %716, float %717, i64 %680, i64 %685, i1 true) #3, !dbg !90 + %719 = or disjoint i32 %531, 16416, !dbg !90 + %720 = add i32 %719, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !90 + %721 = lshr exact i32 %720, 4, !dbg !90 + %722 = and i32 %721, 16383, !dbg !90 + %723 = zext nneg i32 %722 to i64, !dbg !90 + %724 = or disjoint i64 %723, 4611686293372403712, !dbg !90 + %725 = add i32 %537, 8224, !dbg !90 + %726 = lshr exact i32 %725, 4, !dbg !90 + %727 = and i32 %726, 16383, !dbg !90 + %728 = zext nneg i32 %727 to i64, !dbg !90 + %729 = or disjoint i64 %728, 4611686293338849280, !dbg !90 + %730 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %718, 0, !dbg !90 + %731 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %718, 1, !dbg !90 + %732 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %718, 2, !dbg !90 + %733 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %718, 3, !dbg !90 + %734 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %718, 4, !dbg !90 + %735 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %718, 5, !dbg !90 + %736 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %718, 6, !dbg !90 + %737 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %718, 7, !dbg !90 + %738 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %718, 8, !dbg !90 + %739 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %718, 9, !dbg !90 + %740 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %718, 10, !dbg !90 + %741 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %718, 11, !dbg !90 + %742 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %718, 12, !dbg !90 + %743 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %718, 13, !dbg !90 + %744 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %718, 14, !dbg !90 + %745 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %718, 15, !dbg !90 + %746 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %718, 16, !dbg !90 + %747 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %718, 17, !dbg !90 + %748 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %718, 18, !dbg !90 + %749 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %718, 19, !dbg !90 + %750 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %718, 20, !dbg !90 + %751 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %718, 21, !dbg !90 + %752 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %718, 22, !dbg !90 + %753 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %718, 23, !dbg !90 + %754 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %718, 24, !dbg !90 + %755 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %718, 25, !dbg !90 + %756 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %718, 26, !dbg !90 + %757 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %718, 27, !dbg !90 + %758 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %718, 28, !dbg !90 + %759 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %718, 29, !dbg !90 + %760 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %718, 30, !dbg !90 + %761 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %718, 31, !dbg !90 + %762 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %730, float %731, float %732, float %733, float %734, float %735, float %736, float %737, float %738, float %739, float %740, float %741, float %742, float %743, float %744, float %745, float %746, float %747, float %748, float %749, float %750, float %751, float %752, float %753, float %754, float %755, float %756, float %757, float %758, float %759, float %760, float %761, i64 %724, i64 %729, i1 true) #3, !dbg !90 + %763 = or disjoint i32 %531, 16448, !dbg !90 + %764 = add i32 %763, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !90 + %765 = lshr exact i32 %764, 4, !dbg !90 + %766 = and i32 %765, 16383, !dbg !90 + %767 = zext nneg i32 %766 to i64, !dbg !90 + %768 = or disjoint i64 %767, 4611686293372403712, !dbg !90 + %769 = add i32 %537, 8256, !dbg !90 + %770 = lshr exact i32 %769, 4, !dbg !90 + %771 = and i32 %770, 16383, !dbg !90 + %772 = zext nneg i32 %771 to i64, !dbg !90 + %773 = or disjoint i64 %772, 4611686293338849280, !dbg !90 + %774 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %762, 0, !dbg !90 + %775 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %762, 1, !dbg !90 + %776 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %762, 2, !dbg !90 + %777 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %762, 3, !dbg !90 + %778 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %762, 4, !dbg !90 + %779 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %762, 5, !dbg !90 + %780 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %762, 6, !dbg !90 + %781 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %762, 7, !dbg !90 + %782 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %762, 8, !dbg !90 + %783 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %762, 9, !dbg !90 + %784 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %762, 10, !dbg !90 + %785 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %762, 11, !dbg !90 + %786 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %762, 12, !dbg !90 + %787 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %762, 13, !dbg !90 + %788 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %762, 14, !dbg !90 + %789 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %762, 15, !dbg !90 + %790 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %762, 16, !dbg !90 + %791 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %762, 17, !dbg !90 + %792 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %762, 18, !dbg !90 + %793 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %762, 19, !dbg !90 + %794 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %762, 20, !dbg !90 + %795 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %762, 21, !dbg !90 + %796 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %762, 22, !dbg !90 + %797 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %762, 23, !dbg !90 + %798 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %762, 24, !dbg !90 + %799 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %762, 25, !dbg !90 + %800 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %762, 26, !dbg !90 + %801 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %762, 27, !dbg !90 + %802 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %762, 28, !dbg !90 + %803 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %762, 29, !dbg !90 + %804 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %762, 30, !dbg !90 + %805 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %762, 31, !dbg !90 + %806 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %774, float %775, float %776, float %777, float %778, float %779, float %780, float %781, float %782, float %783, float %784, float %785, float %786, float %787, float %788, float %789, float %790, float %791, float %792, float %793, float %794, float %795, float %796, float %797, float %798, float %799, float %800, float %801, float %802, float %803, float %804, float %805, i64 %768, i64 %773, i1 true) #3, !dbg !90 + %807 = or disjoint i32 %531, 16480, !dbg !90 + %808 = add i32 %807, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !90 + %809 = lshr exact i32 %808, 4, !dbg !90 + %810 = and i32 %809, 16383, !dbg !90 + %811 = zext nneg i32 %810 to i64, !dbg !90 + %812 = or disjoint i64 %811, 4611686293372403712, !dbg !90 + %813 = add i32 %537, 8288, !dbg !90 + %814 = lshr exact i32 %813, 4, !dbg !90 + %815 = and i32 %814, 16383, !dbg !90 + %816 = zext nneg i32 %815 to i64, !dbg !90 + %817 = or disjoint i64 %816, 4611686293338849280, !dbg !90 + %818 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %806, 0, !dbg !90 + %819 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %806, 1, !dbg !90 + %820 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %806, 2, !dbg !90 + %821 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %806, 3, !dbg !90 + %822 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %806, 4, !dbg !90 + %823 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %806, 5, !dbg !90 + %824 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %806, 6, !dbg !90 + %825 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %806, 7, !dbg !90 + %826 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %806, 8, !dbg !90 + %827 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %806, 9, !dbg !90 + %828 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %806, 10, !dbg !90 + %829 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %806, 11, !dbg !90 + %830 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %806, 12, !dbg !90 + %831 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %806, 13, !dbg !90 + %832 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %806, 14, !dbg !90 + %833 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %806, 15, !dbg !90 + %834 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %806, 16, !dbg !90 + %835 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %806, 17, !dbg !90 + %836 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %806, 18, !dbg !90 + %837 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %806, 19, !dbg !90 + %838 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %806, 20, !dbg !90 + %839 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %806, 21, !dbg !90 + %840 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %806, 22, !dbg !90 + %841 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %806, 23, !dbg !90 + %842 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %806, 24, !dbg !90 + %843 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %806, 25, !dbg !90 + %844 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %806, 26, !dbg !90 + %845 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %806, 27, !dbg !90 + %846 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %806, 28, !dbg !90 + %847 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %806, 29, !dbg !90 + %848 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %806, 30, !dbg !90 + %849 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %806, 31, !dbg !90 + %850 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %818, float %819, float %820, float %821, float %822, float %823, float %824, float %825, float %826, float %827, float %828, float %829, float %830, float %831, float %832, float %833, float %834, float %835, float %836, float %837, float %838, float %839, float %840, float %841, float %842, float %843, float %844, float %845, float %846, float %847, float %848, float %849, i64 %812, i64 %817, i1 true) #3, !dbg !90 + %851 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %850, 0, !dbg !90 + %852 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %850, 1, !dbg !90 + %853 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %850, 2, !dbg !90 + %854 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %850, 3, !dbg !90 + %855 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %850, 4, !dbg !90 + %856 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %850, 5, !dbg !90 + %857 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %850, 6, !dbg !90 + %858 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %850, 7, !dbg !90 + %859 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %850, 8, !dbg !90 + %860 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %850, 9, !dbg !90 + %861 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %850, 10, !dbg !90 + %862 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %850, 11, !dbg !90 + %863 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %850, 12, !dbg !90 + %864 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %850, 13, !dbg !90 + %865 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %850, 14, !dbg !90 + %866 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %850, 15, !dbg !90 + %867 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %850, 16, !dbg !90 + %868 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %850, 17, !dbg !90 + %869 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %850, 18, !dbg !90 + %870 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %850, 19, !dbg !90 + %871 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %850, 20, !dbg !90 + %872 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %850, 21, !dbg !90 + %873 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %850, 22, !dbg !90 + %874 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %850, 23, !dbg !90 + %875 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %850, 24, !dbg !90 + %876 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %850, 25, !dbg !90 + %877 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %850, 26, !dbg !90 + %878 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %850, 27, !dbg !90 + %879 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %850, 28, !dbg !90 + %880 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %850, 29, !dbg !90 + %881 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %850, 30, !dbg !90 + %882 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %850, 31, !dbg !90 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !90 + %883 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } asm sideeffect "// wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37\0A\09wgmma.wait_group.sync.aligned 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37"(float %851, float %852, float %853, float %854, float %855, float %856, float %857, float %858, float %859, float %860, float %861, float %862, float %863, float %864, float %865, float %866, float %867, float %868, float %869, float %870, float %871, float %872, float %873, float %874, float %875, float %876, float %877, float %878, float %879, float %880, float %881, float %882, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 0, i32 0, ptr addrspace(3) %528, i32 0, i32 0) #3, !dbg !90 + %884 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %883, 0, !dbg !90 + %885 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %883, 1, !dbg !90 + %886 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %883, 2, !dbg !90 + %887 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %883, 3, !dbg !90 + %888 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %883, 4, !dbg !90 + %889 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %883, 5, !dbg !90 + %890 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %883, 6, !dbg !90 + %891 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %883, 7, !dbg !90 + %892 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %883, 8, !dbg !90 + %893 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %883, 9, !dbg !90 + %894 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %883, 10, !dbg !90 + %895 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %883, 11, !dbg !90 + %896 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %883, 12, !dbg !90 + %897 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %883, 13, !dbg !90 + %898 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %883, 14, !dbg !90 + %899 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %883, 15, !dbg !90 + %900 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %883, 16, !dbg !90 + %901 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %883, 17, !dbg !90 + %902 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %883, 18, !dbg !90 + %903 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %883, 19, !dbg !90 + %904 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %883, 20, !dbg !90 + %905 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %883, 21, !dbg !90 + %906 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %883, 22, !dbg !90 + %907 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %883, 23, !dbg !90 + %908 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %883, 24, !dbg !90 + %909 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %883, 25, !dbg !90 + %910 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %883, 26, !dbg !90 + %911 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %883, 27, !dbg !90 + %912 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %883, 28, !dbg !90 + %913 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %883, 29, !dbg !90 + %914 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %883, 30, !dbg !90 + %915 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %883, 31, !dbg !90 + %916 = fmul float %884, 0x3FB6A09E60000000, !dbg !93 + %917 = fmul float %885, 0x3FB6A09E60000000, !dbg !93 + %918 = fmul float %886, 0x3FB6A09E60000000, !dbg !93 + %919 = fmul float %887, 0x3FB6A09E60000000, !dbg !93 + %920 = fmul float %888, 0x3FB6A09E60000000, !dbg !93 + %921 = fmul float %889, 0x3FB6A09E60000000, !dbg !93 + %922 = fmul float %890, 0x3FB6A09E60000000, !dbg !93 + %923 = fmul float %891, 0x3FB6A09E60000000, !dbg !93 + %924 = fmul float %892, 0x3FB6A09E60000000, !dbg !93 + %925 = fmul float %893, 0x3FB6A09E60000000, !dbg !93 + %926 = fmul float %894, 0x3FB6A09E60000000, !dbg !93 + %927 = fmul float %895, 0x3FB6A09E60000000, !dbg !93 + %928 = fmul float %896, 0x3FB6A09E60000000, !dbg !93 + %929 = fmul float %897, 0x3FB6A09E60000000, !dbg !93 + %930 = fmul float %898, 0x3FB6A09E60000000, !dbg !93 + %931 = fmul float %899, 0x3FB6A09E60000000, !dbg !93 + %932 = fmul float %900, 0x3FB6A09E60000000, !dbg !93 + %933 = fmul float %901, 0x3FB6A09E60000000, !dbg !93 + %934 = fmul float %902, 0x3FB6A09E60000000, !dbg !93 + %935 = fmul float %903, 0x3FB6A09E60000000, !dbg !93 + %936 = fmul float %904, 0x3FB6A09E60000000, !dbg !93 + %937 = fmul float %905, 0x3FB6A09E60000000, !dbg !93 + %938 = fmul float %906, 0x3FB6A09E60000000, !dbg !93 + %939 = fmul float %907, 0x3FB6A09E60000000, !dbg !93 + %940 = fmul float %908, 0x3FB6A09E60000000, !dbg !93 + %941 = fmul float %909, 0x3FB6A09E60000000, !dbg !93 + %942 = fmul float %910, 0x3FB6A09E60000000, !dbg !93 + %943 = fmul float %911, 0x3FB6A09E60000000, !dbg !93 + %944 = fmul float %912, 0x3FB6A09E60000000, !dbg !93 + %945 = fmul float %913, 0x3FB6A09E60000000, !dbg !93 + %946 = fmul float %914, 0x3FB6A09E60000000, !dbg !93 + %947 = fmul float %915, 0x3FB6A09E60000000, !dbg !93 + %948 = extractelement <16 x i32> %520, i64 15, !dbg !94 + %949 = icmp sge i32 %338, %948, !dbg !94 + %950 = extractelement <16 x i32> %520, i64 14, !dbg !95 + %951 = icmp sge i32 %338, %950, !dbg !94 + %952 = icmp sge i32 %341, %948, !dbg !94 + %953 = icmp sge i32 %341, %950, !dbg !94 + %954 = extractelement <16 x i32> %520, i64 13, !dbg !94 + %955 = icmp sge i32 %338, %954, !dbg !94 + %956 = extractelement <16 x i32> %520, i64 12, !dbg !95 + %957 = icmp sge i32 %338, %956, !dbg !94 + %958 = icmp sge i32 %341, %954, !dbg !94 + %959 = icmp sge i32 %341, %956, !dbg !94 + %960 = extractelement <16 x i32> %520, i64 11, !dbg !94 + %961 = icmp sge i32 %338, %960, !dbg !94 + %962 = extractelement <16 x i32> %520, i64 10, !dbg !95 + %963 = icmp sge i32 %338, %962, !dbg !94 + %964 = icmp sge i32 %341, %960, !dbg !94 + %965 = icmp sge i32 %341, %962, !dbg !94 + %966 = extractelement <16 x i32> %520, i64 9, !dbg !94 + %967 = icmp sge i32 %338, %966, !dbg !94 + %968 = extractelement <16 x i32> %520, i64 8, !dbg !95 + %969 = icmp sge i32 %338, %968, !dbg !94 + %970 = icmp sge i32 %341, %966, !dbg !94 + %971 = icmp sge i32 %341, %968, !dbg !94 + %972 = extractelement <16 x i32> %520, i64 7, !dbg !94 + %973 = icmp sge i32 %338, %972, !dbg !94 + %974 = extractelement <16 x i32> %520, i64 6, !dbg !95 + %975 = icmp sge i32 %338, %974, !dbg !94 + %976 = icmp sge i32 %341, %972, !dbg !94 + %977 = icmp sge i32 %341, %974, !dbg !94 + %978 = extractelement <16 x i32> %520, i64 5, !dbg !94 + %979 = icmp sge i32 %338, %978, !dbg !94 + %980 = extractelement <16 x i32> %520, i64 4, !dbg !95 + %981 = icmp sge i32 %338, %980, !dbg !94 + %982 = icmp sge i32 %341, %978, !dbg !94 + %983 = icmp sge i32 %341, %980, !dbg !94 + %984 = extractelement <16 x i32> %520, i64 3, !dbg !94 + %985 = icmp sge i32 %338, %984, !dbg !94 + %986 = extractelement <16 x i32> %520, i64 2, !dbg !95 + %987 = icmp sge i32 %338, %986, !dbg !94 + %988 = icmp sge i32 %341, %984, !dbg !94 + %989 = icmp sge i32 %341, %986, !dbg !94 + %990 = extractelement <16 x i32> %520, i64 1, !dbg !94 + %991 = icmp sge i32 %338, %990, !dbg !94 + %992 = extractelement <16 x i32> %520, i64 0, !dbg !95 + %993 = icmp sge i32 %338, %992, !dbg !94 + %994 = icmp sge i32 %341, %990, !dbg !94 + %995 = icmp sge i32 %341, %992, !dbg !94 + %996 = and i1 %399, %949, !dbg !96 + %997 = and i1 %399, %951, !dbg !96 + %998 = and i1 %400, %952, !dbg !96 + %999 = and i1 %400, %953, !dbg !96 + %1000 = and i1 %399, %955, !dbg !96 + %1001 = and i1 %399, %957, !dbg !96 + %1002 = and i1 %400, %958, !dbg !96 + %1003 = and i1 %400, %959, !dbg !96 + %1004 = and i1 %399, %961, !dbg !96 + %1005 = and i1 %399, %963, !dbg !96 + %1006 = and i1 %400, %964, !dbg !96 + %1007 = and i1 %400, %965, !dbg !96 + %1008 = and i1 %399, %967, !dbg !96 + %1009 = and i1 %399, %969, !dbg !96 + %1010 = and i1 %400, %970, !dbg !96 + %1011 = and i1 %400, %971, !dbg !96 + %1012 = and i1 %399, %973, !dbg !96 + %1013 = and i1 %399, %975, !dbg !96 + %1014 = and i1 %400, %976, !dbg !96 + %1015 = and i1 %400, %977, !dbg !96 + %1016 = and i1 %399, %979, !dbg !96 + %1017 = and i1 %399, %981, !dbg !96 + %1018 = and i1 %400, %982, !dbg !96 + %1019 = and i1 %400, %983, !dbg !96 + %1020 = and i1 %399, %985, !dbg !96 + %1021 = and i1 %399, %987, !dbg !96 + %1022 = and i1 %400, %988, !dbg !96 + %1023 = and i1 %400, %989, !dbg !96 + %1024 = and i1 %399, %991, !dbg !96 + %1025 = and i1 %399, %993, !dbg !96 + %1026 = and i1 %400, %994, !dbg !96 + %1027 = and i1 %400, %995, !dbg !96 + %1028 = icmp sgt i32 %948, 2047, !dbg !97 + %1029 = icmp sgt i32 %954, 2047, !dbg !97 + %1030 = icmp sgt i32 %960, 2047, !dbg !97 + %1031 = icmp sgt i32 %966, 2047, !dbg !97 + %1032 = icmp sgt i32 %972, 2047, !dbg !97 + %1033 = icmp sgt i32 %978, 2047, !dbg !97 + %1034 = icmp sgt i32 %984, 2047, !dbg !97 + %1035 = icmp sgt i32 %990, 2047, !dbg !97 + %1036 = shufflevector <16 x i32> %520, <16 x i32> poison, <8 x i32> , !dbg !95 + %1037 = srem <8 x i32> %1036, splat (i32 2048), !dbg !95 + %1038 = icmp ne <8 x i32> %1037, zeroinitializer, !dbg !98 + %1039 = shufflevector <16 x i32> %520, <16 x i32> poison, <8 x i32> , !dbg !99 + %1040 = and <8 x i32> %1039, splat (i32 -2147481601), !dbg !99 + %1041 = icmp ugt <8 x i32> %1040, splat (i32 -2147483648), !dbg !99 + %1042 = trunc <8 x i32> %1039 to <8 x i16>, !dbg !100 + %1043 = and <8 x i16> %1042, splat (i16 2047), !dbg !100 + %1044 = zext nneg <8 x i16> %1043 to <8 x i64>, !dbg !92 + %1045 = icmp sgt <8 x i64> %450, %1044, !dbg !92 + %1046 = and <8 x i1> %1041, %1038, !dbg !101 + %1047 = add nsw <8 x i32> %1037, splat (i32 2048), !dbg !102 + %1048 = select <8 x i1> %1046, <8 x i32> %1047, <8 x i32> %1037, !dbg !103 + %1049 = sext <8 x i32> %1048 to <8 x i64>, !dbg !92 + %1050 = icmp sgt <8 x i64> %450, %1049, !dbg !92 + %1051 = extractelement <8 x i1> %1045, i64 7, !dbg !104 + %1052 = and i1 %1028, %1051, !dbg !104 + %1053 = extractelement <8 x i1> %1050, i64 7, !dbg !104 + %1054 = and i1 %1028, %1053, !dbg !104 + %1055 = extractelement <8 x i1> %1045, i64 6, !dbg !104 + %1056 = and i1 %1029, %1055, !dbg !104 + %1057 = extractelement <8 x i1> %1050, i64 6, !dbg !104 + %1058 = and i1 %1029, %1057, !dbg !104 + %1059 = extractelement <8 x i1> %1045, i64 5, !dbg !104 + %1060 = and i1 %1030, %1059, !dbg !104 + %1061 = extractelement <8 x i1> %1050, i64 5, !dbg !104 + %1062 = and i1 %1030, %1061, !dbg !104 + %1063 = extractelement <8 x i1> %1045, i64 4, !dbg !104 + %1064 = and i1 %1031, %1063, !dbg !104 + %1065 = extractelement <8 x i1> %1050, i64 4, !dbg !104 + %1066 = and i1 %1031, %1065, !dbg !104 + %1067 = extractelement <8 x i1> %1045, i64 3, !dbg !104 + %1068 = and i1 %1032, %1067, !dbg !104 + %1069 = extractelement <8 x i1> %1050, i64 3, !dbg !104 + %1070 = and i1 %1032, %1069, !dbg !104 + %1071 = extractelement <8 x i1> %1045, i64 2, !dbg !104 + %1072 = and i1 %1033, %1071, !dbg !104 + %1073 = extractelement <8 x i1> %1050, i64 2, !dbg !104 + %1074 = and i1 %1033, %1073, !dbg !104 + %1075 = extractelement <8 x i1> %1045, i64 1, !dbg !104 + %1076 = and i1 %1034, %1075, !dbg !104 + %1077 = extractelement <8 x i1> %1050, i64 1, !dbg !104 + %1078 = and i1 %1034, %1077, !dbg !104 + %1079 = extractelement <8 x i1> %1045, i64 0, !dbg !104 + %1080 = and i1 %1035, %1079, !dbg !104 + %1081 = extractelement <8 x i1> %1050, i64 0, !dbg !104 + %1082 = and i1 %1035, %1081, !dbg !104 + %1083 = sub <32 x i32> %521, %94, !dbg !105 + %1084 = and <32 x i32> %1083, splat (i32 2047), !dbg !106 + %1085 = icmp eq <32 x i32> %1084, zeroinitializer, !dbg !107 + %1086 = extractelement <32 x i1> %1085, i64 31, !dbg !108 + %1087 = and i1 %1086, %1052, !dbg !108 + %1088 = extractelement <32 x i1> %1085, i64 30, !dbg !108 + %1089 = and i1 %1088, %1054, !dbg !108 + %1090 = extractelement <32 x i1> %1085, i64 29, !dbg !108 + %1091 = and i1 %1090, %1052, !dbg !108 + %1092 = extractelement <32 x i1> %1085, i64 28, !dbg !108 + %1093 = and i1 %1092, %1054, !dbg !108 + %1094 = extractelement <32 x i1> %1085, i64 27, !dbg !108 + %1095 = and i1 %1094, %1056, !dbg !108 + %1096 = extractelement <32 x i1> %1085, i64 26, !dbg !108 + %1097 = and i1 %1096, %1058, !dbg !108 + %1098 = extractelement <32 x i1> %1085, i64 25, !dbg !108 + %1099 = and i1 %1098, %1056, !dbg !108 + %1100 = extractelement <32 x i1> %1085, i64 24, !dbg !108 + %1101 = and i1 %1100, %1058, !dbg !108 + %1102 = extractelement <32 x i1> %1085, i64 23, !dbg !108 + %1103 = and i1 %1102, %1060, !dbg !108 + %1104 = extractelement <32 x i1> %1085, i64 22, !dbg !108 + %1105 = and i1 %1104, %1062, !dbg !108 + %1106 = extractelement <32 x i1> %1085, i64 21, !dbg !108 + %1107 = and i1 %1106, %1060, !dbg !108 + %1108 = extractelement <32 x i1> %1085, i64 20, !dbg !108 + %1109 = and i1 %1108, %1062, !dbg !108 + %1110 = extractelement <32 x i1> %1085, i64 19, !dbg !108 + %1111 = and i1 %1110, %1064, !dbg !108 + %1112 = extractelement <32 x i1> %1085, i64 18, !dbg !108 + %1113 = and i1 %1112, %1066, !dbg !108 + %1114 = extractelement <32 x i1> %1085, i64 17, !dbg !108 + %1115 = and i1 %1114, %1064, !dbg !108 + %1116 = extractelement <32 x i1> %1085, i64 16, !dbg !108 + %1117 = and i1 %1116, %1066, !dbg !108 + %1118 = extractelement <32 x i1> %1085, i64 15, !dbg !108 + %1119 = and i1 %1118, %1068, !dbg !108 + %1120 = extractelement <32 x i1> %1085, i64 14, !dbg !108 + %1121 = and i1 %1120, %1070, !dbg !108 + %1122 = extractelement <32 x i1> %1085, i64 13, !dbg !108 + %1123 = and i1 %1122, %1068, !dbg !108 + %1124 = extractelement <32 x i1> %1085, i64 12, !dbg !108 + %1125 = and i1 %1124, %1070, !dbg !108 + %1126 = extractelement <32 x i1> %1085, i64 11, !dbg !108 + %1127 = and i1 %1126, %1072, !dbg !108 + %1128 = extractelement <32 x i1> %1085, i64 10, !dbg !108 + %1129 = and i1 %1128, %1074, !dbg !108 + %1130 = extractelement <32 x i1> %1085, i64 9, !dbg !108 + %1131 = and i1 %1130, %1072, !dbg !108 + %1132 = extractelement <32 x i1> %1085, i64 8, !dbg !108 + %1133 = and i1 %1132, %1074, !dbg !108 + %1134 = extractelement <32 x i1> %1085, i64 7, !dbg !108 + %1135 = and i1 %1134, %1076, !dbg !108 + %1136 = extractelement <32 x i1> %1085, i64 6, !dbg !108 + %1137 = and i1 %1136, %1078, !dbg !108 + %1138 = extractelement <32 x i1> %1085, i64 5, !dbg !108 + %1139 = and i1 %1138, %1076, !dbg !108 + %1140 = extractelement <32 x i1> %1085, i64 4, !dbg !108 + %1141 = and i1 %1140, %1078, !dbg !108 + %1142 = extractelement <32 x i1> %1085, i64 3, !dbg !108 + %1143 = and i1 %1142, %1080, !dbg !108 + %1144 = extractelement <32 x i1> %1085, i64 2, !dbg !108 + %1145 = and i1 %1144, %1082, !dbg !108 + %1146 = extractelement <32 x i1> %1085, i64 1, !dbg !108 + %1147 = and i1 %1146, %1080, !dbg !108 + %1148 = extractelement <32 x i1> %1085, i64 0, !dbg !108 + %1149 = and i1 %1148, %1082, !dbg !108 + %1150 = or i1 %996, %1087, !dbg !109 + %1151 = or i1 %997, %1089, !dbg !109 + %1152 = or i1 %998, %1091, !dbg !109 + %1153 = or i1 %999, %1093, !dbg !109 + %1154 = or i1 %1000, %1095, !dbg !109 + %1155 = or i1 %1001, %1097, !dbg !109 + %1156 = or i1 %1002, %1099, !dbg !109 + %1157 = or i1 %1003, %1101, !dbg !109 + %1158 = or i1 %1004, %1103, !dbg !109 + %1159 = or i1 %1005, %1105, !dbg !109 + %1160 = or i1 %1006, %1107, !dbg !109 + %1161 = or i1 %1007, %1109, !dbg !109 + %1162 = or i1 %1008, %1111, !dbg !109 + %1163 = or i1 %1009, %1113, !dbg !109 + %1164 = or i1 %1010, %1115, !dbg !109 + %1165 = or i1 %1011, %1117, !dbg !109 + %1166 = or i1 %1012, %1119, !dbg !109 + %1167 = or i1 %1013, %1121, !dbg !109 + %1168 = or i1 %1014, %1123, !dbg !109 + %1169 = or i1 %1015, %1125, !dbg !109 + %1170 = or i1 %1016, %1127, !dbg !109 + %1171 = or i1 %1017, %1129, !dbg !109 + %1172 = or i1 %1018, %1131, !dbg !109 + %1173 = or i1 %1019, %1133, !dbg !109 + %1174 = or i1 %1020, %1135, !dbg !109 + %1175 = or i1 %1021, %1137, !dbg !109 + %1176 = or i1 %1022, %1139, !dbg !109 + %1177 = or i1 %1023, %1141, !dbg !109 + %1178 = or i1 %1024, %1143, !dbg !109 + %1179 = or i1 %1025, %1145, !dbg !109 + %1180 = or i1 %1026, %1147, !dbg !109 + %1181 = or i1 %1027, %1149, !dbg !109 + %1182 = fmul float %916, 0x3FF7154760000000, !dbg !110 + %1183 = select i1 %1150, float %1182, float 0xFFF0000000000000, !dbg !111 + %1184 = fmul float %917, 0x3FF7154760000000, !dbg !110 + %1185 = select i1 %1151, float %1184, float 0xFFF0000000000000, !dbg !111 + %1186 = fmul float %918, 0x3FF7154760000000, !dbg !110 + %1187 = select i1 %1152, float %1186, float 0xFFF0000000000000, !dbg !111 + %1188 = fmul float %919, 0x3FF7154760000000, !dbg !110 + %1189 = select i1 %1153, float %1188, float 0xFFF0000000000000, !dbg !111 + %1190 = fmul float %920, 0x3FF7154760000000, !dbg !110 + %1191 = select i1 %1154, float %1190, float 0xFFF0000000000000, !dbg !111 + %1192 = fmul float %921, 0x3FF7154760000000, !dbg !110 + %1193 = select i1 %1155, float %1192, float 0xFFF0000000000000, !dbg !111 + %1194 = fmul float %922, 0x3FF7154760000000, !dbg !110 + %1195 = select i1 %1156, float %1194, float 0xFFF0000000000000, !dbg !111 + %1196 = fmul float %923, 0x3FF7154760000000, !dbg !110 + %1197 = select i1 %1157, float %1196, float 0xFFF0000000000000, !dbg !111 + %1198 = fmul float %924, 0x3FF7154760000000, !dbg !110 + %1199 = select i1 %1158, float %1198, float 0xFFF0000000000000, !dbg !111 + %1200 = fmul float %925, 0x3FF7154760000000, !dbg !110 + %1201 = select i1 %1159, float %1200, float 0xFFF0000000000000, !dbg !111 + %1202 = fmul float %926, 0x3FF7154760000000, !dbg !110 + %1203 = select i1 %1160, float %1202, float 0xFFF0000000000000, !dbg !111 + %1204 = fmul float %927, 0x3FF7154760000000, !dbg !110 + %1205 = select i1 %1161, float %1204, float 0xFFF0000000000000, !dbg !111 + %1206 = fmul float %928, 0x3FF7154760000000, !dbg !110 + %1207 = select i1 %1162, float %1206, float 0xFFF0000000000000, !dbg !111 + %1208 = fmul float %929, 0x3FF7154760000000, !dbg !110 + %1209 = select i1 %1163, float %1208, float 0xFFF0000000000000, !dbg !111 + %1210 = fmul float %930, 0x3FF7154760000000, !dbg !110 + %1211 = select i1 %1164, float %1210, float 0xFFF0000000000000, !dbg !111 + %1212 = fmul float %931, 0x3FF7154760000000, !dbg !110 + %1213 = select i1 %1165, float %1212, float 0xFFF0000000000000, !dbg !111 + %1214 = fmul float %932, 0x3FF7154760000000, !dbg !110 + %1215 = select i1 %1166, float %1214, float 0xFFF0000000000000, !dbg !111 + %1216 = fmul float %933, 0x3FF7154760000000, !dbg !110 + %1217 = select i1 %1167, float %1216, float 0xFFF0000000000000, !dbg !111 + %1218 = fmul float %934, 0x3FF7154760000000, !dbg !110 + %1219 = select i1 %1168, float %1218, float 0xFFF0000000000000, !dbg !111 + %1220 = fmul float %935, 0x3FF7154760000000, !dbg !110 + %1221 = select i1 %1169, float %1220, float 0xFFF0000000000000, !dbg !111 + %1222 = fmul float %936, 0x3FF7154760000000, !dbg !110 + %1223 = select i1 %1170, float %1222, float 0xFFF0000000000000, !dbg !111 + %1224 = fmul float %937, 0x3FF7154760000000, !dbg !110 + %1225 = select i1 %1171, float %1224, float 0xFFF0000000000000, !dbg !111 + %1226 = fmul float %938, 0x3FF7154760000000, !dbg !110 + %1227 = select i1 %1172, float %1226, float 0xFFF0000000000000, !dbg !111 + %1228 = fmul float %939, 0x3FF7154760000000, !dbg !110 + %1229 = select i1 %1173, float %1228, float 0xFFF0000000000000, !dbg !111 + %1230 = fmul float %940, 0x3FF7154760000000, !dbg !110 + %1231 = select i1 %1174, float %1230, float 0xFFF0000000000000, !dbg !111 + %1232 = fmul float %941, 0x3FF7154760000000, !dbg !110 + %1233 = select i1 %1175, float %1232, float 0xFFF0000000000000, !dbg !111 + %1234 = fmul float %942, 0x3FF7154760000000, !dbg !110 + %1235 = select i1 %1176, float %1234, float 0xFFF0000000000000, !dbg !111 + %1236 = fmul float %943, 0x3FF7154760000000, !dbg !110 + %1237 = select i1 %1177, float %1236, float 0xFFF0000000000000, !dbg !111 + %1238 = fmul float %944, 0x3FF7154760000000, !dbg !110 + %1239 = select i1 %1178, float %1238, float 0xFFF0000000000000, !dbg !111 + %1240 = fmul float %945, 0x3FF7154760000000, !dbg !110 + %1241 = select i1 %1179, float %1240, float 0xFFF0000000000000, !dbg !111 + %1242 = fmul float %946, 0x3FF7154760000000, !dbg !110 + %1243 = select i1 %1180, float %1242, float 0xFFF0000000000000, !dbg !111 + %1244 = fmul float %947, 0x3FF7154760000000, !dbg !110 + %1245 = select i1 %1181, float %1244, float 0xFFF0000000000000, !dbg !111 + %1246 = fsub float %1183, %356, !dbg !112 + %1247 = fsub float %1185, %356, !dbg !112 + %1248 = fsub float %1187, %357, !dbg !112 + %1249 = fsub float %1189, %357, !dbg !112 + %1250 = fsub float %1191, %356, !dbg !112 + %1251 = fsub float %1193, %356, !dbg !112 + %1252 = fsub float %1195, %357, !dbg !112 + %1253 = fsub float %1197, %357, !dbg !112 + %1254 = fsub float %1199, %356, !dbg !112 + %1255 = fsub float %1201, %356, !dbg !112 + %1256 = fsub float %1203, %357, !dbg !112 + %1257 = fsub float %1205, %357, !dbg !112 + %1258 = fsub float %1207, %356, !dbg !112 + %1259 = fsub float %1209, %356, !dbg !112 + %1260 = fsub float %1211, %357, !dbg !112 + %1261 = fsub float %1213, %357, !dbg !112 + %1262 = fsub float %1215, %356, !dbg !112 + %1263 = fsub float %1217, %356, !dbg !112 + %1264 = fsub float %1219, %357, !dbg !112 + %1265 = fsub float %1221, %357, !dbg !112 + %1266 = fsub float %1223, %356, !dbg !112 + %1267 = fsub float %1225, %356, !dbg !112 + %1268 = fsub float %1227, %357, !dbg !112 + %1269 = fsub float %1229, %357, !dbg !112 + %1270 = fsub float %1231, %356, !dbg !112 + %1271 = fsub float %1233, %356, !dbg !112 + %1272 = fsub float %1235, %357, !dbg !112 + %1273 = fsub float %1237, %357, !dbg !112 + %1274 = fsub float %1239, %356, !dbg !112 + %1275 = fsub float %1241, %356, !dbg !112 + %1276 = fsub float %1243, %357, !dbg !112 + %1277 = fsub float %1245, %357, !dbg !112 + %1278 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !113 + %.not.i1412 = icmp eq i32 %1278, 0, !dbg !113 + br i1 %.not.i1412, label %1281, label %1279, !dbg !113 + +1279: ; preds = %451 + %1280 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1246) #3, !dbg !113 + br label %__nv_exp2f.exit1414, !dbg !113 + +1281: ; preds = %451 + %1282 = tail call float @llvm.nvvm.ex2.approx.f(float %1246) #3, !dbg !113 + br label %__nv_exp2f.exit1414, !dbg !113 + +__nv_exp2f.exit1414: ; preds = %1279, %1281 + %.0.i1413 = phi float [ %1280, %1279 ], [ %1282, %1281 ], !dbg !113 + %1283 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !113 + %.not.i1415 = icmp eq i32 %1283, 0, !dbg !113 + br i1 %.not.i1415, label %1286, label %1284, !dbg !113 + +1284: ; preds = %__nv_exp2f.exit1414 + %1285 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1247) #3, !dbg !113 + br label %__nv_exp2f.exit1417, !dbg !113 + +1286: ; preds = %__nv_exp2f.exit1414 + %1287 = tail call float @llvm.nvvm.ex2.approx.f(float %1247) #3, !dbg !113 + br label %__nv_exp2f.exit1417, !dbg !113 + +__nv_exp2f.exit1417: ; preds = %1284, %1286 + %.0.i1416 = phi float [ %1285, %1284 ], [ %1287, %1286 ], !dbg !113 + %1288 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !113 + %.not.i1418 = icmp eq i32 %1288, 0, !dbg !113 + br i1 %.not.i1418, label %1291, label %1289, !dbg !113 + +1289: ; preds = %__nv_exp2f.exit1417 + %1290 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1248) #3, !dbg !113 + br label %__nv_exp2f.exit1420, !dbg !113 + +1291: ; preds = %__nv_exp2f.exit1417 + %1292 = tail call float @llvm.nvvm.ex2.approx.f(float %1248) #3, !dbg !113 + br label %__nv_exp2f.exit1420, !dbg !113 + +__nv_exp2f.exit1420: ; preds = %1289, %1291 + %.0.i1419 = phi float [ %1290, %1289 ], [ %1292, %1291 ], !dbg !113 + %1293 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !113 + %.not.i1421 = icmp eq i32 %1293, 0, !dbg !113 + br i1 %.not.i1421, label %1296, label %1294, !dbg !113 + +1294: ; preds = %__nv_exp2f.exit1420 + %1295 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1249) #3, !dbg !113 + br label %__nv_exp2f.exit1423, !dbg !113 + +1296: ; preds = %__nv_exp2f.exit1420 + %1297 = tail call float @llvm.nvvm.ex2.approx.f(float %1249) #3, !dbg !113 + br label %__nv_exp2f.exit1423, !dbg !113 + +__nv_exp2f.exit1423: ; preds = %1294, %1296 + %.0.i1422 = phi float [ %1295, %1294 ], [ %1297, %1296 ], !dbg !113 + %1298 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !113 + %.not.i1424 = icmp eq i32 %1298, 0, !dbg !113 + br i1 %.not.i1424, label %1301, label %1299, !dbg !113 + +1299: ; preds = %__nv_exp2f.exit1423 + %1300 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1250) #3, !dbg !113 + br label %__nv_exp2f.exit1426, !dbg !113 + +1301: ; preds = %__nv_exp2f.exit1423 + %1302 = tail call float @llvm.nvvm.ex2.approx.f(float %1250) #3, !dbg !113 + br label %__nv_exp2f.exit1426, !dbg !113 + +__nv_exp2f.exit1426: ; preds = %1299, %1301 + %.0.i1425 = phi float [ %1300, %1299 ], [ %1302, %1301 ], !dbg !113 + %1303 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !113 + %.not.i1427 = icmp eq i32 %1303, 0, !dbg !113 + br i1 %.not.i1427, label %1306, label %1304, !dbg !113 + +1304: ; preds = %__nv_exp2f.exit1426 + %1305 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1251) #3, !dbg !113 + br label %__nv_exp2f.exit1429, !dbg !113 + +1306: ; preds = %__nv_exp2f.exit1426 + %1307 = tail call float @llvm.nvvm.ex2.approx.f(float %1251) #3, !dbg !113 + br label %__nv_exp2f.exit1429, !dbg !113 + +__nv_exp2f.exit1429: ; preds = %1304, %1306 + %.0.i1428 = phi float [ %1305, %1304 ], [ %1307, %1306 ], !dbg !113 + %1308 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !113 + %.not.i1430 = icmp eq i32 %1308, 0, !dbg !113 + br i1 %.not.i1430, label %1311, label %1309, !dbg !113 + +1309: ; preds = %__nv_exp2f.exit1429 + %1310 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1252) #3, !dbg !113 + br label %__nv_exp2f.exit1432, !dbg !113 + +1311: ; preds = %__nv_exp2f.exit1429 + %1312 = tail call float @llvm.nvvm.ex2.approx.f(float %1252) #3, !dbg !113 + br label %__nv_exp2f.exit1432, !dbg !113 + +__nv_exp2f.exit1432: ; preds = %1309, %1311 + %.0.i1431 = phi float [ %1310, %1309 ], [ %1312, %1311 ], !dbg !113 + %1313 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !113 + %.not.i1433 = icmp eq i32 %1313, 0, !dbg !113 + br i1 %.not.i1433, label %1316, label %1314, !dbg !113 + +1314: ; preds = %__nv_exp2f.exit1432 + %1315 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1253) #3, !dbg !113 + br label %__nv_exp2f.exit1435, !dbg !113 + +1316: ; preds = %__nv_exp2f.exit1432 + %1317 = tail call float @llvm.nvvm.ex2.approx.f(float %1253) #3, !dbg !113 + br label %__nv_exp2f.exit1435, !dbg !113 + +__nv_exp2f.exit1435: ; preds = %1314, %1316 + %.0.i1434 = phi float [ %1315, %1314 ], [ %1317, %1316 ], !dbg !113 + %1318 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !113 + %.not.i1436 = icmp eq i32 %1318, 0, !dbg !113 + br i1 %.not.i1436, label %1321, label %1319, !dbg !113 + +1319: ; preds = %__nv_exp2f.exit1435 + %1320 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1254) #3, !dbg !113 + br label %__nv_exp2f.exit1438, !dbg !113 + +1321: ; preds = %__nv_exp2f.exit1435 + %1322 = tail call float @llvm.nvvm.ex2.approx.f(float %1254) #3, !dbg !113 + br label %__nv_exp2f.exit1438, !dbg !113 + +__nv_exp2f.exit1438: ; preds = %1319, %1321 + %.0.i1437 = phi float [ %1320, %1319 ], [ %1322, %1321 ], !dbg !113 + %1323 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !113 + %.not.i1439 = icmp eq i32 %1323, 0, !dbg !113 + br i1 %.not.i1439, label %1326, label %1324, !dbg !113 + +1324: ; preds = %__nv_exp2f.exit1438 + %1325 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1255) #3, !dbg !113 + br label %__nv_exp2f.exit1441, !dbg !113 + +1326: ; preds = %__nv_exp2f.exit1438 + %1327 = tail call float @llvm.nvvm.ex2.approx.f(float %1255) #3, !dbg !113 + br label %__nv_exp2f.exit1441, !dbg !113 + +__nv_exp2f.exit1441: ; preds = %1324, %1326 + %.0.i1440 = phi float [ %1325, %1324 ], [ %1327, %1326 ], !dbg !113 + %1328 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !113 + %.not.i1442 = icmp eq i32 %1328, 0, !dbg !113 + br i1 %.not.i1442, label %1331, label %1329, !dbg !113 + +1329: ; preds = %__nv_exp2f.exit1441 + %1330 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1256) #3, !dbg !113 + br label %__nv_exp2f.exit1444, !dbg !113 + +1331: ; preds = %__nv_exp2f.exit1441 + %1332 = tail call float @llvm.nvvm.ex2.approx.f(float %1256) #3, !dbg !113 + br label %__nv_exp2f.exit1444, !dbg !113 + +__nv_exp2f.exit1444: ; preds = %1329, %1331 + %.0.i1443 = phi float [ %1330, %1329 ], [ %1332, %1331 ], !dbg !113 + %1333 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !113 + %.not.i1445 = icmp eq i32 %1333, 0, !dbg !113 + br i1 %.not.i1445, label %1336, label %1334, !dbg !113 + +1334: ; preds = %__nv_exp2f.exit1444 + %1335 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1257) #3, !dbg !113 + br label %__nv_exp2f.exit1447, !dbg !113 + +1336: ; preds = %__nv_exp2f.exit1444 + %1337 = tail call float @llvm.nvvm.ex2.approx.f(float %1257) #3, !dbg !113 + br label %__nv_exp2f.exit1447, !dbg !113 + +__nv_exp2f.exit1447: ; preds = %1334, %1336 + %.0.i1446 = phi float [ %1335, %1334 ], [ %1337, %1336 ], !dbg !113 + %1338 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !113 + %.not.i1448 = icmp eq i32 %1338, 0, !dbg !113 + br i1 %.not.i1448, label %1341, label %1339, !dbg !113 + +1339: ; preds = %__nv_exp2f.exit1447 + %1340 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1258) #3, !dbg !113 + br label %__nv_exp2f.exit1450, !dbg !113 + +1341: ; preds = %__nv_exp2f.exit1447 + %1342 = tail call float @llvm.nvvm.ex2.approx.f(float %1258) #3, !dbg !113 + br label %__nv_exp2f.exit1450, !dbg !113 + +__nv_exp2f.exit1450: ; preds = %1339, %1341 + %.0.i1449 = phi float [ %1340, %1339 ], [ %1342, %1341 ], !dbg !113 + %1343 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !113 + %.not.i1451 = icmp eq i32 %1343, 0, !dbg !113 + br i1 %.not.i1451, label %1346, label %1344, !dbg !113 + +1344: ; preds = %__nv_exp2f.exit1450 + %1345 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1259) #3, !dbg !113 + br label %__nv_exp2f.exit1453, !dbg !113 + +1346: ; preds = %__nv_exp2f.exit1450 + %1347 = tail call float @llvm.nvvm.ex2.approx.f(float %1259) #3, !dbg !113 + br label %__nv_exp2f.exit1453, !dbg !113 + +__nv_exp2f.exit1453: ; preds = %1344, %1346 + %.0.i1452 = phi float [ %1345, %1344 ], [ %1347, %1346 ], !dbg !113 + %1348 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !113 + %.not.i1454 = icmp eq i32 %1348, 0, !dbg !113 + br i1 %.not.i1454, label %1351, label %1349, !dbg !113 + +1349: ; preds = %__nv_exp2f.exit1453 + %1350 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1260) #3, !dbg !113 + br label %__nv_exp2f.exit1456, !dbg !113 + +1351: ; preds = %__nv_exp2f.exit1453 + %1352 = tail call float @llvm.nvvm.ex2.approx.f(float %1260) #3, !dbg !113 + br label %__nv_exp2f.exit1456, !dbg !113 + +__nv_exp2f.exit1456: ; preds = %1349, %1351 + %.0.i1455 = phi float [ %1350, %1349 ], [ %1352, %1351 ], !dbg !113 + %1353 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !113 + %.not.i1457 = icmp eq i32 %1353, 0, !dbg !113 + br i1 %.not.i1457, label %1356, label %1354, !dbg !113 + +1354: ; preds = %__nv_exp2f.exit1456 + %1355 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1261) #3, !dbg !113 + br label %__nv_exp2f.exit1459, !dbg !113 + +1356: ; preds = %__nv_exp2f.exit1456 + %1357 = tail call float @llvm.nvvm.ex2.approx.f(float %1261) #3, !dbg !113 + br label %__nv_exp2f.exit1459, !dbg !113 + +__nv_exp2f.exit1459: ; preds = %1354, %1356 + %.0.i1458 = phi float [ %1355, %1354 ], [ %1357, %1356 ], !dbg !113 + %1358 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !113 + %.not.i1460 = icmp eq i32 %1358, 0, !dbg !113 + br i1 %.not.i1460, label %1361, label %1359, !dbg !113 + +1359: ; preds = %__nv_exp2f.exit1459 + %1360 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1262) #3, !dbg !113 + br label %__nv_exp2f.exit1462, !dbg !113 + +1361: ; preds = %__nv_exp2f.exit1459 + %1362 = tail call float @llvm.nvvm.ex2.approx.f(float %1262) #3, !dbg !113 + br label %__nv_exp2f.exit1462, !dbg !113 + +__nv_exp2f.exit1462: ; preds = %1359, %1361 + %.0.i1461 = phi float [ %1360, %1359 ], [ %1362, %1361 ], !dbg !113 + %1363 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !113 + %.not.i1463 = icmp eq i32 %1363, 0, !dbg !113 + br i1 %.not.i1463, label %1366, label %1364, !dbg !113 + +1364: ; preds = %__nv_exp2f.exit1462 + %1365 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1263) #3, !dbg !113 + br label %__nv_exp2f.exit1465, !dbg !113 + +1366: ; preds = %__nv_exp2f.exit1462 + %1367 = tail call float @llvm.nvvm.ex2.approx.f(float %1263) #3, !dbg !113 + br label %__nv_exp2f.exit1465, !dbg !113 + +__nv_exp2f.exit1465: ; preds = %1364, %1366 + %.0.i1464 = phi float [ %1365, %1364 ], [ %1367, %1366 ], !dbg !113 + %1368 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !113 + %.not.i1466 = icmp eq i32 %1368, 0, !dbg !113 + br i1 %.not.i1466, label %1371, label %1369, !dbg !113 + +1369: ; preds = %__nv_exp2f.exit1465 + %1370 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1264) #3, !dbg !113 + br label %__nv_exp2f.exit1468, !dbg !113 + +1371: ; preds = %__nv_exp2f.exit1465 + %1372 = tail call float @llvm.nvvm.ex2.approx.f(float %1264) #3, !dbg !113 + br label %__nv_exp2f.exit1468, !dbg !113 + +__nv_exp2f.exit1468: ; preds = %1369, %1371 + %.0.i1467 = phi float [ %1370, %1369 ], [ %1372, %1371 ], !dbg !113 + %1373 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !113 + %.not.i1469 = icmp eq i32 %1373, 0, !dbg !113 + br i1 %.not.i1469, label %1376, label %1374, !dbg !113 + +1374: ; preds = %__nv_exp2f.exit1468 + %1375 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1265) #3, !dbg !113 + br label %__nv_exp2f.exit1471, !dbg !113 + +1376: ; preds = %__nv_exp2f.exit1468 + %1377 = tail call float @llvm.nvvm.ex2.approx.f(float %1265) #3, !dbg !113 + br label %__nv_exp2f.exit1471, !dbg !113 + +__nv_exp2f.exit1471: ; preds = %1374, %1376 + %.0.i1470 = phi float [ %1375, %1374 ], [ %1377, %1376 ], !dbg !113 + %1378 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !113 + %.not.i1472 = icmp eq i32 %1378, 0, !dbg !113 + br i1 %.not.i1472, label %1381, label %1379, !dbg !113 + +1379: ; preds = %__nv_exp2f.exit1471 + %1380 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1266) #3, !dbg !113 + br label %__nv_exp2f.exit1474, !dbg !113 + +1381: ; preds = %__nv_exp2f.exit1471 + %1382 = tail call float @llvm.nvvm.ex2.approx.f(float %1266) #3, !dbg !113 + br label %__nv_exp2f.exit1474, !dbg !113 + +__nv_exp2f.exit1474: ; preds = %1379, %1381 + %.0.i1473 = phi float [ %1380, %1379 ], [ %1382, %1381 ], !dbg !113 + %1383 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !113 + %.not.i1475 = icmp eq i32 %1383, 0, !dbg !113 + br i1 %.not.i1475, label %1386, label %1384, !dbg !113 + +1384: ; preds = %__nv_exp2f.exit1474 + %1385 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1267) #3, !dbg !113 + br label %__nv_exp2f.exit1477, !dbg !113 + +1386: ; preds = %__nv_exp2f.exit1474 + %1387 = tail call float @llvm.nvvm.ex2.approx.f(float %1267) #3, !dbg !113 + br label %__nv_exp2f.exit1477, !dbg !113 + +__nv_exp2f.exit1477: ; preds = %1384, %1386 + %.0.i1476 = phi float [ %1385, %1384 ], [ %1387, %1386 ], !dbg !113 + %1388 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !113 + %.not.i1478 = icmp eq i32 %1388, 0, !dbg !113 + br i1 %.not.i1478, label %1391, label %1389, !dbg !113 + +1389: ; preds = %__nv_exp2f.exit1477 + %1390 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1268) #3, !dbg !113 + br label %__nv_exp2f.exit1480, !dbg !113 + +1391: ; preds = %__nv_exp2f.exit1477 + %1392 = tail call float @llvm.nvvm.ex2.approx.f(float %1268) #3, !dbg !113 + br label %__nv_exp2f.exit1480, !dbg !113 + +__nv_exp2f.exit1480: ; preds = %1389, %1391 + %.0.i1479 = phi float [ %1390, %1389 ], [ %1392, %1391 ], !dbg !113 + %1393 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !113 + %.not.i1481 = icmp eq i32 %1393, 0, !dbg !113 + br i1 %.not.i1481, label %1396, label %1394, !dbg !113 + +1394: ; preds = %__nv_exp2f.exit1480 + %1395 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1269) #3, !dbg !113 + br label %__nv_exp2f.exit1483, !dbg !113 + +1396: ; preds = %__nv_exp2f.exit1480 + %1397 = tail call float @llvm.nvvm.ex2.approx.f(float %1269) #3, !dbg !113 + br label %__nv_exp2f.exit1483, !dbg !113 + +__nv_exp2f.exit1483: ; preds = %1394, %1396 + %.0.i1482 = phi float [ %1395, %1394 ], [ %1397, %1396 ], !dbg !113 + %1398 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !113 + %.not.i1484 = icmp eq i32 %1398, 0, !dbg !113 + br i1 %.not.i1484, label %1401, label %1399, !dbg !113 + +1399: ; preds = %__nv_exp2f.exit1483 + %1400 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1270) #3, !dbg !113 + br label %__nv_exp2f.exit1486, !dbg !113 + +1401: ; preds = %__nv_exp2f.exit1483 + %1402 = tail call float @llvm.nvvm.ex2.approx.f(float %1270) #3, !dbg !113 + br label %__nv_exp2f.exit1486, !dbg !113 + +__nv_exp2f.exit1486: ; preds = %1399, %1401 + %.0.i1485 = phi float [ %1400, %1399 ], [ %1402, %1401 ], !dbg !113 + %1403 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !113 + %.not.i1487 = icmp eq i32 %1403, 0, !dbg !113 + br i1 %.not.i1487, label %1406, label %1404, !dbg !113 + +1404: ; preds = %__nv_exp2f.exit1486 + %1405 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1271) #3, !dbg !113 + br label %__nv_exp2f.exit1489, !dbg !113 + +1406: ; preds = %__nv_exp2f.exit1486 + %1407 = tail call float @llvm.nvvm.ex2.approx.f(float %1271) #3, !dbg !113 + br label %__nv_exp2f.exit1489, !dbg !113 + +__nv_exp2f.exit1489: ; preds = %1404, %1406 + %.0.i1488 = phi float [ %1405, %1404 ], [ %1407, %1406 ], !dbg !113 + %1408 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !113 + %.not.i1490 = icmp eq i32 %1408, 0, !dbg !113 + br i1 %.not.i1490, label %1411, label %1409, !dbg !113 + +1409: ; preds = %__nv_exp2f.exit1489 + %1410 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1272) #3, !dbg !113 + br label %__nv_exp2f.exit1492, !dbg !113 + +1411: ; preds = %__nv_exp2f.exit1489 + %1412 = tail call float @llvm.nvvm.ex2.approx.f(float %1272) #3, !dbg !113 + br label %__nv_exp2f.exit1492, !dbg !113 + +__nv_exp2f.exit1492: ; preds = %1409, %1411 + %.0.i1491 = phi float [ %1410, %1409 ], [ %1412, %1411 ], !dbg !113 + %1413 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !113 + %.not.i1493 = icmp eq i32 %1413, 0, !dbg !113 + br i1 %.not.i1493, label %1416, label %1414, !dbg !113 + +1414: ; preds = %__nv_exp2f.exit1492 + %1415 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1273) #3, !dbg !113 + br label %__nv_exp2f.exit1495, !dbg !113 + +1416: ; preds = %__nv_exp2f.exit1492 + %1417 = tail call float @llvm.nvvm.ex2.approx.f(float %1273) #3, !dbg !113 + br label %__nv_exp2f.exit1495, !dbg !113 + +__nv_exp2f.exit1495: ; preds = %1414, %1416 + %.0.i1494 = phi float [ %1415, %1414 ], [ %1417, %1416 ], !dbg !113 + %1418 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !113 + %.not.i1496 = icmp eq i32 %1418, 0, !dbg !113 + br i1 %.not.i1496, label %1421, label %1419, !dbg !113 + +1419: ; preds = %__nv_exp2f.exit1495 + %1420 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1274) #3, !dbg !113 + br label %__nv_exp2f.exit1498, !dbg !113 + +1421: ; preds = %__nv_exp2f.exit1495 + %1422 = tail call float @llvm.nvvm.ex2.approx.f(float %1274) #3, !dbg !113 + br label %__nv_exp2f.exit1498, !dbg !113 + +__nv_exp2f.exit1498: ; preds = %1419, %1421 + %.0.i1497 = phi float [ %1420, %1419 ], [ %1422, %1421 ], !dbg !113 + %1423 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !113 + %.not.i1499 = icmp eq i32 %1423, 0, !dbg !113 + br i1 %.not.i1499, label %1426, label %1424, !dbg !113 + +1424: ; preds = %__nv_exp2f.exit1498 + %1425 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1275) #3, !dbg !113 + br label %__nv_exp2f.exit1501, !dbg !113 + +1426: ; preds = %__nv_exp2f.exit1498 + %1427 = tail call float @llvm.nvvm.ex2.approx.f(float %1275) #3, !dbg !113 + br label %__nv_exp2f.exit1501, !dbg !113 + +__nv_exp2f.exit1501: ; preds = %1424, %1426 + %.0.i1500 = phi float [ %1425, %1424 ], [ %1427, %1426 ], !dbg !113 + %1428 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !113 + %.not.i1502 = icmp eq i32 %1428, 0, !dbg !113 + br i1 %.not.i1502, label %1431, label %1429, !dbg !113 + +1429: ; preds = %__nv_exp2f.exit1501 + %1430 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1276) #3, !dbg !113 + br label %__nv_exp2f.exit1504, !dbg !113 + +1431: ; preds = %__nv_exp2f.exit1501 + %1432 = tail call float @llvm.nvvm.ex2.approx.f(float %1276) #3, !dbg !113 + br label %__nv_exp2f.exit1504, !dbg !113 + +__nv_exp2f.exit1504: ; preds = %1429, %1431 + %.0.i1503 = phi float [ %1430, %1429 ], [ %1432, %1431 ], !dbg !113 + %1433 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !113 + %.not.i1505 = icmp eq i32 %1433, 0, !dbg !113 + br i1 %.not.i1505, label %1436, label %1434, !dbg !113 + +1434: ; preds = %__nv_exp2f.exit1504 + %1435 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1277) #3, !dbg !113 + br label %__nv_exp2f.exit1507, !dbg !113 + +1436: ; preds = %__nv_exp2f.exit1504 + %1437 = tail call float @llvm.nvvm.ex2.approx.f(float %1277) #3, !dbg !113 + br label %__nv_exp2f.exit1507, !dbg !113 + +__nv_exp2f.exit1507: ; preds = %1434, %1436 + %.0.i1506 = phi float [ %1435, %1434 ], [ %1437, %1436 ], !dbg !113 + %1438 = getelementptr bfloat, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %527, !dbg !87 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !114 + %1439 = add i32 %531, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !114 + %1440 = lshr exact i32 %1439, 4, !dbg !114 + %1441 = and i32 %1440, 16383, !dbg !114 + %1442 = zext nneg i32 %1441 to i64, !dbg !114 + %1443 = or disjoint i64 %1442, 4611686293372403712, !dbg !114 + %1444 = ptrtoint ptr addrspace(3) %1438 to i32, !dbg !114 + %1445 = lshr exact i32 %1444, 4, !dbg !114 + %1446 = and i32 %1445, 16383, !dbg !114 + %1447 = zext nneg i32 %1446 to i64, !dbg !114 + %1448 = or disjoint i64 %1447, 4611686293338849280, !dbg !114 + %1449 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $32, $33, 0, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,l,l"(i64 %1443, i64 %1448) #3, !dbg !114 + %1450 = add i32 %543, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !114 + %1451 = lshr exact i32 %1450, 4, !dbg !114 + %1452 = and i32 %1451, 16383, !dbg !114 + %1453 = zext nneg i32 %1452 to i64, !dbg !114 + %1454 = or disjoint i64 %1453, 4611686293372403712, !dbg !114 + %1455 = add i32 %1444, 32, !dbg !114 + %1456 = lshr exact i32 %1455, 4, !dbg !114 + %1457 = and i32 %1456, 16383, !dbg !114 + %1458 = zext nneg i32 %1457 to i64, !dbg !114 + %1459 = or disjoint i64 %1458, 4611686293338849280, !dbg !114 + %1460 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1449, 0, !dbg !114 + %1461 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1449, 1, !dbg !114 + %1462 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1449, 2, !dbg !114 + %1463 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1449, 3, !dbg !114 + %1464 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1449, 4, !dbg !114 + %1465 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1449, 5, !dbg !114 + %1466 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1449, 6, !dbg !114 + %1467 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1449, 7, !dbg !114 + %1468 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1449, 8, !dbg !114 + %1469 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1449, 9, !dbg !114 + %1470 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1449, 10, !dbg !114 + %1471 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1449, 11, !dbg !114 + %1472 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1449, 12, !dbg !114 + %1473 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1449, 13, !dbg !114 + %1474 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1449, 14, !dbg !114 + %1475 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1449, 15, !dbg !114 + %1476 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1449, 16, !dbg !114 + %1477 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1449, 17, !dbg !114 + %1478 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1449, 18, !dbg !114 + %1479 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1449, 19, !dbg !114 + %1480 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1449, 20, !dbg !114 + %1481 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1449, 21, !dbg !114 + %1482 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1449, 22, !dbg !114 + %1483 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1449, 23, !dbg !114 + %1484 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1449, 24, !dbg !114 + %1485 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1449, 25, !dbg !114 + %1486 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1449, 26, !dbg !114 + %1487 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1449, 27, !dbg !114 + %1488 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1449, 28, !dbg !114 + %1489 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1449, 29, !dbg !114 + %1490 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1449, 30, !dbg !114 + %1491 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1449, 31, !dbg !114 + %1492 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %1460, float %1461, float %1462, float %1463, float %1464, float %1465, float %1466, float %1467, float %1468, float %1469, float %1470, float %1471, float %1472, float %1473, float %1474, float %1475, float %1476, float %1477, float %1478, float %1479, float %1480, float %1481, float %1482, float %1483, float %1484, float %1485, float %1486, float %1487, float %1488, float %1489, float %1490, float %1491, i64 %1454, i64 %1459, i1 true) #3, !dbg !114 + %1493 = add i32 %587, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !114 + %1494 = lshr exact i32 %1493, 4, !dbg !114 + %1495 = and i32 %1494, 16383, !dbg !114 + %1496 = zext nneg i32 %1495 to i64, !dbg !114 + %1497 = or disjoint i64 %1496, 4611686293372403712, !dbg !114 + %1498 = add i32 %1444, 64, !dbg !114 + %1499 = lshr exact i32 %1498, 4, !dbg !114 + %1500 = and i32 %1499, 16383, !dbg !114 + %1501 = zext nneg i32 %1500 to i64, !dbg !114 + %1502 = or disjoint i64 %1501, 4611686293338849280, !dbg !114 + %1503 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1492, 0, !dbg !114 + %1504 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1492, 1, !dbg !114 + %1505 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1492, 2, !dbg !114 + %1506 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1492, 3, !dbg !114 + %1507 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1492, 4, !dbg !114 + %1508 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1492, 5, !dbg !114 + %1509 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1492, 6, !dbg !114 + %1510 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1492, 7, !dbg !114 + %1511 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1492, 8, !dbg !114 + %1512 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1492, 9, !dbg !114 + %1513 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1492, 10, !dbg !114 + %1514 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1492, 11, !dbg !114 + %1515 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1492, 12, !dbg !114 + %1516 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1492, 13, !dbg !114 + %1517 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1492, 14, !dbg !114 + %1518 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1492, 15, !dbg !114 + %1519 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1492, 16, !dbg !114 + %1520 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1492, 17, !dbg !114 + %1521 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1492, 18, !dbg !114 + %1522 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1492, 19, !dbg !114 + %1523 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1492, 20, !dbg !114 + %1524 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1492, 21, !dbg !114 + %1525 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1492, 22, !dbg !114 + %1526 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1492, 23, !dbg !114 + %1527 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1492, 24, !dbg !114 + %1528 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1492, 25, !dbg !114 + %1529 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1492, 26, !dbg !114 + %1530 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1492, 27, !dbg !114 + %1531 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1492, 28, !dbg !114 + %1532 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1492, 29, !dbg !114 + %1533 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1492, 30, !dbg !114 + %1534 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1492, 31, !dbg !114 + %1535 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %1503, float %1504, float %1505, float %1506, float %1507, float %1508, float %1509, float %1510, float %1511, float %1512, float %1513, float %1514, float %1515, float %1516, float %1517, float %1518, float %1519, float %1520, float %1521, float %1522, float %1523, float %1524, float %1525, float %1526, float %1527, float %1528, float %1529, float %1530, float %1531, float %1532, float %1533, float %1534, i64 %1497, i64 %1502, i1 true) #3, !dbg !114 + %1536 = add i32 %631, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !114 + %1537 = lshr exact i32 %1536, 4, !dbg !114 + %1538 = and i32 %1537, 16383, !dbg !114 + %1539 = zext nneg i32 %1538 to i64, !dbg !114 + %1540 = or disjoint i64 %1539, 4611686293372403712, !dbg !114 + %1541 = add i32 %1444, 96, !dbg !114 + %1542 = lshr exact i32 %1541, 4, !dbg !114 + %1543 = and i32 %1542, 16383, !dbg !114 + %1544 = zext nneg i32 %1543 to i64, !dbg !114 + %1545 = or disjoint i64 %1544, 4611686293338849280, !dbg !114 + %1546 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1535, 0, !dbg !114 + %1547 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1535, 1, !dbg !114 + %1548 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1535, 2, !dbg !114 + %1549 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1535, 3, !dbg !114 + %1550 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1535, 4, !dbg !114 + %1551 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1535, 5, !dbg !114 + %1552 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1535, 6, !dbg !114 + %1553 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1535, 7, !dbg !114 + %1554 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1535, 8, !dbg !114 + %1555 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1535, 9, !dbg !114 + %1556 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1535, 10, !dbg !114 + %1557 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1535, 11, !dbg !114 + %1558 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1535, 12, !dbg !114 + %1559 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1535, 13, !dbg !114 + %1560 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1535, 14, !dbg !114 + %1561 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1535, 15, !dbg !114 + %1562 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1535, 16, !dbg !114 + %1563 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1535, 17, !dbg !114 + %1564 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1535, 18, !dbg !114 + %1565 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1535, 19, !dbg !114 + %1566 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1535, 20, !dbg !114 + %1567 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1535, 21, !dbg !114 + %1568 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1535, 22, !dbg !114 + %1569 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1535, 23, !dbg !114 + %1570 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1535, 24, !dbg !114 + %1571 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1535, 25, !dbg !114 + %1572 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1535, 26, !dbg !114 + %1573 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1535, 27, !dbg !114 + %1574 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1535, 28, !dbg !114 + %1575 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1535, 29, !dbg !114 + %1576 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1535, 30, !dbg !114 + %1577 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1535, 31, !dbg !114 + %1578 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %1546, float %1547, float %1548, float %1549, float %1550, float %1551, float %1552, float %1553, float %1554, float %1555, float %1556, float %1557, float %1558, float %1559, float %1560, float %1561, float %1562, float %1563, float %1564, float %1565, float %1566, float %1567, float %1568, float %1569, float %1570, float %1571, float %1572, float %1573, float %1574, float %1575, float %1576, float %1577, i64 %1540, i64 %1545, i1 true) #3, !dbg !114 + %1579 = add i32 %675, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !114 + %1580 = lshr exact i32 %1579, 4, !dbg !114 + %1581 = and i32 %1580, 16383, !dbg !114 + %1582 = zext nneg i32 %1581 to i64, !dbg !114 + %1583 = or disjoint i64 %1582, 4611686293372403712, !dbg !114 + %1584 = add i32 %1444, 8192, !dbg !114 + %1585 = lshr exact i32 %1584, 4, !dbg !114 + %1586 = and i32 %1585, 16383, !dbg !114 + %1587 = zext nneg i32 %1586 to i64, !dbg !114 + %1588 = or disjoint i64 %1587, 4611686293338849280, !dbg !114 + %1589 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1578, 0, !dbg !114 + %1590 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1578, 1, !dbg !114 + %1591 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1578, 2, !dbg !114 + %1592 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1578, 3, !dbg !114 + %1593 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1578, 4, !dbg !114 + %1594 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1578, 5, !dbg !114 + %1595 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1578, 6, !dbg !114 + %1596 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1578, 7, !dbg !114 + %1597 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1578, 8, !dbg !114 + %1598 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1578, 9, !dbg !114 + %1599 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1578, 10, !dbg !114 + %1600 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1578, 11, !dbg !114 + %1601 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1578, 12, !dbg !114 + %1602 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1578, 13, !dbg !114 + %1603 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1578, 14, !dbg !114 + %1604 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1578, 15, !dbg !114 + %1605 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1578, 16, !dbg !114 + %1606 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1578, 17, !dbg !114 + %1607 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1578, 18, !dbg !114 + %1608 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1578, 19, !dbg !114 + %1609 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1578, 20, !dbg !114 + %1610 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1578, 21, !dbg !114 + %1611 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1578, 22, !dbg !114 + %1612 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1578, 23, !dbg !114 + %1613 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1578, 24, !dbg !114 + %1614 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1578, 25, !dbg !114 + %1615 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1578, 26, !dbg !114 + %1616 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1578, 27, !dbg !114 + %1617 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1578, 28, !dbg !114 + %1618 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1578, 29, !dbg !114 + %1619 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1578, 30, !dbg !114 + %1620 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1578, 31, !dbg !114 + %1621 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %1589, float %1590, float %1591, float %1592, float %1593, float %1594, float %1595, float %1596, float %1597, float %1598, float %1599, float %1600, float %1601, float %1602, float %1603, float %1604, float %1605, float %1606, float %1607, float %1608, float %1609, float %1610, float %1611, float %1612, float %1613, float %1614, float %1615, float %1616, float %1617, float %1618, float %1619, float %1620, i64 %1583, i64 %1588, i1 true) #3, !dbg !114 + %1622 = add i32 %719, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !114 + %1623 = lshr exact i32 %1622, 4, !dbg !114 + %1624 = and i32 %1623, 16383, !dbg !114 + %1625 = zext nneg i32 %1624 to i64, !dbg !114 + %1626 = or disjoint i64 %1625, 4611686293372403712, !dbg !114 + %1627 = add i32 %1444, 8224, !dbg !114 + %1628 = lshr exact i32 %1627, 4, !dbg !114 + %1629 = and i32 %1628, 16383, !dbg !114 + %1630 = zext nneg i32 %1629 to i64, !dbg !114 + %1631 = or disjoint i64 %1630, 4611686293338849280, !dbg !114 + %1632 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1621, 0, !dbg !114 + %1633 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1621, 1, !dbg !114 + %1634 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1621, 2, !dbg !114 + %1635 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1621, 3, !dbg !114 + %1636 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1621, 4, !dbg !114 + %1637 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1621, 5, !dbg !114 + %1638 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1621, 6, !dbg !114 + %1639 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1621, 7, !dbg !114 + %1640 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1621, 8, !dbg !114 + %1641 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1621, 9, !dbg !114 + %1642 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1621, 10, !dbg !114 + %1643 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1621, 11, !dbg !114 + %1644 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1621, 12, !dbg !114 + %1645 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1621, 13, !dbg !114 + %1646 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1621, 14, !dbg !114 + %1647 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1621, 15, !dbg !114 + %1648 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1621, 16, !dbg !114 + %1649 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1621, 17, !dbg !114 + %1650 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1621, 18, !dbg !114 + %1651 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1621, 19, !dbg !114 + %1652 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1621, 20, !dbg !114 + %1653 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1621, 21, !dbg !114 + %1654 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1621, 22, !dbg !114 + %1655 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1621, 23, !dbg !114 + %1656 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1621, 24, !dbg !114 + %1657 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1621, 25, !dbg !114 + %1658 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1621, 26, !dbg !114 + %1659 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1621, 27, !dbg !114 + %1660 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1621, 28, !dbg !114 + %1661 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1621, 29, !dbg !114 + %1662 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1621, 30, !dbg !114 + %1663 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1621, 31, !dbg !114 + %1664 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %1632, float %1633, float %1634, float %1635, float %1636, float %1637, float %1638, float %1639, float %1640, float %1641, float %1642, float %1643, float %1644, float %1645, float %1646, float %1647, float %1648, float %1649, float %1650, float %1651, float %1652, float %1653, float %1654, float %1655, float %1656, float %1657, float %1658, float %1659, float %1660, float %1661, float %1662, float %1663, i64 %1626, i64 %1631, i1 true) #3, !dbg !114 + %1665 = add i32 %763, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !114 + %1666 = lshr exact i32 %1665, 4, !dbg !114 + %1667 = and i32 %1666, 16383, !dbg !114 + %1668 = zext nneg i32 %1667 to i64, !dbg !114 + %1669 = or disjoint i64 %1668, 4611686293372403712, !dbg !114 + %1670 = add i32 %1444, 8256, !dbg !114 + %1671 = lshr exact i32 %1670, 4, !dbg !114 + %1672 = and i32 %1671, 16383, !dbg !114 + %1673 = zext nneg i32 %1672 to i64, !dbg !114 + %1674 = or disjoint i64 %1673, 4611686293338849280, !dbg !114 + %1675 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1664, 0, !dbg !114 + %1676 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1664, 1, !dbg !114 + %1677 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1664, 2, !dbg !114 + %1678 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1664, 3, !dbg !114 + %1679 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1664, 4, !dbg !114 + %1680 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1664, 5, !dbg !114 + %1681 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1664, 6, !dbg !114 + %1682 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1664, 7, !dbg !114 + %1683 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1664, 8, !dbg !114 + %1684 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1664, 9, !dbg !114 + %1685 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1664, 10, !dbg !114 + %1686 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1664, 11, !dbg !114 + %1687 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1664, 12, !dbg !114 + %1688 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1664, 13, !dbg !114 + %1689 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1664, 14, !dbg !114 + %1690 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1664, 15, !dbg !114 + %1691 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1664, 16, !dbg !114 + %1692 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1664, 17, !dbg !114 + %1693 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1664, 18, !dbg !114 + %1694 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1664, 19, !dbg !114 + %1695 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1664, 20, !dbg !114 + %1696 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1664, 21, !dbg !114 + %1697 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1664, 22, !dbg !114 + %1698 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1664, 23, !dbg !114 + %1699 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1664, 24, !dbg !114 + %1700 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1664, 25, !dbg !114 + %1701 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1664, 26, !dbg !114 + %1702 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1664, 27, !dbg !114 + %1703 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1664, 28, !dbg !114 + %1704 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1664, 29, !dbg !114 + %1705 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1664, 30, !dbg !114 + %1706 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1664, 31, !dbg !114 + %1707 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %1675, float %1676, float %1677, float %1678, float %1679, float %1680, float %1681, float %1682, float %1683, float %1684, float %1685, float %1686, float %1687, float %1688, float %1689, float %1690, float %1691, float %1692, float %1693, float %1694, float %1695, float %1696, float %1697, float %1698, float %1699, float %1700, float %1701, float %1702, float %1703, float %1704, float %1705, float %1706, i64 %1669, i64 %1674, i1 true) #3, !dbg !114 + %1708 = add i32 %807, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !114 + %1709 = lshr exact i32 %1708, 4, !dbg !114 + %1710 = and i32 %1709, 16383, !dbg !114 + %1711 = zext nneg i32 %1710 to i64, !dbg !114 + %1712 = or disjoint i64 %1711, 4611686293372403712, !dbg !114 + %1713 = add i32 %1444, 8288, !dbg !114 + %1714 = lshr exact i32 %1713, 4, !dbg !114 + %1715 = and i32 %1714, 16383, !dbg !114 + %1716 = zext nneg i32 %1715 to i64, !dbg !114 + %1717 = or disjoint i64 %1716, 4611686293338849280, !dbg !114 + %1718 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1707, 0, !dbg !114 + %1719 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1707, 1, !dbg !114 + %1720 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1707, 2, !dbg !114 + %1721 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1707, 3, !dbg !114 + %1722 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1707, 4, !dbg !114 + %1723 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1707, 5, !dbg !114 + %1724 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1707, 6, !dbg !114 + %1725 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1707, 7, !dbg !114 + %1726 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1707, 8, !dbg !114 + %1727 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1707, 9, !dbg !114 + %1728 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1707, 10, !dbg !114 + %1729 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1707, 11, !dbg !114 + %1730 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1707, 12, !dbg !114 + %1731 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1707, 13, !dbg !114 + %1732 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1707, 14, !dbg !114 + %1733 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1707, 15, !dbg !114 + %1734 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1707, 16, !dbg !114 + %1735 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1707, 17, !dbg !114 + %1736 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1707, 18, !dbg !114 + %1737 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1707, 19, !dbg !114 + %1738 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1707, 20, !dbg !114 + %1739 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1707, 21, !dbg !114 + %1740 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1707, 22, !dbg !114 + %1741 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1707, 23, !dbg !114 + %1742 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1707, 24, !dbg !114 + %1743 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1707, 25, !dbg !114 + %1744 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1707, 26, !dbg !114 + %1745 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1707, 27, !dbg !114 + %1746 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1707, 28, !dbg !114 + %1747 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1707, 29, !dbg !114 + %1748 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1707, 30, !dbg !114 + %1749 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1707, 31, !dbg !114 + %1750 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %1718, float %1719, float %1720, float %1721, float %1722, float %1723, float %1724, float %1725, float %1726, float %1727, float %1728, float %1729, float %1730, float %1731, float %1732, float %1733, float %1734, float %1735, float %1736, float %1737, float %1738, float %1739, float %1740, float %1741, float %1742, float %1743, float %1744, float %1745, float %1746, float %1747, float %1748, float %1749, i64 %1712, i64 %1717, i1 true) #3, !dbg !114 + %1751 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1750, 0, !dbg !114 + %1752 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1750, 1, !dbg !114 + %1753 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1750, 2, !dbg !114 + %1754 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1750, 3, !dbg !114 + %1755 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1750, 4, !dbg !114 + %1756 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1750, 5, !dbg !114 + %1757 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1750, 6, !dbg !114 + %1758 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1750, 7, !dbg !114 + %1759 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1750, 8, !dbg !114 + %1760 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1750, 9, !dbg !114 + %1761 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1750, 10, !dbg !114 + %1762 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1750, 11, !dbg !114 + %1763 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1750, 12, !dbg !114 + %1764 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1750, 13, !dbg !114 + %1765 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1750, 14, !dbg !114 + %1766 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1750, 15, !dbg !114 + %1767 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1750, 16, !dbg !114 + %1768 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1750, 17, !dbg !114 + %1769 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1750, 18, !dbg !114 + %1770 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1750, 19, !dbg !114 + %1771 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1750, 20, !dbg !114 + %1772 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1750, 21, !dbg !114 + %1773 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1750, 22, !dbg !114 + %1774 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1750, 23, !dbg !114 + %1775 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1750, 24, !dbg !114 + %1776 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1750, 25, !dbg !114 + %1777 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1750, 26, !dbg !114 + %1778 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1750, 27, !dbg !114 + %1779 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1750, 28, !dbg !114 + %1780 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1750, 29, !dbg !114 + %1781 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1750, 30, !dbg !114 + %1782 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1750, 31, !dbg !114 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !114 + %1783 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } asm sideeffect "// wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37\0A\09wgmma.wait_group.sync.aligned 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37"(float %1751, float %1752, float %1753, float %1754, float %1755, float %1756, float %1757, float %1758, float %1759, float %1760, float %1761, float %1762, float %1763, float %1764, float %1765, float %1766, float %1767, float %1768, float %1769, float %1770, float %1771, float %1772, float %1773, float %1774, float %1775, float %1776, float %1777, float %1778, float %1779, float %1780, float %1781, float %1782, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072), i32 0, i32 0, ptr addrspace(3) %1438, i32 0, i32 0) #3, !dbg !114 + %1784 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1783, 0, !dbg !114 + %1785 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1783, 1, !dbg !114 + %1786 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1783, 2, !dbg !114 + %1787 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1783, 3, !dbg !114 + %1788 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1783, 4, !dbg !114 + %1789 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1783, 5, !dbg !114 + %1790 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1783, 6, !dbg !114 + %1791 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1783, 7, !dbg !114 + %1792 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1783, 8, !dbg !114 + %1793 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1783, 9, !dbg !114 + %1794 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1783, 10, !dbg !114 + %1795 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1783, 11, !dbg !114 + %1796 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1783, 12, !dbg !114 + %1797 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1783, 13, !dbg !114 + %1798 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1783, 14, !dbg !114 + %1799 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1783, 15, !dbg !114 + %1800 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1783, 16, !dbg !114 + %1801 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1783, 17, !dbg !114 + %1802 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1783, 18, !dbg !114 + %1803 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1783, 19, !dbg !114 + %1804 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1783, 20, !dbg !114 + %1805 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1783, 21, !dbg !114 + %1806 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1783, 22, !dbg !114 + %1807 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1783, 23, !dbg !114 + %1808 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1783, 24, !dbg !114 + %1809 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1783, 25, !dbg !114 + %1810 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1783, 26, !dbg !114 + %1811 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1783, 27, !dbg !114 + %1812 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1783, 28, !dbg !114 + %1813 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1783, 29, !dbg !114 + %1814 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1783, 30, !dbg !114 + %1815 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1783, 31, !dbg !114 + %1816 = fsub float %1784, %345, !dbg !115 + %1817 = fsub float %1785, %345, !dbg !115 + %1818 = fsub float %1786, %347, !dbg !115 + %1819 = fsub float %1787, %347, !dbg !115 + %1820 = fsub float %1788, %345, !dbg !115 + %1821 = fsub float %1789, %345, !dbg !115 + %1822 = fsub float %1790, %347, !dbg !115 + %1823 = fsub float %1791, %347, !dbg !115 + %1824 = fsub float %1792, %345, !dbg !115 + %1825 = fsub float %1793, %345, !dbg !115 + %1826 = fsub float %1794, %347, !dbg !115 + %1827 = fsub float %1795, %347, !dbg !115 + %1828 = fsub float %1796, %345, !dbg !115 + %1829 = fsub float %1797, %345, !dbg !115 + %1830 = fsub float %1798, %347, !dbg !115 + %1831 = fsub float %1799, %347, !dbg !115 + %1832 = fsub float %1800, %345, !dbg !115 + %1833 = fsub float %1801, %345, !dbg !115 + %1834 = fsub float %1802, %347, !dbg !115 + %1835 = fsub float %1803, %347, !dbg !115 + %1836 = fsub float %1804, %345, !dbg !115 + %1837 = fsub float %1805, %345, !dbg !115 + %1838 = fsub float %1806, %347, !dbg !115 + %1839 = fsub float %1807, %347, !dbg !115 + %1840 = fsub float %1808, %345, !dbg !115 + %1841 = fsub float %1809, %345, !dbg !115 + %1842 = fsub float %1810, %347, !dbg !115 + %1843 = fsub float %1811, %347, !dbg !115 + %1844 = fsub float %1812, %345, !dbg !115 + %1845 = fsub float %1813, %345, !dbg !115 + %1846 = fsub float %1814, %347, !dbg !115 + %1847 = fsub float %1815, %347, !dbg !115 + %1848 = fmul float %.0.i1413, %1816, !dbg !116 + %1849 = fmul float %.0.i1416, %1817, !dbg !116 + %1850 = fmul float %.0.i1419, %1818, !dbg !116 + %1851 = fmul float %.0.i1422, %1819, !dbg !116 + %1852 = fmul float %.0.i1425, %1820, !dbg !116 + %1853 = fmul float %.0.i1428, %1821, !dbg !116 + %1854 = fmul float %.0.i1431, %1822, !dbg !116 + %1855 = fmul float %.0.i1434, %1823, !dbg !116 + %1856 = fmul float %.0.i1437, %1824, !dbg !116 + %1857 = fmul float %.0.i1440, %1825, !dbg !116 + %1858 = fmul float %.0.i1443, %1826, !dbg !116 + %1859 = fmul float %.0.i1446, %1827, !dbg !116 + %1860 = fmul float %.0.i1449, %1828, !dbg !116 + %1861 = fmul float %.0.i1452, %1829, !dbg !116 + %1862 = fmul float %.0.i1455, %1830, !dbg !116 + %1863 = fmul float %.0.i1458, %1831, !dbg !116 + %1864 = fmul float %.0.i1461, %1832, !dbg !116 + %1865 = fmul float %.0.i1464, %1833, !dbg !116 + %1866 = fmul float %.0.i1467, %1834, !dbg !116 + %1867 = fmul float %.0.i1470, %1835, !dbg !116 + %1868 = fmul float %.0.i1473, %1836, !dbg !116 + %1869 = fmul float %.0.i1476, %1837, !dbg !116 + %1870 = fmul float %.0.i1479, %1838, !dbg !116 + %1871 = fmul float %.0.i1482, %1839, !dbg !116 + %1872 = fmul float %.0.i1485, %1840, !dbg !116 + %1873 = fmul float %.0.i1488, %1841, !dbg !116 + %1874 = fmul float %.0.i1491, %1842, !dbg !116 + %1875 = fmul float %.0.i1494, %1843, !dbg !116 + %1876 = fmul float %.0.i1497, %1844, !dbg !116 + %1877 = fmul float %.0.i1500, %1845, !dbg !116 + %1878 = fmul float %.0.i1503, %1846, !dbg !116 + %1879 = fmul float %.0.i1506, %1847, !dbg !116 + %1880 = fptrunc float %1848 to bfloat, !dbg !117 + %1881 = select i1 %1150, bfloat %1880, bfloat 0xR0000, !dbg !118 + %1882 = fptrunc float %1849 to bfloat, !dbg !117 + %1883 = select i1 %1151, bfloat %1882, bfloat 0xR0000, !dbg !118 + %1884 = fptrunc float %1850 to bfloat, !dbg !117 + %1885 = select i1 %1152, bfloat %1884, bfloat 0xR0000, !dbg !118 + %1886 = fptrunc float %1851 to bfloat, !dbg !117 + %1887 = select i1 %1153, bfloat %1886, bfloat 0xR0000, !dbg !118 + %1888 = fptrunc float %1852 to bfloat, !dbg !117 + %1889 = select i1 %1154, bfloat %1888, bfloat 0xR0000, !dbg !118 + %1890 = fptrunc float %1853 to bfloat, !dbg !117 + %1891 = select i1 %1155, bfloat %1890, bfloat 0xR0000, !dbg !118 + %1892 = fptrunc float %1854 to bfloat, !dbg !117 + %1893 = select i1 %1156, bfloat %1892, bfloat 0xR0000, !dbg !118 + %1894 = fptrunc float %1855 to bfloat, !dbg !117 + %1895 = select i1 %1157, bfloat %1894, bfloat 0xR0000, !dbg !118 + %1896 = fptrunc float %1856 to bfloat, !dbg !117 + %1897 = select i1 %1158, bfloat %1896, bfloat 0xR0000, !dbg !118 + %1898 = fptrunc float %1857 to bfloat, !dbg !117 + %1899 = select i1 %1159, bfloat %1898, bfloat 0xR0000, !dbg !118 + %1900 = fptrunc float %1858 to bfloat, !dbg !117 + %1901 = select i1 %1160, bfloat %1900, bfloat 0xR0000, !dbg !118 + %1902 = fptrunc float %1859 to bfloat, !dbg !117 + %1903 = select i1 %1161, bfloat %1902, bfloat 0xR0000, !dbg !118 + %1904 = fptrunc float %1860 to bfloat, !dbg !117 + %1905 = select i1 %1162, bfloat %1904, bfloat 0xR0000, !dbg !118 + %1906 = fptrunc float %1861 to bfloat, !dbg !117 + %1907 = select i1 %1163, bfloat %1906, bfloat 0xR0000, !dbg !118 + %1908 = fptrunc float %1862 to bfloat, !dbg !117 + %1909 = select i1 %1164, bfloat %1908, bfloat 0xR0000, !dbg !118 + %1910 = fptrunc float %1863 to bfloat, !dbg !117 + %1911 = select i1 %1165, bfloat %1910, bfloat 0xR0000, !dbg !118 + %1912 = fptrunc float %1864 to bfloat, !dbg !117 + %1913 = select i1 %1166, bfloat %1912, bfloat 0xR0000, !dbg !118 + %1914 = fptrunc float %1865 to bfloat, !dbg !117 + %1915 = select i1 %1167, bfloat %1914, bfloat 0xR0000, !dbg !118 + %1916 = fptrunc float %1866 to bfloat, !dbg !117 + %1917 = select i1 %1168, bfloat %1916, bfloat 0xR0000, !dbg !118 + %1918 = fptrunc float %1867 to bfloat, !dbg !117 + %1919 = select i1 %1169, bfloat %1918, bfloat 0xR0000, !dbg !118 + %1920 = fptrunc float %1868 to bfloat, !dbg !117 + %1921 = select i1 %1170, bfloat %1920, bfloat 0xR0000, !dbg !118 + %1922 = fptrunc float %1869 to bfloat, !dbg !117 + %1923 = select i1 %1171, bfloat %1922, bfloat 0xR0000, !dbg !118 + %1924 = fptrunc float %1870 to bfloat, !dbg !117 + %1925 = select i1 %1172, bfloat %1924, bfloat 0xR0000, !dbg !118 + %1926 = fptrunc float %1871 to bfloat, !dbg !117 + %1927 = select i1 %1173, bfloat %1926, bfloat 0xR0000, !dbg !118 + %1928 = fptrunc float %1872 to bfloat, !dbg !117 + %1929 = select i1 %1174, bfloat %1928, bfloat 0xR0000, !dbg !118 + %1930 = fptrunc float %1873 to bfloat, !dbg !117 + %1931 = select i1 %1175, bfloat %1930, bfloat 0xR0000, !dbg !118 + %1932 = fptrunc float %1874 to bfloat, !dbg !117 + %1933 = select i1 %1176, bfloat %1932, bfloat 0xR0000, !dbg !118 + %1934 = fptrunc float %1875 to bfloat, !dbg !117 + %1935 = select i1 %1177, bfloat %1934, bfloat 0xR0000, !dbg !118 + %1936 = fptrunc float %1876 to bfloat, !dbg !117 + %1937 = select i1 %1178, bfloat %1936, bfloat 0xR0000, !dbg !118 + %1938 = fptrunc float %1877 to bfloat, !dbg !117 + %1939 = select i1 %1179, bfloat %1938, bfloat 0xR0000, !dbg !118 + %1940 = fptrunc float %1878 to bfloat, !dbg !117 + %1941 = select i1 %1180, bfloat %1940, bfloat 0xR0000, !dbg !118 + %1942 = fptrunc float %1879 to bfloat, !dbg !117 + %1943 = select i1 %1181, bfloat %1942, bfloat 0xR0000, !dbg !118 + %1944 = insertelement <2 x bfloat> poison, bfloat %1881, i64 0, !dbg !119 + %1945 = insertelement <2 x bfloat> %1944, bfloat %1883, i64 1, !dbg !119 + %1946 = bitcast <2 x bfloat> %1945 to i32, !dbg !119 + %1947 = insertelement <2 x bfloat> poison, bfloat %1885, i64 0, !dbg !119 + %1948 = insertelement <2 x bfloat> %1947, bfloat %1887, i64 1, !dbg !119 + %1949 = bitcast <2 x bfloat> %1948 to i32, !dbg !119 + %1950 = insertelement <2 x bfloat> poison, bfloat %1889, i64 0, !dbg !119 + %1951 = insertelement <2 x bfloat> %1950, bfloat %1891, i64 1, !dbg !119 + %1952 = bitcast <2 x bfloat> %1951 to i32, !dbg !119 + %1953 = insertelement <2 x bfloat> poison, bfloat %1893, i64 0, !dbg !119 + %1954 = insertelement <2 x bfloat> %1953, bfloat %1895, i64 1, !dbg !119 + %1955 = bitcast <2 x bfloat> %1954 to i32, !dbg !119 + %1956 = insertelement <2 x bfloat> poison, bfloat %1897, i64 0, !dbg !119 + %1957 = insertelement <2 x bfloat> %1956, bfloat %1899, i64 1, !dbg !119 + %1958 = bitcast <2 x bfloat> %1957 to i32, !dbg !119 + %1959 = insertelement <2 x bfloat> poison, bfloat %1901, i64 0, !dbg !119 + %1960 = insertelement <2 x bfloat> %1959, bfloat %1903, i64 1, !dbg !119 + %1961 = bitcast <2 x bfloat> %1960 to i32, !dbg !119 + %1962 = insertelement <2 x bfloat> poison, bfloat %1905, i64 0, !dbg !119 + %1963 = insertelement <2 x bfloat> %1962, bfloat %1907, i64 1, !dbg !119 + %1964 = bitcast <2 x bfloat> %1963 to i32, !dbg !119 + %1965 = insertelement <2 x bfloat> poison, bfloat %1909, i64 0, !dbg !119 + %1966 = insertelement <2 x bfloat> %1965, bfloat %1911, i64 1, !dbg !119 + %1967 = bitcast <2 x bfloat> %1966 to i32, !dbg !119 + %1968 = insertelement <2 x bfloat> poison, bfloat %1913, i64 0, !dbg !119 + %1969 = insertelement <2 x bfloat> %1968, bfloat %1915, i64 1, !dbg !119 + %1970 = bitcast <2 x bfloat> %1969 to i32, !dbg !119 + %1971 = insertelement <2 x bfloat> poison, bfloat %1917, i64 0, !dbg !119 + %1972 = insertelement <2 x bfloat> %1971, bfloat %1919, i64 1, !dbg !119 + %1973 = bitcast <2 x bfloat> %1972 to i32, !dbg !119 + %1974 = insertelement <2 x bfloat> poison, bfloat %1921, i64 0, !dbg !119 + %1975 = insertelement <2 x bfloat> %1974, bfloat %1923, i64 1, !dbg !119 + %1976 = bitcast <2 x bfloat> %1975 to i32, !dbg !119 + %1977 = insertelement <2 x bfloat> poison, bfloat %1925, i64 0, !dbg !119 + %1978 = insertelement <2 x bfloat> %1977, bfloat %1927, i64 1, !dbg !119 + %1979 = bitcast <2 x bfloat> %1978 to i32, !dbg !119 + %1980 = insertelement <2 x bfloat> poison, bfloat %1929, i64 0, !dbg !119 + %1981 = insertelement <2 x bfloat> %1980, bfloat %1931, i64 1, !dbg !119 + %1982 = bitcast <2 x bfloat> %1981 to i32, !dbg !119 + %1983 = insertelement <2 x bfloat> poison, bfloat %1933, i64 0, !dbg !119 + %1984 = insertelement <2 x bfloat> %1983, bfloat %1935, i64 1, !dbg !119 + %1985 = bitcast <2 x bfloat> %1984 to i32, !dbg !119 + %1986 = insertelement <2 x bfloat> poison, bfloat %1937, i64 0, !dbg !119 + %1987 = insertelement <2 x bfloat> %1986, bfloat %1939, i64 1, !dbg !119 + %1988 = bitcast <2 x bfloat> %1987 to i32, !dbg !119 + %1989 = insertelement <2 x bfloat> poison, bfloat %1941, i64 0, !dbg !119 + %1990 = insertelement <2 x bfloat> %1989, bfloat %1943, i64 1, !dbg !119 + %1991 = bitcast <2 x bfloat> %1990 to i32, !dbg !119 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !119 + %1992 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %455, float %456, float %457, float %458, float %459, float %460, float %461, float %462, float %463, float %464, float %465, float %466, float %467, float %468, float %469, float %470, float %471, float %472, float %473, float %474, float %475, float %476, float %477, float %478, float %479, float %480, float %481, float %482, float %483, float %484, float %485, float %486, float %487, float %488, float %489, float %490, float %491, float %492, float %493, float %494, float %495, float %496, float %497, float %498, float %499, float %500, float %501, float %502, float %503, float %504, float %505, float %506, float %507, float %508, float %509, float %510, float %511, float %512, float %513, float %514, float %515, float %516, float %517, float %518, i32 %1946, i32 %1949, i32 %1952, i32 %1955, i64 %541, i1 true) #3, !dbg !119 + %1993 = add i32 %537, 2048, !dbg !119 + %1994 = lshr exact i32 %1993, 4, !dbg !119 + %1995 = and i32 %1994, 16383, !dbg !119 + %1996 = zext nneg i32 %1995 to i64, !dbg !119 + %1997 = or disjoint i64 %1996, 4611686293338849280, !dbg !119 + %1998 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 0, !dbg !119 + %1999 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 1, !dbg !119 + %2000 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 2, !dbg !119 + %2001 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 3, !dbg !119 + %2002 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 4, !dbg !119 + %2003 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 5, !dbg !119 + %2004 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 6, !dbg !119 + %2005 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 7, !dbg !119 + %2006 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 8, !dbg !119 + %2007 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 9, !dbg !119 + %2008 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 10, !dbg !119 + %2009 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 11, !dbg !119 + %2010 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 12, !dbg !119 + %2011 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 13, !dbg !119 + %2012 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 14, !dbg !119 + %2013 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 15, !dbg !119 + %2014 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 16, !dbg !119 + %2015 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 17, !dbg !119 + %2016 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 18, !dbg !119 + %2017 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 19, !dbg !119 + %2018 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 20, !dbg !119 + %2019 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 21, !dbg !119 + %2020 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 22, !dbg !119 + %2021 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 23, !dbg !119 + %2022 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 24, !dbg !119 + %2023 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 25, !dbg !119 + %2024 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 26, !dbg !119 + %2025 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 27, !dbg !119 + %2026 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 28, !dbg !119 + %2027 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 29, !dbg !119 + %2028 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 30, !dbg !119 + %2029 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 31, !dbg !119 + %2030 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 32, !dbg !119 + %2031 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 33, !dbg !119 + %2032 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 34, !dbg !119 + %2033 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 35, !dbg !119 + %2034 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 36, !dbg !119 + %2035 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 37, !dbg !119 + %2036 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 38, !dbg !119 + %2037 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 39, !dbg !119 + %2038 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 40, !dbg !119 + %2039 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 41, !dbg !119 + %2040 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 42, !dbg !119 + %2041 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 43, !dbg !119 + %2042 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 44, !dbg !119 + %2043 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 45, !dbg !119 + %2044 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 46, !dbg !119 + %2045 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 47, !dbg !119 + %2046 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 48, !dbg !119 + %2047 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 49, !dbg !119 + %2048 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 50, !dbg !119 + %2049 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 51, !dbg !119 + %2050 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 52, !dbg !119 + %2051 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 53, !dbg !119 + %2052 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 54, !dbg !119 + %2053 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 55, !dbg !119 + %2054 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 56, !dbg !119 + %2055 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 57, !dbg !119 + %2056 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 58, !dbg !119 + %2057 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 59, !dbg !119 + %2058 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 60, !dbg !119 + %2059 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 61, !dbg !119 + %2060 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 62, !dbg !119 + %2061 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1992, 63, !dbg !119 + %2062 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %1998, float %1999, float %2000, float %2001, float %2002, float %2003, float %2004, float %2005, float %2006, float %2007, float %2008, float %2009, float %2010, float %2011, float %2012, float %2013, float %2014, float %2015, float %2016, float %2017, float %2018, float %2019, float %2020, float %2021, float %2022, float %2023, float %2024, float %2025, float %2026, float %2027, float %2028, float %2029, float %2030, float %2031, float %2032, float %2033, float %2034, float %2035, float %2036, float %2037, float %2038, float %2039, float %2040, float %2041, float %2042, float %2043, float %2044, float %2045, float %2046, float %2047, float %2048, float %2049, float %2050, float %2051, float %2052, float %2053, float %2054, float %2055, float %2056, float %2057, float %2058, float %2059, float %2060, float %2061, i32 %1958, i32 %1961, i32 %1964, i32 %1967, i64 %1997, i1 true) #3, !dbg !119 + %2063 = add i32 %537, 4096, !dbg !119 + %2064 = lshr exact i32 %2063, 4, !dbg !119 + %2065 = and i32 %2064, 16383, !dbg !119 + %2066 = zext nneg i32 %2065 to i64, !dbg !119 + %2067 = or disjoint i64 %2066, 4611686293338849280, !dbg !119 + %2068 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 0, !dbg !119 + %2069 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 1, !dbg !119 + %2070 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 2, !dbg !119 + %2071 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 3, !dbg !119 + %2072 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 4, !dbg !119 + %2073 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 5, !dbg !119 + %2074 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 6, !dbg !119 + %2075 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 7, !dbg !119 + %2076 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 8, !dbg !119 + %2077 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 9, !dbg !119 + %2078 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 10, !dbg !119 + %2079 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 11, !dbg !119 + %2080 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 12, !dbg !119 + %2081 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 13, !dbg !119 + %2082 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 14, !dbg !119 + %2083 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 15, !dbg !119 + %2084 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 16, !dbg !119 + %2085 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 17, !dbg !119 + %2086 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 18, !dbg !119 + %2087 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 19, !dbg !119 + %2088 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 20, !dbg !119 + %2089 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 21, !dbg !119 + %2090 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 22, !dbg !119 + %2091 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 23, !dbg !119 + %2092 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 24, !dbg !119 + %2093 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 25, !dbg !119 + %2094 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 26, !dbg !119 + %2095 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 27, !dbg !119 + %2096 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 28, !dbg !119 + %2097 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 29, !dbg !119 + %2098 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 30, !dbg !119 + %2099 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 31, !dbg !119 + %2100 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 32, !dbg !119 + %2101 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 33, !dbg !119 + %2102 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 34, !dbg !119 + %2103 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 35, !dbg !119 + %2104 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 36, !dbg !119 + %2105 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 37, !dbg !119 + %2106 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 38, !dbg !119 + %2107 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 39, !dbg !119 + %2108 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 40, !dbg !119 + %2109 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 41, !dbg !119 + %2110 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 42, !dbg !119 + %2111 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 43, !dbg !119 + %2112 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 44, !dbg !119 + %2113 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 45, !dbg !119 + %2114 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 46, !dbg !119 + %2115 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 47, !dbg !119 + %2116 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 48, !dbg !119 + %2117 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 49, !dbg !119 + %2118 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 50, !dbg !119 + %2119 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 51, !dbg !119 + %2120 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 52, !dbg !119 + %2121 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 53, !dbg !119 + %2122 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 54, !dbg !119 + %2123 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 55, !dbg !119 + %2124 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 56, !dbg !119 + %2125 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 57, !dbg !119 + %2126 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 58, !dbg !119 + %2127 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 59, !dbg !119 + %2128 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 60, !dbg !119 + %2129 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 61, !dbg !119 + %2130 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 62, !dbg !119 + %2131 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2062, 63, !dbg !119 + %2132 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %2068, float %2069, float %2070, float %2071, float %2072, float %2073, float %2074, float %2075, float %2076, float %2077, float %2078, float %2079, float %2080, float %2081, float %2082, float %2083, float %2084, float %2085, float %2086, float %2087, float %2088, float %2089, float %2090, float %2091, float %2092, float %2093, float %2094, float %2095, float %2096, float %2097, float %2098, float %2099, float %2100, float %2101, float %2102, float %2103, float %2104, float %2105, float %2106, float %2107, float %2108, float %2109, float %2110, float %2111, float %2112, float %2113, float %2114, float %2115, float %2116, float %2117, float %2118, float %2119, float %2120, float %2121, float %2122, float %2123, float %2124, float %2125, float %2126, float %2127, float %2128, float %2129, float %2130, float %2131, i32 %1970, i32 %1973, i32 %1976, i32 %1979, i64 %2067, i1 true) #3, !dbg !119 + %2133 = add i32 %537, 6144, !dbg !119 + %2134 = lshr exact i32 %2133, 4, !dbg !119 + %2135 = and i32 %2134, 16383, !dbg !119 + %2136 = zext nneg i32 %2135 to i64, !dbg !119 + %2137 = or disjoint i64 %2136, 4611686293338849280, !dbg !119 + %2138 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 0, !dbg !119 + %2139 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 1, !dbg !119 + %2140 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 2, !dbg !119 + %2141 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 3, !dbg !119 + %2142 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 4, !dbg !119 + %2143 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 5, !dbg !119 + %2144 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 6, !dbg !119 + %2145 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 7, !dbg !119 + %2146 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 8, !dbg !119 + %2147 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 9, !dbg !119 + %2148 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 10, !dbg !119 + %2149 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 11, !dbg !119 + %2150 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 12, !dbg !119 + %2151 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 13, !dbg !119 + %2152 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 14, !dbg !119 + %2153 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 15, !dbg !119 + %2154 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 16, !dbg !119 + %2155 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 17, !dbg !119 + %2156 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 18, !dbg !119 + %2157 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 19, !dbg !119 + %2158 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 20, !dbg !119 + %2159 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 21, !dbg !119 + %2160 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 22, !dbg !119 + %2161 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 23, !dbg !119 + %2162 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 24, !dbg !119 + %2163 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 25, !dbg !119 + %2164 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 26, !dbg !119 + %2165 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 27, !dbg !119 + %2166 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 28, !dbg !119 + %2167 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 29, !dbg !119 + %2168 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 30, !dbg !119 + %2169 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 31, !dbg !119 + %2170 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 32, !dbg !119 + %2171 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 33, !dbg !119 + %2172 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 34, !dbg !119 + %2173 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 35, !dbg !119 + %2174 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 36, !dbg !119 + %2175 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 37, !dbg !119 + %2176 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 38, !dbg !119 + %2177 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 39, !dbg !119 + %2178 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 40, !dbg !119 + %2179 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 41, !dbg !119 + %2180 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 42, !dbg !119 + %2181 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 43, !dbg !119 + %2182 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 44, !dbg !119 + %2183 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 45, !dbg !119 + %2184 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 46, !dbg !119 + %2185 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 47, !dbg !119 + %2186 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 48, !dbg !119 + %2187 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 49, !dbg !119 + %2188 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 50, !dbg !119 + %2189 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 51, !dbg !119 + %2190 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 52, !dbg !119 + %2191 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 53, !dbg !119 + %2192 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 54, !dbg !119 + %2193 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 55, !dbg !119 + %2194 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 56, !dbg !119 + %2195 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 57, !dbg !119 + %2196 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 58, !dbg !119 + %2197 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 59, !dbg !119 + %2198 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 60, !dbg !119 + %2199 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 61, !dbg !119 + %2200 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 62, !dbg !119 + %2201 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2132, 63, !dbg !119 + %2202 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %2138, float %2139, float %2140, float %2141, float %2142, float %2143, float %2144, float %2145, float %2146, float %2147, float %2148, float %2149, float %2150, float %2151, float %2152, float %2153, float %2154, float %2155, float %2156, float %2157, float %2158, float %2159, float %2160, float %2161, float %2162, float %2163, float %2164, float %2165, float %2166, float %2167, float %2168, float %2169, float %2170, float %2171, float %2172, float %2173, float %2174, float %2175, float %2176, float %2177, float %2178, float %2179, float %2180, float %2181, float %2182, float %2183, float %2184, float %2185, float %2186, float %2187, float %2188, float %2189, float %2190, float %2191, float %2192, float %2193, float %2194, float %2195, float %2196, float %2197, float %2198, float %2199, float %2200, float %2201, i32 %1982, i32 %1985, i32 %1988, i32 %1991, i64 %2137, i1 true) #3, !dbg !119 + %2203 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 0, !dbg !119 + %2204 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 1, !dbg !119 + %2205 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 2, !dbg !119 + %2206 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 3, !dbg !119 + %2207 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 4, !dbg !119 + %2208 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 5, !dbg !119 + %2209 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 6, !dbg !119 + %2210 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 7, !dbg !119 + %2211 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 8, !dbg !119 + %2212 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 9, !dbg !119 + %2213 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 10, !dbg !119 + %2214 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 11, !dbg !119 + %2215 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 12, !dbg !119 + %2216 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 13, !dbg !119 + %2217 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 14, !dbg !119 + %2218 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 15, !dbg !119 + %2219 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 16, !dbg !119 + %2220 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 17, !dbg !119 + %2221 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 18, !dbg !119 + %2222 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 19, !dbg !119 + %2223 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 20, !dbg !119 + %2224 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 21, !dbg !119 + %2225 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 22, !dbg !119 + %2226 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 23, !dbg !119 + %2227 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 24, !dbg !119 + %2228 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 25, !dbg !119 + %2229 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 26, !dbg !119 + %2230 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 27, !dbg !119 + %2231 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 28, !dbg !119 + %2232 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 29, !dbg !119 + %2233 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 30, !dbg !119 + %2234 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 31, !dbg !119 + %2235 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 32, !dbg !119 + %2236 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 33, !dbg !119 + %2237 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 34, !dbg !119 + %2238 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 35, !dbg !119 + %2239 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 36, !dbg !119 + %2240 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 37, !dbg !119 + %2241 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 38, !dbg !119 + %2242 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 39, !dbg !119 + %2243 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 40, !dbg !119 + %2244 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 41, !dbg !119 + %2245 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 42, !dbg !119 + %2246 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 43, !dbg !119 + %2247 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 44, !dbg !119 + %2248 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 45, !dbg !119 + %2249 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 46, !dbg !119 + %2250 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 47, !dbg !119 + %2251 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 48, !dbg !119 + %2252 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 49, !dbg !119 + %2253 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 50, !dbg !119 + %2254 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 51, !dbg !119 + %2255 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 52, !dbg !119 + %2256 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 53, !dbg !119 + %2257 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 54, !dbg !119 + %2258 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 55, !dbg !119 + %2259 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 56, !dbg !119 + %2260 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 57, !dbg !119 + %2261 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 58, !dbg !119 + %2262 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 59, !dbg !119 + %2263 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 60, !dbg !119 + %2264 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 61, !dbg !119 + %2265 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 62, !dbg !119 + %2266 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2202, 63, !dbg !119 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !119 + %2267 = insertelement <16 x i32> poison, i32 %452, i64 0, !dbg !120 + %2268 = shufflevector <16 x i32> %2267, <16 x i32> poison, <16 x i32> zeroinitializer, !dbg !120 + %2269 = add <16 x i32> %2268, %520, !dbg !120 + %2270 = add nuw nsw i32 %519, 1, !dbg !84 + %2271 = lshr i32 %2270, 1, !dbg !121 + %2272 = zext nneg i32 %2271 to i64, !dbg !122 + %2273 = getelementptr i32, ptr addrspace(1) %359, i64 %2272, !dbg !122 + %2274 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #3, !dbg !123 + %2275 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b32 { $0 }, [ $1 + 0 ], $2;", "=r,l,l,b"(ptr addrspace(1) %2273, i64 %2274, i1 %523) #3, !dbg !123 + %2276 = add nuw nsw i32 %2271, 1, !dbg !124 + %2277 = icmp slt i32 %2276, %364, !dbg !125 + %2278 = getelementptr i8, ptr addrspace(1) %2273, i64 4, !dbg !126 + %2279 = and i1 %523, %2277, !dbg !84 + %2280 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #3, !dbg !127 + %2281 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b32 { $0 }, [ $1 + 0 ], $2;", "=r,l,l,b"(ptr addrspace(1) %2278, i64 %2280, i1 %2279) #3, !dbg !127 + %2282 = and i32 %519, 1, !dbg !128 + %2283 = sub i32 %2281, %2275, !dbg !129 + %2284 = shl i32 %2283, 7, !dbg !130 + %2285 = add i32 %2284, -64, !dbg !131 + %2286 = xor i32 %2282, 1, !dbg !132 + %2287 = mul nuw nsw i32 %2285, %2286, !dbg !132 + %2288 = shl nuw nsw i32 %2282, 6, !dbg !133 + %2289 = add i32 %2287, %2288, !dbg !134 + %2290 = shl i32 %2289, 7, !dbg !135 + %2291 = sext i32 %2290 to i64, !dbg !88 + %2292 = getelementptr bfloat, ptr addrspace(1) %.pn9021821, i64 %2291, !dbg !88 + %2293 = getelementptr bfloat, ptr addrspace(1) %.pn8861822, i64 %2291, !dbg !88 + %2294 = getelementptr bfloat, ptr addrspace(1) %.pn8701823, i64 %2291, !dbg !88 + %2295 = getelementptr bfloat, ptr addrspace(1) %.pn8541824, i64 %2291, !dbg !88 + %2296 = getelementptr bfloat, ptr addrspace(1) %.pn9661825, i64 %2291, !dbg !89 + %2297 = getelementptr bfloat, ptr addrspace(1) %.pn9501826, i64 %2291, !dbg !89 + %2298 = getelementptr bfloat, ptr addrspace(1) %.pn9341827, i64 %2291, !dbg !89 + %2299 = getelementptr bfloat, ptr addrspace(1) %.pn9181828, i64 %2291, !dbg !89 + %2300 = add i32 %454, 1, !dbg !84 + %2301 = icmp sgt i32 %2300, 2, !dbg !84 + %2302 = select i1 %2301, i32 0, i32 %2300, !dbg !84 + %2303 = shl i32 %2302, 13, !dbg !87 + %2304 = getelementptr bfloat, ptr addrspace(3) @global_smem, i32 %2303, !dbg !87 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !87 + %2305 = getelementptr inbounds nuw i8, ptr addrspace(3) %2304, i32 %402, !dbg !87 + %2306 = select i1 %522, i32 16, i32 0, !dbg !87 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %2305, ptr addrspace(1) %2292, i32 %2306) #3, !dbg !87 + %2307 = getelementptr inbounds nuw i8, ptr addrspace(3) %2304, i32 %405, !dbg !87 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %2307, ptr addrspace(1) %2293, i32 %2306) #3, !dbg !87 + %2308 = getelementptr inbounds nuw i8, ptr addrspace(3) %2304, i32 %407, !dbg !87 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %2308, ptr addrspace(1) %2294, i32 %2306) #3, !dbg !87 + %2309 = getelementptr inbounds nuw i8, ptr addrspace(3) %2304, i32 %409, !dbg !87 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %2309, ptr addrspace(1) %2295, i32 %2306) #3, !dbg !87 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !87 + %2310 = getelementptr bfloat, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %2303, !dbg !87 + %2311 = getelementptr inbounds nuw i8, ptr addrspace(3) %2310, i32 %402, !dbg !87 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %2311, ptr addrspace(1) %2296, i32 %2306) #3, !dbg !87 + %2312 = getelementptr inbounds nuw i8, ptr addrspace(3) %2310, i32 %405, !dbg !87 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %2312, ptr addrspace(1) %2297, i32 %2306) #3, !dbg !87 + %2313 = getelementptr inbounds nuw i8, ptr addrspace(3) %2310, i32 %407, !dbg !87 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %2313, ptr addrspace(1) %2298, i32 %2306) #3, !dbg !87 + %2314 = getelementptr inbounds nuw i8, ptr addrspace(3) %2310, i32 %409, !dbg !87 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %2314, ptr addrspace(1) %2299, i32 %2306) #3, !dbg !87 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !87 + %exitcond2125.not = icmp eq i32 %2270, %433, !dbg !84 + br i1 %exitcond2125.not, label %._crit_edge1847, label %451, !dbg !84 + +._crit_edge1847: ; preds = %__nv_exp2f.exit1507, %53 + %2315 = phi float [ 0.000000e+00, %53 ], [ %2203, %__nv_exp2f.exit1507 ] + %2316 = phi float [ 0.000000e+00, %53 ], [ %2204, %__nv_exp2f.exit1507 ] + %2317 = phi float [ 0.000000e+00, %53 ], [ %2205, %__nv_exp2f.exit1507 ] + %2318 = phi float [ 0.000000e+00, %53 ], [ %2206, %__nv_exp2f.exit1507 ] + %2319 = phi float [ 0.000000e+00, %53 ], [ %2207, %__nv_exp2f.exit1507 ] + %2320 = phi float [ 0.000000e+00, %53 ], [ %2208, %__nv_exp2f.exit1507 ] + %2321 = phi float [ 0.000000e+00, %53 ], [ %2209, %__nv_exp2f.exit1507 ] + %2322 = phi float [ 0.000000e+00, %53 ], [ %2210, %__nv_exp2f.exit1507 ] + %2323 = phi float [ 0.000000e+00, %53 ], [ %2211, %__nv_exp2f.exit1507 ] + %2324 = phi float [ 0.000000e+00, %53 ], [ %2212, %__nv_exp2f.exit1507 ] + %2325 = phi float [ 0.000000e+00, %53 ], [ %2213, %__nv_exp2f.exit1507 ] + %2326 = phi float [ 0.000000e+00, %53 ], [ %2214, %__nv_exp2f.exit1507 ] + %2327 = phi float [ 0.000000e+00, %53 ], [ %2215, %__nv_exp2f.exit1507 ] + %2328 = phi float [ 0.000000e+00, %53 ], [ %2216, %__nv_exp2f.exit1507 ] + %2329 = phi float [ 0.000000e+00, %53 ], [ %2217, %__nv_exp2f.exit1507 ] + %2330 = phi float [ 0.000000e+00, %53 ], [ %2218, %__nv_exp2f.exit1507 ] + %2331 = phi float [ 0.000000e+00, %53 ], [ %2219, %__nv_exp2f.exit1507 ] + %2332 = phi float [ 0.000000e+00, %53 ], [ %2220, %__nv_exp2f.exit1507 ] + %2333 = phi float [ 0.000000e+00, %53 ], [ %2221, %__nv_exp2f.exit1507 ] + %2334 = phi float [ 0.000000e+00, %53 ], [ %2222, %__nv_exp2f.exit1507 ] + %2335 = phi float [ 0.000000e+00, %53 ], [ %2223, %__nv_exp2f.exit1507 ] + %2336 = phi float [ 0.000000e+00, %53 ], [ %2224, %__nv_exp2f.exit1507 ] + %2337 = phi float [ 0.000000e+00, %53 ], [ %2225, %__nv_exp2f.exit1507 ] + %2338 = phi float [ 0.000000e+00, %53 ], [ %2226, %__nv_exp2f.exit1507 ] + %2339 = phi float [ 0.000000e+00, %53 ], [ %2227, %__nv_exp2f.exit1507 ] + %2340 = phi float [ 0.000000e+00, %53 ], [ %2228, %__nv_exp2f.exit1507 ] + %2341 = phi float [ 0.000000e+00, %53 ], [ %2229, %__nv_exp2f.exit1507 ] + %2342 = phi float [ 0.000000e+00, %53 ], [ %2230, %__nv_exp2f.exit1507 ] + %2343 = phi float [ 0.000000e+00, %53 ], [ %2231, %__nv_exp2f.exit1507 ] + %2344 = phi float [ 0.000000e+00, %53 ], [ %2232, %__nv_exp2f.exit1507 ] + %2345 = phi float [ 0.000000e+00, %53 ], [ %2233, %__nv_exp2f.exit1507 ] + %2346 = phi float [ 0.000000e+00, %53 ], [ %2234, %__nv_exp2f.exit1507 ] + %2347 = phi float [ 0.000000e+00, %53 ], [ %2235, %__nv_exp2f.exit1507 ] + %2348 = phi float [ 0.000000e+00, %53 ], [ %2236, %__nv_exp2f.exit1507 ] + %2349 = phi float [ 0.000000e+00, %53 ], [ %2237, %__nv_exp2f.exit1507 ] + %2350 = phi float [ 0.000000e+00, %53 ], [ %2238, %__nv_exp2f.exit1507 ] + %2351 = phi float [ 0.000000e+00, %53 ], [ %2239, %__nv_exp2f.exit1507 ] + %2352 = phi float [ 0.000000e+00, %53 ], [ %2240, %__nv_exp2f.exit1507 ] + %2353 = phi float [ 0.000000e+00, %53 ], [ %2241, %__nv_exp2f.exit1507 ] + %2354 = phi float [ 0.000000e+00, %53 ], [ %2242, %__nv_exp2f.exit1507 ] + %2355 = phi float [ 0.000000e+00, %53 ], [ %2243, %__nv_exp2f.exit1507 ] + %2356 = phi float [ 0.000000e+00, %53 ], [ %2244, %__nv_exp2f.exit1507 ] + %2357 = phi float [ 0.000000e+00, %53 ], [ %2245, %__nv_exp2f.exit1507 ] + %2358 = phi float [ 0.000000e+00, %53 ], [ %2246, %__nv_exp2f.exit1507 ] + %2359 = phi float [ 0.000000e+00, %53 ], [ %2247, %__nv_exp2f.exit1507 ] + %2360 = phi float [ 0.000000e+00, %53 ], [ %2248, %__nv_exp2f.exit1507 ] + %2361 = phi float [ 0.000000e+00, %53 ], [ %2249, %__nv_exp2f.exit1507 ] + %2362 = phi float [ 0.000000e+00, %53 ], [ %2250, %__nv_exp2f.exit1507 ] + %2363 = phi float [ 0.000000e+00, %53 ], [ %2251, %__nv_exp2f.exit1507 ] + %2364 = phi float [ 0.000000e+00, %53 ], [ %2252, %__nv_exp2f.exit1507 ] + %2365 = phi float [ 0.000000e+00, %53 ], [ %2253, %__nv_exp2f.exit1507 ] + %2366 = phi float [ 0.000000e+00, %53 ], [ %2254, %__nv_exp2f.exit1507 ] + %2367 = phi float [ 0.000000e+00, %53 ], [ %2255, %__nv_exp2f.exit1507 ] + %2368 = phi float [ 0.000000e+00, %53 ], [ %2256, %__nv_exp2f.exit1507 ] + %2369 = phi float [ 0.000000e+00, %53 ], [ %2257, %__nv_exp2f.exit1507 ] + %2370 = phi float [ 0.000000e+00, %53 ], [ %2258, %__nv_exp2f.exit1507 ] + %2371 = phi float [ 0.000000e+00, %53 ], [ %2259, %__nv_exp2f.exit1507 ] + %2372 = phi float [ 0.000000e+00, %53 ], [ %2260, %__nv_exp2f.exit1507 ] + %2373 = phi float [ 0.000000e+00, %53 ], [ %2261, %__nv_exp2f.exit1507 ] + %2374 = phi float [ 0.000000e+00, %53 ], [ %2262, %__nv_exp2f.exit1507 ] + %2375 = phi float [ 0.000000e+00, %53 ], [ %2263, %__nv_exp2f.exit1507 ] + %2376 = phi float [ 0.000000e+00, %53 ], [ %2264, %__nv_exp2f.exit1507 ] + %2377 = phi float [ 0.000000e+00, %53 ], [ %2265, %__nv_exp2f.exit1507 ] + %2378 = phi float [ 0.000000e+00, %53 ], [ %2266, %__nv_exp2f.exit1507 ] + %2379 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "// wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63\0A\09wgmma.wait_group.sync.aligned 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63"(float %2315, float %2316, float %2317, float %2318, float %2319, float %2320, float %2321, float %2322, float %2323, float %2324, float %2325, float %2326, float %2327, float %2328, float %2329, float %2330, float %2331, float %2332, float %2333, float %2334, float %2335, float %2336, float %2337, float %2338, float %2339, float %2340, float %2341, float %2342, float %2343, float %2344, float %2345, float %2346, float %2347, float %2348, float %2349, float %2350, float %2351, float %2352, float %2353, float %2354, float %2355, float %2356, float %2357, float %2358, float %2359, float %2360, float %2361, float %2362, float %2363, float %2364, float %2365, float %2366, float %2367, float %2368, float %2369, float %2370, float %2371, float %2372, float %2373, float %2374, float %2375, float %2376, float %2377, float %2378) #3, !dbg !84 + tail call void @llvm.nvvm.cp.async.wait.group(i32 0), !dbg !84 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !84 + %2380 = getelementptr i32, ptr addrspace(1) %13, i64 %358, !dbg !136 + %2381 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l"(ptr addrspace(1) %2380) #3, !dbg !137 + %2382 = shl i32 %2381, 7, !dbg !138 + %2383 = getelementptr i32, ptr addrspace(1) %12, i64 %362, !dbg !139 + %2384 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l"(ptr addrspace(1) %2383) #3, !dbg !140 + %2385 = or disjoint i32 %2382, %38, !dbg !141 + %2386 = or disjoint i32 %2382, %39, !dbg !141 + %2387 = or disjoint i32 %2382, %40, !dbg !141 + %2388 = or disjoint i32 %2382, %41, !dbg !141 + %2389 = shl i32 %2385, 7, !dbg !142 + %2390 = shl i32 %2386, 7, !dbg !142 + %2391 = shl i32 %2387, 7, !dbg !142 + %2392 = shl i32 %2388, 7, !dbg !142 + %2393 = sext i32 %2389 to i64, !dbg !144 + %2394 = getelementptr bfloat, ptr addrspace(1) %32, i64 %2393, !dbg !144 + %2395 = sext i32 %2390 to i64, !dbg !144 + %2396 = getelementptr bfloat, ptr addrspace(1) %32, i64 %2395, !dbg !144 + %2397 = sext i32 %2391 to i64, !dbg !144 + %2398 = getelementptr bfloat, ptr addrspace(1) %32, i64 %2397, !dbg !144 + %2399 = sext i32 %2392 to i64, !dbg !144 + %2400 = getelementptr bfloat, ptr addrspace(1) %32, i64 %2399, !dbg !144 + %2401 = getelementptr bfloat, ptr addrspace(1) %2394, i64 %121, !dbg !145 + %2402 = getelementptr bfloat, ptr addrspace(1) %2396, i64 %121, !dbg !145 + %2403 = getelementptr bfloat, ptr addrspace(1) %2398, i64 %121, !dbg !145 + %2404 = getelementptr bfloat, ptr addrspace(1) %2400, i64 %121, !dbg !145 + %2405 = getelementptr bfloat, ptr addrspace(1) %33, i64 %2393, !dbg !146 + %2406 = getelementptr bfloat, ptr addrspace(1) %33, i64 %2395, !dbg !146 + %2407 = getelementptr bfloat, ptr addrspace(1) %33, i64 %2397, !dbg !146 + %2408 = getelementptr bfloat, ptr addrspace(1) %33, i64 %2399, !dbg !146 + %2409 = getelementptr bfloat, ptr addrspace(1) %2405, i64 %121, !dbg !147 + %2410 = getelementptr bfloat, ptr addrspace(1) %2406, i64 %121, !dbg !147 + %2411 = getelementptr bfloat, ptr addrspace(1) %2407, i64 %121, !dbg !147 + %2412 = getelementptr bfloat, ptr addrspace(1) %2408, i64 %121, !dbg !147 + %2413 = shl i32 %2384, 1, !dbg !148 + %2414 = icmp sgt i32 %2413, 0, !dbg !149 + %2415 = select i1 %2414, i32 16, i32 0, !dbg !150 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %403, ptr addrspace(1) %2401, i32 %2415) #3, !dbg !150 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %406, ptr addrspace(1) %2402, i32 %2415) #3, !dbg !150 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %408, ptr addrspace(1) %2403, i32 %2415) #3, !dbg !150 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %410, ptr addrspace(1) %2404, i32 %2415) #3, !dbg !150 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !150 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %411, ptr addrspace(1) %2409, i32 %2415) #3, !dbg !150 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %412, ptr addrspace(1) %2410, i32 %2415) #3, !dbg !150 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %413, ptr addrspace(1) %2411, i32 %2415) #3, !dbg !150 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %414, ptr addrspace(1) %2412, i32 %2415) #3, !dbg !150 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !150 + %2416 = icmp sgt i32 %2413, 1, !dbg !149 + %2417 = getelementptr i8, ptr addrspace(1) %2401, i64 16384, !dbg !151 + %2418 = getelementptr i8, ptr addrspace(1) %2402, i64 16384, !dbg !151 + %2419 = getelementptr i8, ptr addrspace(1) %2403, i64 16384, !dbg !151 + %2420 = getelementptr i8, ptr addrspace(1) %2404, i64 16384, !dbg !151 + %2421 = getelementptr i8, ptr addrspace(1) %2409, i64 16384, !dbg !152 + %2422 = getelementptr i8, ptr addrspace(1) %2410, i64 16384, !dbg !152 + %2423 = getelementptr i8, ptr addrspace(1) %2411, i64 16384, !dbg !152 + %2424 = getelementptr i8, ptr addrspace(1) %2412, i64 16384, !dbg !152 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !150 + %2425 = select i1 %2416, i32 16, i32 0, !dbg !150 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %424, ptr addrspace(1) %2417, i32 %2425) #3, !dbg !150 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %426, ptr addrspace(1) %2418, i32 %2425) #3, !dbg !150 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %427, ptr addrspace(1) %2419, i32 %2425) #3, !dbg !150 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %428, ptr addrspace(1) %2420, i32 %2425) #3, !dbg !150 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !150 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %429, ptr addrspace(1) %2421, i32 %2425) #3, !dbg !150 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %430, ptr addrspace(1) %2422, i32 %2425) #3, !dbg !150 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %431, ptr addrspace(1) %2423, i32 %2425) #3, !dbg !150 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %432, ptr addrspace(1) %2424, i32 %2425) #3, !dbg !150 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !150 + tail call void asm sideeffect "fence.proxy.async.shared::cta;", ""() #3, !dbg !153 + %2426 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 0, !dbg !149 + %2427 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 1, !dbg !149 + %2428 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 2, !dbg !149 + %2429 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 3, !dbg !149 + %2430 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 4, !dbg !149 + %2431 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 5, !dbg !149 + %2432 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 6, !dbg !149 + %2433 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 7, !dbg !149 + %2434 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 8, !dbg !149 + %2435 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 9, !dbg !149 + %2436 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 10, !dbg !149 + %2437 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 11, !dbg !149 + %2438 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 12, !dbg !149 + %2439 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 13, !dbg !149 + %2440 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 14, !dbg !149 + %2441 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 15, !dbg !149 + %2442 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 16, !dbg !149 + %2443 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 17, !dbg !149 + %2444 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 18, !dbg !149 + %2445 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 19, !dbg !149 + %2446 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 20, !dbg !149 + %2447 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 21, !dbg !149 + %2448 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 22, !dbg !149 + %2449 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 23, !dbg !149 + %2450 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 24, !dbg !149 + %2451 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 25, !dbg !149 + %2452 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 26, !dbg !149 + %2453 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 27, !dbg !149 + %2454 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 28, !dbg !149 + %2455 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 29, !dbg !149 + %2456 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 30, !dbg !149 + %2457 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 31, !dbg !149 + %2458 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 32, !dbg !149 + %2459 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 33, !dbg !149 + %2460 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 34, !dbg !149 + %2461 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 35, !dbg !149 + %2462 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 36, !dbg !149 + %2463 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 37, !dbg !149 + %2464 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 38, !dbg !149 + %2465 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 39, !dbg !149 + %2466 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 40, !dbg !149 + %2467 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 41, !dbg !149 + %2468 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 42, !dbg !149 + %2469 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 43, !dbg !149 + %2470 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 44, !dbg !149 + %2471 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 45, !dbg !149 + %2472 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 46, !dbg !149 + %2473 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 47, !dbg !149 + %2474 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 48, !dbg !149 + %2475 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 49, !dbg !149 + %2476 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 50, !dbg !149 + %2477 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 51, !dbg !149 + %2478 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 52, !dbg !149 + %2479 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 53, !dbg !149 + %2480 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 54, !dbg !149 + %2481 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 55, !dbg !149 + %2482 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 56, !dbg !149 + %2483 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 57, !dbg !149 + %2484 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 58, !dbg !149 + %2485 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 59, !dbg !149 + %2486 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 60, !dbg !149 + %2487 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 61, !dbg !149 + %2488 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 62, !dbg !149 + %2489 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 63, !dbg !149 + br i1 %2414, label %.lr.ph1858, label %._crit_edge1859, !dbg !149 + +.lr.ph1858: ; preds = %._crit_edge1847 + %2490 = tail call i32 @llvm.umin.i32(i32 %2413, i32 32), !dbg !154 + %2491 = add nsw i32 %2490, -2 + %2492 = add nsw i32 %2490, -1 + %2493 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 0, !dbg !149 + %2494 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 1, !dbg !149 + %2495 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 2, !dbg !149 + %2496 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 3, !dbg !149 + %2497 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 4, !dbg !149 + %2498 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 5, !dbg !149 + %2499 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 6, !dbg !149 + %2500 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 7, !dbg !149 + %2501 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 8, !dbg !149 + %2502 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 9, !dbg !149 + %2503 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 10, !dbg !149 + %2504 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 11, !dbg !149 + %2505 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 12, !dbg !149 + %2506 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 13, !dbg !149 + %2507 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 14, !dbg !149 + %2508 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 15, !dbg !149 + %2509 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 16, !dbg !149 + %2510 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 17, !dbg !149 + %2511 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 18, !dbg !149 + %2512 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 19, !dbg !149 + %2513 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 20, !dbg !149 + %2514 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 21, !dbg !149 + %2515 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 22, !dbg !149 + %2516 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 23, !dbg !149 + %2517 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 24, !dbg !149 + %2518 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 25, !dbg !149 + %2519 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 26, !dbg !149 + %2520 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 27, !dbg !149 + %2521 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 28, !dbg !149 + %2522 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 29, !dbg !149 + %2523 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 30, !dbg !149 + %2524 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 31, !dbg !149 + %2525 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 32, !dbg !149 + %2526 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 33, !dbg !149 + %2527 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 34, !dbg !149 + %2528 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 35, !dbg !149 + %2529 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 36, !dbg !149 + %2530 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 37, !dbg !149 + %2531 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 38, !dbg !149 + %2532 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 39, !dbg !149 + %2533 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 40, !dbg !149 + %2534 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 41, !dbg !149 + %2535 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 42, !dbg !149 + %2536 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 43, !dbg !149 + %2537 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 44, !dbg !149 + %2538 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 45, !dbg !149 + %2539 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 46, !dbg !149 + %2540 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 47, !dbg !149 + %2541 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 48, !dbg !149 + %2542 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 49, !dbg !149 + %2543 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 50, !dbg !149 + %2544 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 51, !dbg !149 + %2545 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 52, !dbg !149 + %2546 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 53, !dbg !149 + %2547 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 54, !dbg !149 + %2548 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 55, !dbg !149 + %2549 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 56, !dbg !149 + %2550 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 57, !dbg !149 + %2551 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 58, !dbg !149 + %2552 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 59, !dbg !149 + %2553 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 60, !dbg !149 + %2554 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 61, !dbg !149 + %2555 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 62, !dbg !149 + %2556 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2379, 63, !dbg !149 + %2557 = insertelement <2 x float> poison, float %347, i64 0, !dbg !155 + %2558 = shufflevector <2 x float> %2557, <2 x float> poison, <2 x i32> zeroinitializer, !dbg !155 + %2559 = insertelement <2 x float> poison, float %345, i64 0, !dbg !155 + %2560 = shufflevector <2 x float> %2559, <2 x float> poison, <2 x i32> zeroinitializer, !dbg !155 + br label %2561, !dbg !149 + +2561: ; preds = %.lr.ph1858, %__nv_exp2f.exit1411 + %2562 = phi i32 [ -1, %.lr.ph1858 ], [ %2569, %__nv_exp2f.exit1411 ] + %2563 = phi i32 [ 1, %.lr.ph1858 ], [ %4028, %__nv_exp2f.exit1411 ] + %.pn10781856 = phi ptr addrspace(1) [ %2424, %.lr.ph1858 ], [ %4025, %__nv_exp2f.exit1411 ] + %.pn10941855 = phi ptr addrspace(1) [ %2423, %.lr.ph1858 ], [ %4024, %__nv_exp2f.exit1411 ] + %.pn11101854 = phi ptr addrspace(1) [ %2422, %.lr.ph1858 ], [ %4023, %__nv_exp2f.exit1411 ] + %.pn11261853 = phi ptr addrspace(1) [ %2421, %.lr.ph1858 ], [ %4022, %__nv_exp2f.exit1411 ] + %.pn10141852 = phi ptr addrspace(1) [ %2420, %.lr.ph1858 ], [ %4021, %__nv_exp2f.exit1411 ] + %.pn10301851 = phi ptr addrspace(1) [ %2419, %.lr.ph1858 ], [ %4020, %__nv_exp2f.exit1411 ] + %.pn10461850 = phi ptr addrspace(1) [ %2418, %.lr.ph1858 ], [ %4019, %__nv_exp2f.exit1411 ] + %.pn10621849 = phi ptr addrspace(1) [ %2417, %.lr.ph1858 ], [ %4018, %__nv_exp2f.exit1411 ] + %.pn = phi float [ %2493, %.lr.ph1858 ], [ %3932, %__nv_exp2f.exit1411 ] + %.pn2389 = phi float [ %2494, %.lr.ph1858 ], [ %3933, %__nv_exp2f.exit1411 ] + %.pn2390 = phi float [ %2495, %.lr.ph1858 ], [ %3934, %__nv_exp2f.exit1411 ] + %.pn2391 = phi float [ %2496, %.lr.ph1858 ], [ %3935, %__nv_exp2f.exit1411 ] + %.pn2392 = phi float [ %2497, %.lr.ph1858 ], [ %3936, %__nv_exp2f.exit1411 ] + %.pn2393 = phi float [ %2498, %.lr.ph1858 ], [ %3937, %__nv_exp2f.exit1411 ] + %.pn2394 = phi float [ %2499, %.lr.ph1858 ], [ %3938, %__nv_exp2f.exit1411 ] + %.pn2395 = phi float [ %2500, %.lr.ph1858 ], [ %3939, %__nv_exp2f.exit1411 ] + %.pn2396 = phi float [ %2501, %.lr.ph1858 ], [ %3940, %__nv_exp2f.exit1411 ] + %.pn2397 = phi float [ %2502, %.lr.ph1858 ], [ %3941, %__nv_exp2f.exit1411 ] + %.pn2398 = phi float [ %2503, %.lr.ph1858 ], [ %3942, %__nv_exp2f.exit1411 ] + %.pn2399 = phi float [ %2504, %.lr.ph1858 ], [ %3943, %__nv_exp2f.exit1411 ] + %.pn2400 = phi float [ %2505, %.lr.ph1858 ], [ %3944, %__nv_exp2f.exit1411 ] + %.pn2401 = phi float [ %2506, %.lr.ph1858 ], [ %3945, %__nv_exp2f.exit1411 ] + %.pn2402 = phi float [ %2507, %.lr.ph1858 ], [ %3946, %__nv_exp2f.exit1411 ] + %.pn2403 = phi float [ %2508, %.lr.ph1858 ], [ %3947, %__nv_exp2f.exit1411 ] + %.pn2404 = phi float [ %2509, %.lr.ph1858 ], [ %3948, %__nv_exp2f.exit1411 ] + %.pn2405 = phi float [ %2510, %.lr.ph1858 ], [ %3949, %__nv_exp2f.exit1411 ] + %.pn2406 = phi float [ %2511, %.lr.ph1858 ], [ %3950, %__nv_exp2f.exit1411 ] + %.pn2407 = phi float [ %2512, %.lr.ph1858 ], [ %3951, %__nv_exp2f.exit1411 ] + %.pn2408 = phi float [ %2513, %.lr.ph1858 ], [ %3952, %__nv_exp2f.exit1411 ] + %.pn2409 = phi float [ %2514, %.lr.ph1858 ], [ %3953, %__nv_exp2f.exit1411 ] + %.pn2410 = phi float [ %2515, %.lr.ph1858 ], [ %3954, %__nv_exp2f.exit1411 ] + %.pn2411 = phi float [ %2516, %.lr.ph1858 ], [ %3955, %__nv_exp2f.exit1411 ] + %.pn2412 = phi float [ %2517, %.lr.ph1858 ], [ %3956, %__nv_exp2f.exit1411 ] + %.pn2413 = phi float [ %2518, %.lr.ph1858 ], [ %3957, %__nv_exp2f.exit1411 ] + %.pn2414 = phi float [ %2519, %.lr.ph1858 ], [ %3958, %__nv_exp2f.exit1411 ] + %.pn2415 = phi float [ %2520, %.lr.ph1858 ], [ %3959, %__nv_exp2f.exit1411 ] + %.pn2416 = phi float [ %2521, %.lr.ph1858 ], [ %3960, %__nv_exp2f.exit1411 ] + %.pn2417 = phi float [ %2522, %.lr.ph1858 ], [ %3961, %__nv_exp2f.exit1411 ] + %.pn2418 = phi float [ %2523, %.lr.ph1858 ], [ %3962, %__nv_exp2f.exit1411 ] + %.pn2419 = phi float [ %2524, %.lr.ph1858 ], [ %3963, %__nv_exp2f.exit1411 ] + %.pn2420 = phi float [ %2525, %.lr.ph1858 ], [ %3964, %__nv_exp2f.exit1411 ] + %.pn2421 = phi float [ %2526, %.lr.ph1858 ], [ %3965, %__nv_exp2f.exit1411 ] + %.pn2422 = phi float [ %2527, %.lr.ph1858 ], [ %3966, %__nv_exp2f.exit1411 ] + %.pn2423 = phi float [ %2528, %.lr.ph1858 ], [ %3967, %__nv_exp2f.exit1411 ] + %.pn2424 = phi float [ %2529, %.lr.ph1858 ], [ %3968, %__nv_exp2f.exit1411 ] + %.pn2425 = phi float [ %2530, %.lr.ph1858 ], [ %3969, %__nv_exp2f.exit1411 ] + %.pn2426 = phi float [ %2531, %.lr.ph1858 ], [ %3970, %__nv_exp2f.exit1411 ] + %.pn2427 = phi float [ %2532, %.lr.ph1858 ], [ %3971, %__nv_exp2f.exit1411 ] + %.pn2428 = phi float [ %2533, %.lr.ph1858 ], [ %3972, %__nv_exp2f.exit1411 ] + %.pn2429 = phi float [ %2534, %.lr.ph1858 ], [ %3973, %__nv_exp2f.exit1411 ] + %.pn2430 = phi float [ %2535, %.lr.ph1858 ], [ %3974, %__nv_exp2f.exit1411 ] + %.pn2431 = phi float [ %2536, %.lr.ph1858 ], [ %3975, %__nv_exp2f.exit1411 ] + %.pn2432 = phi float [ %2537, %.lr.ph1858 ], [ %3976, %__nv_exp2f.exit1411 ] + %.pn2433 = phi float [ %2538, %.lr.ph1858 ], [ %3977, %__nv_exp2f.exit1411 ] + %.pn2434 = phi float [ %2539, %.lr.ph1858 ], [ %3978, %__nv_exp2f.exit1411 ] + %.pn2435 = phi float [ %2540, %.lr.ph1858 ], [ %3979, %__nv_exp2f.exit1411 ] + %.pn2436 = phi float [ %2541, %.lr.ph1858 ], [ %3980, %__nv_exp2f.exit1411 ] + %.pn2437 = phi float [ %2542, %.lr.ph1858 ], [ %3981, %__nv_exp2f.exit1411 ] + %.pn2438 = phi float [ %2543, %.lr.ph1858 ], [ %3982, %__nv_exp2f.exit1411 ] + %.pn2439 = phi float [ %2544, %.lr.ph1858 ], [ %3983, %__nv_exp2f.exit1411 ] + %.pn2440 = phi float [ %2545, %.lr.ph1858 ], [ %3984, %__nv_exp2f.exit1411 ] + %.pn2441 = phi float [ %2546, %.lr.ph1858 ], [ %3985, %__nv_exp2f.exit1411 ] + %.pn2442 = phi float [ %2547, %.lr.ph1858 ], [ %3986, %__nv_exp2f.exit1411 ] + %.pn2443 = phi float [ %2548, %.lr.ph1858 ], [ %3987, %__nv_exp2f.exit1411 ] + %.pn2444 = phi float [ %2549, %.lr.ph1858 ], [ %3988, %__nv_exp2f.exit1411 ] + %.pn2445 = phi float [ %2550, %.lr.ph1858 ], [ %3989, %__nv_exp2f.exit1411 ] + %.pn2446 = phi float [ %2551, %.lr.ph1858 ], [ %3990, %__nv_exp2f.exit1411 ] + %.pn2447 = phi float [ %2552, %.lr.ph1858 ], [ %3991, %__nv_exp2f.exit1411 ] + %.pn2448 = phi float [ %2553, %.lr.ph1858 ], [ %3992, %__nv_exp2f.exit1411 ] + %.pn2449 = phi float [ %2554, %.lr.ph1858 ], [ %3993, %__nv_exp2f.exit1411 ] + %.pn2450 = phi float [ %2555, %.lr.ph1858 ], [ %3994, %__nv_exp2f.exit1411 ] + %.pn2451 = phi float [ %2556, %.lr.ph1858 ], [ %3995, %__nv_exp2f.exit1411 ] + %2564 = phi i32 [ 0, %.lr.ph1858 ], [ %3996, %__nv_exp2f.exit1411 ] + %2565 = icmp slt i32 %2564, %2491, !dbg !149 + %2566 = icmp slt i32 %2564, %2492, !dbg !149 + %2567 = add i32 %2562, 1, !dbg !149 + %2568 = icmp sgt i32 %2567, 2, !dbg !149 + %2569 = select i1 %2568, i32 0, i32 %2567, !dbg !149 + tail call void @llvm.nvvm.cp.async.wait.group(i32 2), !dbg !150 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !150 + %2570 = shl i32 %2569, 13, !dbg !150 + %2571 = getelementptr bfloat, ptr addrspace(3) @global_smem, i32 %2570, !dbg !150 + %2572 = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 %36, i32 0, i32 31), !dbg !153 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !153 + %2573 = shl i32 %2572, 11, !dbg !153 + %2574 = and i32 %2573, 8192, !dbg !153 + %2575 = add i32 %2574, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !153 + %2576 = lshr exact i32 %2575, 4, !dbg !153 + %2577 = and i32 %2576, 16383, !dbg !153 + %2578 = zext nneg i32 %2577 to i64, !dbg !153 + %2579 = or disjoint i64 %2578, 4611686293372403712, !dbg !153 + %2580 = ptrtoint ptr addrspace(3) %2571 to i32, !dbg !153 + %2581 = lshr exact i32 %2580, 4, !dbg !153 + %2582 = and i32 %2581, 16383, !dbg !153 + %2583 = zext nneg i32 %2582 to i64, !dbg !153 + %2584 = or disjoint i64 %2583, 4611686293338849280, !dbg !153 + %2585 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $32, $33, 0, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,l,l"(i64 %2579, i64 %2584) #3, !dbg !153 + %2586 = or disjoint i32 %2574, 32, !dbg !153 + %2587 = add i32 %2586, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !153 + %2588 = lshr exact i32 %2587, 4, !dbg !153 + %2589 = and i32 %2588, 16383, !dbg !153 + %2590 = zext nneg i32 %2589 to i64, !dbg !153 + %2591 = or disjoint i64 %2590, 4611686293372403712, !dbg !153 + %2592 = add i32 %2580, 32, !dbg !153 + %2593 = lshr exact i32 %2592, 4, !dbg !153 + %2594 = and i32 %2593, 16383, !dbg !153 + %2595 = zext nneg i32 %2594 to i64, !dbg !153 + %2596 = or disjoint i64 %2595, 4611686293338849280, !dbg !153 + %2597 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2585, 0, !dbg !153 + %2598 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2585, 1, !dbg !153 + %2599 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2585, 2, !dbg !153 + %2600 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2585, 3, !dbg !153 + %2601 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2585, 4, !dbg !153 + %2602 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2585, 5, !dbg !153 + %2603 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2585, 6, !dbg !153 + %2604 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2585, 7, !dbg !153 + %2605 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2585, 8, !dbg !153 + %2606 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2585, 9, !dbg !153 + %2607 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2585, 10, !dbg !153 + %2608 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2585, 11, !dbg !153 + %2609 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2585, 12, !dbg !153 + %2610 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2585, 13, !dbg !153 + %2611 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2585, 14, !dbg !153 + %2612 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2585, 15, !dbg !153 + %2613 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2585, 16, !dbg !153 + %2614 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2585, 17, !dbg !153 + %2615 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2585, 18, !dbg !153 + %2616 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2585, 19, !dbg !153 + %2617 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2585, 20, !dbg !153 + %2618 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2585, 21, !dbg !153 + %2619 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2585, 22, !dbg !153 + %2620 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2585, 23, !dbg !153 + %2621 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2585, 24, !dbg !153 + %2622 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2585, 25, !dbg !153 + %2623 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2585, 26, !dbg !153 + %2624 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2585, 27, !dbg !153 + %2625 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2585, 28, !dbg !153 + %2626 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2585, 29, !dbg !153 + %2627 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2585, 30, !dbg !153 + %2628 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2585, 31, !dbg !153 + %2629 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %2597, float %2598, float %2599, float %2600, float %2601, float %2602, float %2603, float %2604, float %2605, float %2606, float %2607, float %2608, float %2609, float %2610, float %2611, float %2612, float %2613, float %2614, float %2615, float %2616, float %2617, float %2618, float %2619, float %2620, float %2621, float %2622, float %2623, float %2624, float %2625, float %2626, float %2627, float %2628, i64 %2591, i64 %2596, i1 true) #3, !dbg !153 + %2630 = or disjoint i32 %2574, 64, !dbg !153 + %2631 = add i32 %2630, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !153 + %2632 = lshr exact i32 %2631, 4, !dbg !153 + %2633 = and i32 %2632, 16383, !dbg !153 + %2634 = zext nneg i32 %2633 to i64, !dbg !153 + %2635 = or disjoint i64 %2634, 4611686293372403712, !dbg !153 + %2636 = add i32 %2580, 64, !dbg !153 + %2637 = lshr exact i32 %2636, 4, !dbg !153 + %2638 = and i32 %2637, 16383, !dbg !153 + %2639 = zext nneg i32 %2638 to i64, !dbg !153 + %2640 = or disjoint i64 %2639, 4611686293338849280, !dbg !153 + %2641 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2629, 0, !dbg !153 + %2642 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2629, 1, !dbg !153 + %2643 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2629, 2, !dbg !153 + %2644 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2629, 3, !dbg !153 + %2645 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2629, 4, !dbg !153 + %2646 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2629, 5, !dbg !153 + %2647 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2629, 6, !dbg !153 + %2648 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2629, 7, !dbg !153 + %2649 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2629, 8, !dbg !153 + %2650 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2629, 9, !dbg !153 + %2651 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2629, 10, !dbg !153 + %2652 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2629, 11, !dbg !153 + %2653 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2629, 12, !dbg !153 + %2654 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2629, 13, !dbg !153 + %2655 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2629, 14, !dbg !153 + %2656 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2629, 15, !dbg !153 + %2657 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2629, 16, !dbg !153 + %2658 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2629, 17, !dbg !153 + %2659 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2629, 18, !dbg !153 + %2660 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2629, 19, !dbg !153 + %2661 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2629, 20, !dbg !153 + %2662 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2629, 21, !dbg !153 + %2663 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2629, 22, !dbg !153 + %2664 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2629, 23, !dbg !153 + %2665 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2629, 24, !dbg !153 + %2666 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2629, 25, !dbg !153 + %2667 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2629, 26, !dbg !153 + %2668 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2629, 27, !dbg !153 + %2669 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2629, 28, !dbg !153 + %2670 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2629, 29, !dbg !153 + %2671 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2629, 30, !dbg !153 + %2672 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2629, 31, !dbg !153 + %2673 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %2641, float %2642, float %2643, float %2644, float %2645, float %2646, float %2647, float %2648, float %2649, float %2650, float %2651, float %2652, float %2653, float %2654, float %2655, float %2656, float %2657, float %2658, float %2659, float %2660, float %2661, float %2662, float %2663, float %2664, float %2665, float %2666, float %2667, float %2668, float %2669, float %2670, float %2671, float %2672, i64 %2635, i64 %2640, i1 true) #3, !dbg !153 + %2674 = or disjoint i32 %2574, 96, !dbg !153 + %2675 = add i32 %2674, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !153 + %2676 = lshr exact i32 %2675, 4, !dbg !153 + %2677 = and i32 %2676, 16383, !dbg !153 + %2678 = zext nneg i32 %2677 to i64, !dbg !153 + %2679 = or disjoint i64 %2678, 4611686293372403712, !dbg !153 + %2680 = add i32 %2580, 96, !dbg !153 + %2681 = lshr exact i32 %2680, 4, !dbg !153 + %2682 = and i32 %2681, 16383, !dbg !153 + %2683 = zext nneg i32 %2682 to i64, !dbg !153 + %2684 = or disjoint i64 %2683, 4611686293338849280, !dbg !153 + %2685 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2673, 0, !dbg !153 + %2686 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2673, 1, !dbg !153 + %2687 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2673, 2, !dbg !153 + %2688 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2673, 3, !dbg !153 + %2689 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2673, 4, !dbg !153 + %2690 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2673, 5, !dbg !153 + %2691 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2673, 6, !dbg !153 + %2692 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2673, 7, !dbg !153 + %2693 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2673, 8, !dbg !153 + %2694 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2673, 9, !dbg !153 + %2695 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2673, 10, !dbg !153 + %2696 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2673, 11, !dbg !153 + %2697 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2673, 12, !dbg !153 + %2698 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2673, 13, !dbg !153 + %2699 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2673, 14, !dbg !153 + %2700 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2673, 15, !dbg !153 + %2701 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2673, 16, !dbg !153 + %2702 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2673, 17, !dbg !153 + %2703 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2673, 18, !dbg !153 + %2704 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2673, 19, !dbg !153 + %2705 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2673, 20, !dbg !153 + %2706 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2673, 21, !dbg !153 + %2707 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2673, 22, !dbg !153 + %2708 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2673, 23, !dbg !153 + %2709 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2673, 24, !dbg !153 + %2710 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2673, 25, !dbg !153 + %2711 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2673, 26, !dbg !153 + %2712 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2673, 27, !dbg !153 + %2713 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2673, 28, !dbg !153 + %2714 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2673, 29, !dbg !153 + %2715 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2673, 30, !dbg !153 + %2716 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2673, 31, !dbg !153 + %2717 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %2685, float %2686, float %2687, float %2688, float %2689, float %2690, float %2691, float %2692, float %2693, float %2694, float %2695, float %2696, float %2697, float %2698, float %2699, float %2700, float %2701, float %2702, float %2703, float %2704, float %2705, float %2706, float %2707, float %2708, float %2709, float %2710, float %2711, float %2712, float %2713, float %2714, float %2715, float %2716, i64 %2679, i64 %2684, i1 true) #3, !dbg !153 + %2718 = or disjoint i32 %2574, 16384, !dbg !153 + %2719 = add i32 %2718, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !153 + %2720 = lshr exact i32 %2719, 4, !dbg !153 + %2721 = and i32 %2720, 16383, !dbg !153 + %2722 = zext nneg i32 %2721 to i64, !dbg !153 + %2723 = or disjoint i64 %2722, 4611686293372403712, !dbg !153 + %2724 = add i32 %2580, 8192, !dbg !153 + %2725 = lshr exact i32 %2724, 4, !dbg !153 + %2726 = and i32 %2725, 16383, !dbg !153 + %2727 = zext nneg i32 %2726 to i64, !dbg !153 + %2728 = or disjoint i64 %2727, 4611686293338849280, !dbg !153 + %2729 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2717, 0, !dbg !153 + %2730 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2717, 1, !dbg !153 + %2731 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2717, 2, !dbg !153 + %2732 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2717, 3, !dbg !153 + %2733 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2717, 4, !dbg !153 + %2734 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2717, 5, !dbg !153 + %2735 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2717, 6, !dbg !153 + %2736 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2717, 7, !dbg !153 + %2737 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2717, 8, !dbg !153 + %2738 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2717, 9, !dbg !153 + %2739 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2717, 10, !dbg !153 + %2740 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2717, 11, !dbg !153 + %2741 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2717, 12, !dbg !153 + %2742 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2717, 13, !dbg !153 + %2743 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2717, 14, !dbg !153 + %2744 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2717, 15, !dbg !153 + %2745 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2717, 16, !dbg !153 + %2746 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2717, 17, !dbg !153 + %2747 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2717, 18, !dbg !153 + %2748 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2717, 19, !dbg !153 + %2749 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2717, 20, !dbg !153 + %2750 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2717, 21, !dbg !153 + %2751 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2717, 22, !dbg !153 + %2752 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2717, 23, !dbg !153 + %2753 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2717, 24, !dbg !153 + %2754 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2717, 25, !dbg !153 + %2755 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2717, 26, !dbg !153 + %2756 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2717, 27, !dbg !153 + %2757 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2717, 28, !dbg !153 + %2758 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2717, 29, !dbg !153 + %2759 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2717, 30, !dbg !153 + %2760 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2717, 31, !dbg !153 + %2761 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %2729, float %2730, float %2731, float %2732, float %2733, float %2734, float %2735, float %2736, float %2737, float %2738, float %2739, float %2740, float %2741, float %2742, float %2743, float %2744, float %2745, float %2746, float %2747, float %2748, float %2749, float %2750, float %2751, float %2752, float %2753, float %2754, float %2755, float %2756, float %2757, float %2758, float %2759, float %2760, i64 %2723, i64 %2728, i1 true) #3, !dbg !153 + %2762 = or disjoint i32 %2574, 16416, !dbg !153 + %2763 = add i32 %2762, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !153 + %2764 = lshr exact i32 %2763, 4, !dbg !153 + %2765 = and i32 %2764, 16383, !dbg !153 + %2766 = zext nneg i32 %2765 to i64, !dbg !153 + %2767 = or disjoint i64 %2766, 4611686293372403712, !dbg !153 + %2768 = add i32 %2580, 8224, !dbg !153 + %2769 = lshr exact i32 %2768, 4, !dbg !153 + %2770 = and i32 %2769, 16383, !dbg !153 + %2771 = zext nneg i32 %2770 to i64, !dbg !153 + %2772 = or disjoint i64 %2771, 4611686293338849280, !dbg !153 + %2773 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2761, 0, !dbg !153 + %2774 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2761, 1, !dbg !153 + %2775 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2761, 2, !dbg !153 + %2776 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2761, 3, !dbg !153 + %2777 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2761, 4, !dbg !153 + %2778 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2761, 5, !dbg !153 + %2779 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2761, 6, !dbg !153 + %2780 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2761, 7, !dbg !153 + %2781 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2761, 8, !dbg !153 + %2782 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2761, 9, !dbg !153 + %2783 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2761, 10, !dbg !153 + %2784 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2761, 11, !dbg !153 + %2785 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2761, 12, !dbg !153 + %2786 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2761, 13, !dbg !153 + %2787 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2761, 14, !dbg !153 + %2788 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2761, 15, !dbg !153 + %2789 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2761, 16, !dbg !153 + %2790 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2761, 17, !dbg !153 + %2791 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2761, 18, !dbg !153 + %2792 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2761, 19, !dbg !153 + %2793 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2761, 20, !dbg !153 + %2794 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2761, 21, !dbg !153 + %2795 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2761, 22, !dbg !153 + %2796 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2761, 23, !dbg !153 + %2797 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2761, 24, !dbg !153 + %2798 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2761, 25, !dbg !153 + %2799 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2761, 26, !dbg !153 + %2800 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2761, 27, !dbg !153 + %2801 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2761, 28, !dbg !153 + %2802 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2761, 29, !dbg !153 + %2803 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2761, 30, !dbg !153 + %2804 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2761, 31, !dbg !153 + %2805 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %2773, float %2774, float %2775, float %2776, float %2777, float %2778, float %2779, float %2780, float %2781, float %2782, float %2783, float %2784, float %2785, float %2786, float %2787, float %2788, float %2789, float %2790, float %2791, float %2792, float %2793, float %2794, float %2795, float %2796, float %2797, float %2798, float %2799, float %2800, float %2801, float %2802, float %2803, float %2804, i64 %2767, i64 %2772, i1 true) #3, !dbg !153 + %2806 = or disjoint i32 %2574, 16448, !dbg !153 + %2807 = add i32 %2806, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !153 + %2808 = lshr exact i32 %2807, 4, !dbg !153 + %2809 = and i32 %2808, 16383, !dbg !153 + %2810 = zext nneg i32 %2809 to i64, !dbg !153 + %2811 = or disjoint i64 %2810, 4611686293372403712, !dbg !153 + %2812 = add i32 %2580, 8256, !dbg !153 + %2813 = lshr exact i32 %2812, 4, !dbg !153 + %2814 = and i32 %2813, 16383, !dbg !153 + %2815 = zext nneg i32 %2814 to i64, !dbg !153 + %2816 = or disjoint i64 %2815, 4611686293338849280, !dbg !153 + %2817 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2805, 0, !dbg !153 + %2818 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2805, 1, !dbg !153 + %2819 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2805, 2, !dbg !153 + %2820 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2805, 3, !dbg !153 + %2821 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2805, 4, !dbg !153 + %2822 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2805, 5, !dbg !153 + %2823 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2805, 6, !dbg !153 + %2824 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2805, 7, !dbg !153 + %2825 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2805, 8, !dbg !153 + %2826 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2805, 9, !dbg !153 + %2827 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2805, 10, !dbg !153 + %2828 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2805, 11, !dbg !153 + %2829 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2805, 12, !dbg !153 + %2830 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2805, 13, !dbg !153 + %2831 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2805, 14, !dbg !153 + %2832 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2805, 15, !dbg !153 + %2833 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2805, 16, !dbg !153 + %2834 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2805, 17, !dbg !153 + %2835 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2805, 18, !dbg !153 + %2836 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2805, 19, !dbg !153 + %2837 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2805, 20, !dbg !153 + %2838 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2805, 21, !dbg !153 + %2839 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2805, 22, !dbg !153 + %2840 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2805, 23, !dbg !153 + %2841 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2805, 24, !dbg !153 + %2842 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2805, 25, !dbg !153 + %2843 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2805, 26, !dbg !153 + %2844 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2805, 27, !dbg !153 + %2845 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2805, 28, !dbg !153 + %2846 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2805, 29, !dbg !153 + %2847 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2805, 30, !dbg !153 + %2848 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2805, 31, !dbg !153 + %2849 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %2817, float %2818, float %2819, float %2820, float %2821, float %2822, float %2823, float %2824, float %2825, float %2826, float %2827, float %2828, float %2829, float %2830, float %2831, float %2832, float %2833, float %2834, float %2835, float %2836, float %2837, float %2838, float %2839, float %2840, float %2841, float %2842, float %2843, float %2844, float %2845, float %2846, float %2847, float %2848, i64 %2811, i64 %2816, i1 true) #3, !dbg !153 + %2850 = or disjoint i32 %2574, 16480, !dbg !153 + %2851 = add i32 %2850, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !153 + %2852 = lshr exact i32 %2851, 4, !dbg !153 + %2853 = and i32 %2852, 16383, !dbg !153 + %2854 = zext nneg i32 %2853 to i64, !dbg !153 + %2855 = or disjoint i64 %2854, 4611686293372403712, !dbg !153 + %2856 = add i32 %2580, 8288, !dbg !153 + %2857 = lshr exact i32 %2856, 4, !dbg !153 + %2858 = and i32 %2857, 16383, !dbg !153 + %2859 = zext nneg i32 %2858 to i64, !dbg !153 + %2860 = or disjoint i64 %2859, 4611686293338849280, !dbg !153 + %2861 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2849, 0, !dbg !153 + %2862 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2849, 1, !dbg !153 + %2863 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2849, 2, !dbg !153 + %2864 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2849, 3, !dbg !153 + %2865 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2849, 4, !dbg !153 + %2866 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2849, 5, !dbg !153 + %2867 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2849, 6, !dbg !153 + %2868 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2849, 7, !dbg !153 + %2869 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2849, 8, !dbg !153 + %2870 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2849, 9, !dbg !153 + %2871 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2849, 10, !dbg !153 + %2872 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2849, 11, !dbg !153 + %2873 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2849, 12, !dbg !153 + %2874 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2849, 13, !dbg !153 + %2875 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2849, 14, !dbg !153 + %2876 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2849, 15, !dbg !153 + %2877 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2849, 16, !dbg !153 + %2878 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2849, 17, !dbg !153 + %2879 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2849, 18, !dbg !153 + %2880 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2849, 19, !dbg !153 + %2881 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2849, 20, !dbg !153 + %2882 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2849, 21, !dbg !153 + %2883 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2849, 22, !dbg !153 + %2884 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2849, 23, !dbg !153 + %2885 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2849, 24, !dbg !153 + %2886 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2849, 25, !dbg !153 + %2887 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2849, 26, !dbg !153 + %2888 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2849, 27, !dbg !153 + %2889 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2849, 28, !dbg !153 + %2890 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2849, 29, !dbg !153 + %2891 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2849, 30, !dbg !153 + %2892 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2849, 31, !dbg !153 + %2893 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %2861, float %2862, float %2863, float %2864, float %2865, float %2866, float %2867, float %2868, float %2869, float %2870, float %2871, float %2872, float %2873, float %2874, float %2875, float %2876, float %2877, float %2878, float %2879, float %2880, float %2881, float %2882, float %2883, float %2884, float %2885, float %2886, float %2887, float %2888, float %2889, float %2890, float %2891, float %2892, i64 %2855, i64 %2860, i1 true) #3, !dbg !153 + %2894 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2893, 0, !dbg !153 + %2895 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2893, 1, !dbg !153 + %2896 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2893, 2, !dbg !153 + %2897 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2893, 3, !dbg !153 + %2898 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2893, 4, !dbg !153 + %2899 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2893, 5, !dbg !153 + %2900 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2893, 6, !dbg !153 + %2901 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2893, 7, !dbg !153 + %2902 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2893, 8, !dbg !153 + %2903 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2893, 9, !dbg !153 + %2904 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2893, 10, !dbg !153 + %2905 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2893, 11, !dbg !153 + %2906 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2893, 12, !dbg !153 + %2907 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2893, 13, !dbg !153 + %2908 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2893, 14, !dbg !153 + %2909 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2893, 15, !dbg !153 + %2910 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2893, 16, !dbg !153 + %2911 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2893, 17, !dbg !153 + %2912 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2893, 18, !dbg !153 + %2913 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2893, 19, !dbg !153 + %2914 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2893, 20, !dbg !153 + %2915 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2893, 21, !dbg !153 + %2916 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2893, 22, !dbg !153 + %2917 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2893, 23, !dbg !153 + %2918 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2893, 24, !dbg !153 + %2919 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2893, 25, !dbg !153 + %2920 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2893, 26, !dbg !153 + %2921 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2893, 27, !dbg !153 + %2922 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2893, 28, !dbg !153 + %2923 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2893, 29, !dbg !153 + %2924 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2893, 30, !dbg !153 + %2925 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2893, 31, !dbg !153 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !153 + %2926 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } asm sideeffect "// wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37\0A\09wgmma.wait_group.sync.aligned 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37"(float %2894, float %2895, float %2896, float %2897, float %2898, float %2899, float %2900, float %2901, float %2902, float %2903, float %2904, float %2905, float %2906, float %2907, float %2908, float %2909, float %2910, float %2911, float %2912, float %2913, float %2914, float %2915, float %2916, float %2917, float %2918, float %2919, float %2920, float %2921, float %2922, float %2923, float %2924, float %2925, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 0, i32 0, ptr addrspace(3) %2571, i32 0, i32 0) #3, !dbg !153 + %2927 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %2926, 0, !dbg !153 + %2928 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %2926, 1, !dbg !153 + %2929 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %2926, 2, !dbg !153 + %2930 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %2926, 3, !dbg !153 + %2931 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %2926, 4, !dbg !153 + %2932 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %2926, 5, !dbg !153 + %2933 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %2926, 6, !dbg !153 + %2934 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %2926, 7, !dbg !153 + %2935 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %2926, 8, !dbg !153 + %2936 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %2926, 9, !dbg !153 + %2937 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %2926, 10, !dbg !153 + %2938 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %2926, 11, !dbg !153 + %2939 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %2926, 12, !dbg !153 + %2940 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %2926, 13, !dbg !153 + %2941 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %2926, 14, !dbg !153 + %2942 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %2926, 15, !dbg !153 + %2943 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %2926, 16, !dbg !153 + %2944 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %2926, 17, !dbg !153 + %2945 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %2926, 18, !dbg !153 + %2946 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %2926, 19, !dbg !153 + %2947 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %2926, 20, !dbg !153 + %2948 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %2926, 21, !dbg !153 + %2949 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %2926, 22, !dbg !153 + %2950 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %2926, 23, !dbg !153 + %2951 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %2926, 24, !dbg !153 + %2952 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %2926, 25, !dbg !153 + %2953 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %2926, 26, !dbg !153 + %2954 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %2926, 27, !dbg !153 + %2955 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %2926, 28, !dbg !153 + %2956 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %2926, 29, !dbg !153 + %2957 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %2926, 30, !dbg !153 + %2958 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %2926, 31, !dbg !153 + %2959 = fmul float %2927, 0x3FB6A09E60000000, !dbg !156 + %2960 = fmul float %2928, 0x3FB6A09E60000000, !dbg !156 + %2961 = fmul float %2929, 0x3FB6A09E60000000, !dbg !156 + %2962 = fmul float %2930, 0x3FB6A09E60000000, !dbg !156 + %2963 = fmul float %2931, 0x3FB6A09E60000000, !dbg !156 + %2964 = fmul float %2932, 0x3FB6A09E60000000, !dbg !156 + %2965 = fmul float %2933, 0x3FB6A09E60000000, !dbg !156 + %2966 = fmul float %2934, 0x3FB6A09E60000000, !dbg !156 + %2967 = fmul float %2935, 0x3FB6A09E60000000, !dbg !156 + %2968 = fmul float %2936, 0x3FB6A09E60000000, !dbg !156 + %2969 = fmul float %2937, 0x3FB6A09E60000000, !dbg !156 + %2970 = fmul float %2938, 0x3FB6A09E60000000, !dbg !156 + %2971 = fmul float %2939, 0x3FB6A09E60000000, !dbg !156 + %2972 = fmul float %2940, 0x3FB6A09E60000000, !dbg !156 + %2973 = fmul float %2941, 0x3FB6A09E60000000, !dbg !156 + %2974 = fmul float %2942, 0x3FB6A09E60000000, !dbg !156 + %2975 = fmul float %2943, 0x3FB6A09E60000000, !dbg !156 + %2976 = fmul float %2944, 0x3FB6A09E60000000, !dbg !156 + %2977 = fmul float %2945, 0x3FB6A09E60000000, !dbg !156 + %2978 = fmul float %2946, 0x3FB6A09E60000000, !dbg !156 + %2979 = fmul float %2947, 0x3FB6A09E60000000, !dbg !156 + %2980 = fmul float %2948, 0x3FB6A09E60000000, !dbg !156 + %2981 = fmul float %2949, 0x3FB6A09E60000000, !dbg !156 + %2982 = fmul float %2950, 0x3FB6A09E60000000, !dbg !156 + %2983 = fmul float %2951, 0x3FB6A09E60000000, !dbg !156 + %2984 = fmul float %2952, 0x3FB6A09E60000000, !dbg !156 + %2985 = fmul float %2953, 0x3FB6A09E60000000, !dbg !156 + %2986 = fmul float %2954, 0x3FB6A09E60000000, !dbg !156 + %2987 = fmul float %2955, 0x3FB6A09E60000000, !dbg !156 + %2988 = fmul float %2956, 0x3FB6A09E60000000, !dbg !156 + %2989 = fmul float %2957, 0x3FB6A09E60000000, !dbg !156 + %2990 = fmul float %2958, 0x3FB6A09E60000000, !dbg !156 + %2991 = fmul float %2959, 0x3FF7154760000000, !dbg !157 + %2992 = fmul float %2960, 0x3FF7154760000000, !dbg !157 + %2993 = fmul float %2961, 0x3FF7154760000000, !dbg !157 + %2994 = fmul float %2962, 0x3FF7154760000000, !dbg !157 + %2995 = fmul float %2963, 0x3FF7154760000000, !dbg !157 + %2996 = fmul float %2964, 0x3FF7154760000000, !dbg !157 + %2997 = fmul float %2965, 0x3FF7154760000000, !dbg !157 + %2998 = fmul float %2966, 0x3FF7154760000000, !dbg !157 + %2999 = fmul float %2967, 0x3FF7154760000000, !dbg !157 + %3000 = fmul float %2968, 0x3FF7154760000000, !dbg !157 + %3001 = fmul float %2969, 0x3FF7154760000000, !dbg !157 + %3002 = fmul float %2970, 0x3FF7154760000000, !dbg !157 + %3003 = fmul float %2971, 0x3FF7154760000000, !dbg !157 + %3004 = fmul float %2972, 0x3FF7154760000000, !dbg !157 + %3005 = fmul float %2973, 0x3FF7154760000000, !dbg !157 + %3006 = fmul float %2974, 0x3FF7154760000000, !dbg !157 + %3007 = fmul float %2975, 0x3FF7154760000000, !dbg !157 + %3008 = fmul float %2976, 0x3FF7154760000000, !dbg !157 + %3009 = fmul float %2977, 0x3FF7154760000000, !dbg !157 + %3010 = fmul float %2978, 0x3FF7154760000000, !dbg !157 + %3011 = fmul float %2979, 0x3FF7154760000000, !dbg !157 + %3012 = fmul float %2980, 0x3FF7154760000000, !dbg !157 + %3013 = fmul float %2981, 0x3FF7154760000000, !dbg !157 + %3014 = fmul float %2982, 0x3FF7154760000000, !dbg !157 + %3015 = fmul float %2983, 0x3FF7154760000000, !dbg !157 + %3016 = fmul float %2984, 0x3FF7154760000000, !dbg !157 + %3017 = fmul float %2985, 0x3FF7154760000000, !dbg !157 + %3018 = fmul float %2986, 0x3FF7154760000000, !dbg !157 + %3019 = fmul float %2987, 0x3FF7154760000000, !dbg !157 + %3020 = fmul float %2988, 0x3FF7154760000000, !dbg !157 + %3021 = fmul float %2989, 0x3FF7154760000000, !dbg !157 + %3022 = fmul float %2990, 0x3FF7154760000000, !dbg !157 + %3023 = fsub float %2991, %356, !dbg !158 + %3024 = fsub float %2992, %356, !dbg !158 + %3025 = fsub float %2993, %357, !dbg !158 + %3026 = fsub float %2994, %357, !dbg !158 + %3027 = fsub float %2995, %356, !dbg !158 + %3028 = fsub float %2996, %356, !dbg !158 + %3029 = fsub float %2997, %357, !dbg !158 + %3030 = fsub float %2998, %357, !dbg !158 + %3031 = fsub float %2999, %356, !dbg !158 + %3032 = fsub float %3000, %356, !dbg !158 + %3033 = fsub float %3001, %357, !dbg !158 + %3034 = fsub float %3002, %357, !dbg !158 + %3035 = fsub float %3003, %356, !dbg !158 + %3036 = fsub float %3004, %356, !dbg !158 + %3037 = fsub float %3005, %357, !dbg !158 + %3038 = fsub float %3006, %357, !dbg !158 + %3039 = fsub float %3007, %356, !dbg !158 + %3040 = fsub float %3008, %356, !dbg !158 + %3041 = fsub float %3009, %357, !dbg !158 + %3042 = fsub float %3010, %357, !dbg !158 + %3043 = fsub float %3011, %356, !dbg !158 + %3044 = fsub float %3012, %356, !dbg !158 + %3045 = fsub float %3013, %357, !dbg !158 + %3046 = fsub float %3014, %357, !dbg !158 + %3047 = fsub float %3015, %356, !dbg !158 + %3048 = fsub float %3016, %356, !dbg !158 + %3049 = fsub float %3017, %357, !dbg !158 + %3050 = fsub float %3018, %357, !dbg !158 + %3051 = fsub float %3019, %356, !dbg !158 + %3052 = fsub float %3020, %356, !dbg !158 + %3053 = fsub float %3021, %357, !dbg !158 + %3054 = fsub float %3022, %357, !dbg !158 + %3055 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !159 + %.not.i1316 = icmp eq i32 %3055, 0, !dbg !159 + br i1 %.not.i1316, label %3058, label %3056, !dbg !159 + +3056: ; preds = %2561 + %3057 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3023) #3, !dbg !159 + br label %__nv_exp2f.exit1318, !dbg !159 + +3058: ; preds = %2561 + %3059 = tail call float @llvm.nvvm.ex2.approx.f(float %3023) #3, !dbg !159 + br label %__nv_exp2f.exit1318, !dbg !159 + +__nv_exp2f.exit1318: ; preds = %3056, %3058 + %.0.i1317 = phi float [ %3057, %3056 ], [ %3059, %3058 ], !dbg !159 + %3060 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !159 + %.not.i1319 = icmp eq i32 %3060, 0, !dbg !159 + br i1 %.not.i1319, label %3063, label %3061, !dbg !159 + +3061: ; preds = %__nv_exp2f.exit1318 + %3062 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3024) #3, !dbg !159 + br label %__nv_exp2f.exit1321, !dbg !159 + +3063: ; preds = %__nv_exp2f.exit1318 + %3064 = tail call float @llvm.nvvm.ex2.approx.f(float %3024) #3, !dbg !159 + br label %__nv_exp2f.exit1321, !dbg !159 + +__nv_exp2f.exit1321: ; preds = %3061, %3063 + %.0.i1320 = phi float [ %3062, %3061 ], [ %3064, %3063 ], !dbg !159 + %3065 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !159 + %.not.i1322 = icmp eq i32 %3065, 0, !dbg !159 + br i1 %.not.i1322, label %3068, label %3066, !dbg !159 + +3066: ; preds = %__nv_exp2f.exit1321 + %3067 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3025) #3, !dbg !159 + br label %__nv_exp2f.exit1324, !dbg !159 + +3068: ; preds = %__nv_exp2f.exit1321 + %3069 = tail call float @llvm.nvvm.ex2.approx.f(float %3025) #3, !dbg !159 + br label %__nv_exp2f.exit1324, !dbg !159 + +__nv_exp2f.exit1324: ; preds = %3066, %3068 + %.0.i1323 = phi float [ %3067, %3066 ], [ %3069, %3068 ], !dbg !159 + %3070 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !159 + %.not.i1325 = icmp eq i32 %3070, 0, !dbg !159 + br i1 %.not.i1325, label %3073, label %3071, !dbg !159 + +3071: ; preds = %__nv_exp2f.exit1324 + %3072 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3026) #3, !dbg !159 + br label %__nv_exp2f.exit1327, !dbg !159 + +3073: ; preds = %__nv_exp2f.exit1324 + %3074 = tail call float @llvm.nvvm.ex2.approx.f(float %3026) #3, !dbg !159 + br label %__nv_exp2f.exit1327, !dbg !159 + +__nv_exp2f.exit1327: ; preds = %3071, %3073 + %.0.i1326 = phi float [ %3072, %3071 ], [ %3074, %3073 ], !dbg !159 + %3075 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !159 + %.not.i1328 = icmp eq i32 %3075, 0, !dbg !159 + br i1 %.not.i1328, label %3078, label %3076, !dbg !159 + +3076: ; preds = %__nv_exp2f.exit1327 + %3077 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3027) #3, !dbg !159 + br label %__nv_exp2f.exit1330, !dbg !159 + +3078: ; preds = %__nv_exp2f.exit1327 + %3079 = tail call float @llvm.nvvm.ex2.approx.f(float %3027) #3, !dbg !159 + br label %__nv_exp2f.exit1330, !dbg !159 + +__nv_exp2f.exit1330: ; preds = %3076, %3078 + %.0.i1329 = phi float [ %3077, %3076 ], [ %3079, %3078 ], !dbg !159 + %3080 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !159 + %.not.i1331 = icmp eq i32 %3080, 0, !dbg !159 + br i1 %.not.i1331, label %3083, label %3081, !dbg !159 + +3081: ; preds = %__nv_exp2f.exit1330 + %3082 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3028) #3, !dbg !159 + br label %__nv_exp2f.exit1333, !dbg !159 + +3083: ; preds = %__nv_exp2f.exit1330 + %3084 = tail call float @llvm.nvvm.ex2.approx.f(float %3028) #3, !dbg !159 + br label %__nv_exp2f.exit1333, !dbg !159 + +__nv_exp2f.exit1333: ; preds = %3081, %3083 + %.0.i1332 = phi float [ %3082, %3081 ], [ %3084, %3083 ], !dbg !159 + %3085 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !159 + %.not.i1334 = icmp eq i32 %3085, 0, !dbg !159 + br i1 %.not.i1334, label %3088, label %3086, !dbg !159 + +3086: ; preds = %__nv_exp2f.exit1333 + %3087 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3029) #3, !dbg !159 + br label %__nv_exp2f.exit1336, !dbg !159 + +3088: ; preds = %__nv_exp2f.exit1333 + %3089 = tail call float @llvm.nvvm.ex2.approx.f(float %3029) #3, !dbg !159 + br label %__nv_exp2f.exit1336, !dbg !159 + +__nv_exp2f.exit1336: ; preds = %3086, %3088 + %.0.i1335 = phi float [ %3087, %3086 ], [ %3089, %3088 ], !dbg !159 + %3090 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !159 + %.not.i1337 = icmp eq i32 %3090, 0, !dbg !159 + br i1 %.not.i1337, label %3093, label %3091, !dbg !159 + +3091: ; preds = %__nv_exp2f.exit1336 + %3092 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3030) #3, !dbg !159 + br label %__nv_exp2f.exit1339, !dbg !159 + +3093: ; preds = %__nv_exp2f.exit1336 + %3094 = tail call float @llvm.nvvm.ex2.approx.f(float %3030) #3, !dbg !159 + br label %__nv_exp2f.exit1339, !dbg !159 + +__nv_exp2f.exit1339: ; preds = %3091, %3093 + %.0.i1338 = phi float [ %3092, %3091 ], [ %3094, %3093 ], !dbg !159 + %3095 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !159 + %.not.i1340 = icmp eq i32 %3095, 0, !dbg !159 + br i1 %.not.i1340, label %3098, label %3096, !dbg !159 + +3096: ; preds = %__nv_exp2f.exit1339 + %3097 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3031) #3, !dbg !159 + br label %__nv_exp2f.exit1342, !dbg !159 + +3098: ; preds = %__nv_exp2f.exit1339 + %3099 = tail call float @llvm.nvvm.ex2.approx.f(float %3031) #3, !dbg !159 + br label %__nv_exp2f.exit1342, !dbg !159 + +__nv_exp2f.exit1342: ; preds = %3096, %3098 + %.0.i1341 = phi float [ %3097, %3096 ], [ %3099, %3098 ], !dbg !159 + %3100 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !159 + %.not.i1343 = icmp eq i32 %3100, 0, !dbg !159 + br i1 %.not.i1343, label %3103, label %3101, !dbg !159 + +3101: ; preds = %__nv_exp2f.exit1342 + %3102 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3032) #3, !dbg !159 + br label %__nv_exp2f.exit1345, !dbg !159 + +3103: ; preds = %__nv_exp2f.exit1342 + %3104 = tail call float @llvm.nvvm.ex2.approx.f(float %3032) #3, !dbg !159 + br label %__nv_exp2f.exit1345, !dbg !159 + +__nv_exp2f.exit1345: ; preds = %3101, %3103 + %.0.i1344 = phi float [ %3102, %3101 ], [ %3104, %3103 ], !dbg !159 + %3105 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !159 + %.not.i1346 = icmp eq i32 %3105, 0, !dbg !159 + br i1 %.not.i1346, label %3108, label %3106, !dbg !159 + +3106: ; preds = %__nv_exp2f.exit1345 + %3107 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3033) #3, !dbg !159 + br label %__nv_exp2f.exit1348, !dbg !159 + +3108: ; preds = %__nv_exp2f.exit1345 + %3109 = tail call float @llvm.nvvm.ex2.approx.f(float %3033) #3, !dbg !159 + br label %__nv_exp2f.exit1348, !dbg !159 + +__nv_exp2f.exit1348: ; preds = %3106, %3108 + %.0.i1347 = phi float [ %3107, %3106 ], [ %3109, %3108 ], !dbg !159 + %3110 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !159 + %.not.i1349 = icmp eq i32 %3110, 0, !dbg !159 + br i1 %.not.i1349, label %3113, label %3111, !dbg !159 + +3111: ; preds = %__nv_exp2f.exit1348 + %3112 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3034) #3, !dbg !159 + br label %__nv_exp2f.exit1351, !dbg !159 + +3113: ; preds = %__nv_exp2f.exit1348 + %3114 = tail call float @llvm.nvvm.ex2.approx.f(float %3034) #3, !dbg !159 + br label %__nv_exp2f.exit1351, !dbg !159 + +__nv_exp2f.exit1351: ; preds = %3111, %3113 + %.0.i1350 = phi float [ %3112, %3111 ], [ %3114, %3113 ], !dbg !159 + %3115 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !159 + %.not.i1352 = icmp eq i32 %3115, 0, !dbg !159 + br i1 %.not.i1352, label %3118, label %3116, !dbg !159 + +3116: ; preds = %__nv_exp2f.exit1351 + %3117 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3035) #3, !dbg !159 + br label %__nv_exp2f.exit1354, !dbg !159 + +3118: ; preds = %__nv_exp2f.exit1351 + %3119 = tail call float @llvm.nvvm.ex2.approx.f(float %3035) #3, !dbg !159 + br label %__nv_exp2f.exit1354, !dbg !159 + +__nv_exp2f.exit1354: ; preds = %3116, %3118 + %.0.i1353 = phi float [ %3117, %3116 ], [ %3119, %3118 ], !dbg !159 + %3120 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !159 + %.not.i1355 = icmp eq i32 %3120, 0, !dbg !159 + br i1 %.not.i1355, label %3123, label %3121, !dbg !159 + +3121: ; preds = %__nv_exp2f.exit1354 + %3122 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3036) #3, !dbg !159 + br label %__nv_exp2f.exit1357, !dbg !159 + +3123: ; preds = %__nv_exp2f.exit1354 + %3124 = tail call float @llvm.nvvm.ex2.approx.f(float %3036) #3, !dbg !159 + br label %__nv_exp2f.exit1357, !dbg !159 + +__nv_exp2f.exit1357: ; preds = %3121, %3123 + %.0.i1356 = phi float [ %3122, %3121 ], [ %3124, %3123 ], !dbg !159 + %3125 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !159 + %.not.i1358 = icmp eq i32 %3125, 0, !dbg !159 + br i1 %.not.i1358, label %3128, label %3126, !dbg !159 + +3126: ; preds = %__nv_exp2f.exit1357 + %3127 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3037) #3, !dbg !159 + br label %__nv_exp2f.exit1360, !dbg !159 + +3128: ; preds = %__nv_exp2f.exit1357 + %3129 = tail call float @llvm.nvvm.ex2.approx.f(float %3037) #3, !dbg !159 + br label %__nv_exp2f.exit1360, !dbg !159 + +__nv_exp2f.exit1360: ; preds = %3126, %3128 + %.0.i1359 = phi float [ %3127, %3126 ], [ %3129, %3128 ], !dbg !159 + %3130 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !159 + %.not.i1361 = icmp eq i32 %3130, 0, !dbg !159 + br i1 %.not.i1361, label %3133, label %3131, !dbg !159 + +3131: ; preds = %__nv_exp2f.exit1360 + %3132 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3038) #3, !dbg !159 + br label %__nv_exp2f.exit1363, !dbg !159 + +3133: ; preds = %__nv_exp2f.exit1360 + %3134 = tail call float @llvm.nvvm.ex2.approx.f(float %3038) #3, !dbg !159 + br label %__nv_exp2f.exit1363, !dbg !159 + +__nv_exp2f.exit1363: ; preds = %3131, %3133 + %.0.i1362 = phi float [ %3132, %3131 ], [ %3134, %3133 ], !dbg !159 + %3135 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !159 + %.not.i1364 = icmp eq i32 %3135, 0, !dbg !159 + br i1 %.not.i1364, label %3138, label %3136, !dbg !159 + +3136: ; preds = %__nv_exp2f.exit1363 + %3137 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3039) #3, !dbg !159 + br label %__nv_exp2f.exit1366, !dbg !159 + +3138: ; preds = %__nv_exp2f.exit1363 + %3139 = tail call float @llvm.nvvm.ex2.approx.f(float %3039) #3, !dbg !159 + br label %__nv_exp2f.exit1366, !dbg !159 + +__nv_exp2f.exit1366: ; preds = %3136, %3138 + %.0.i1365 = phi float [ %3137, %3136 ], [ %3139, %3138 ], !dbg !159 + %3140 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !159 + %.not.i1367 = icmp eq i32 %3140, 0, !dbg !159 + br i1 %.not.i1367, label %3143, label %3141, !dbg !159 + +3141: ; preds = %__nv_exp2f.exit1366 + %3142 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3040) #3, !dbg !159 + br label %__nv_exp2f.exit1369, !dbg !159 + +3143: ; preds = %__nv_exp2f.exit1366 + %3144 = tail call float @llvm.nvvm.ex2.approx.f(float %3040) #3, !dbg !159 + br label %__nv_exp2f.exit1369, !dbg !159 + +__nv_exp2f.exit1369: ; preds = %3141, %3143 + %.0.i1368 = phi float [ %3142, %3141 ], [ %3144, %3143 ], !dbg !159 + %3145 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !159 + %.not.i1370 = icmp eq i32 %3145, 0, !dbg !159 + br i1 %.not.i1370, label %3148, label %3146, !dbg !159 + +3146: ; preds = %__nv_exp2f.exit1369 + %3147 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3041) #3, !dbg !159 + br label %__nv_exp2f.exit1372, !dbg !159 + +3148: ; preds = %__nv_exp2f.exit1369 + %3149 = tail call float @llvm.nvvm.ex2.approx.f(float %3041) #3, !dbg !159 + br label %__nv_exp2f.exit1372, !dbg !159 + +__nv_exp2f.exit1372: ; preds = %3146, %3148 + %.0.i1371 = phi float [ %3147, %3146 ], [ %3149, %3148 ], !dbg !159 + %3150 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !159 + %.not.i1373 = icmp eq i32 %3150, 0, !dbg !159 + br i1 %.not.i1373, label %3153, label %3151, !dbg !159 + +3151: ; preds = %__nv_exp2f.exit1372 + %3152 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3042) #3, !dbg !159 + br label %__nv_exp2f.exit1375, !dbg !159 + +3153: ; preds = %__nv_exp2f.exit1372 + %3154 = tail call float @llvm.nvvm.ex2.approx.f(float %3042) #3, !dbg !159 + br label %__nv_exp2f.exit1375, !dbg !159 + +__nv_exp2f.exit1375: ; preds = %3151, %3153 + %.0.i1374 = phi float [ %3152, %3151 ], [ %3154, %3153 ], !dbg !159 + %3155 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !159 + %.not.i1376 = icmp eq i32 %3155, 0, !dbg !159 + br i1 %.not.i1376, label %3158, label %3156, !dbg !159 + +3156: ; preds = %__nv_exp2f.exit1375 + %3157 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3043) #3, !dbg !159 + br label %__nv_exp2f.exit1378, !dbg !159 + +3158: ; preds = %__nv_exp2f.exit1375 + %3159 = tail call float @llvm.nvvm.ex2.approx.f(float %3043) #3, !dbg !159 + br label %__nv_exp2f.exit1378, !dbg !159 + +__nv_exp2f.exit1378: ; preds = %3156, %3158 + %.0.i1377 = phi float [ %3157, %3156 ], [ %3159, %3158 ], !dbg !159 + %3160 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !159 + %.not.i1379 = icmp eq i32 %3160, 0, !dbg !159 + br i1 %.not.i1379, label %3163, label %3161, !dbg !159 + +3161: ; preds = %__nv_exp2f.exit1378 + %3162 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3044) #3, !dbg !159 + br label %__nv_exp2f.exit1381, !dbg !159 + +3163: ; preds = %__nv_exp2f.exit1378 + %3164 = tail call float @llvm.nvvm.ex2.approx.f(float %3044) #3, !dbg !159 + br label %__nv_exp2f.exit1381, !dbg !159 + +__nv_exp2f.exit1381: ; preds = %3161, %3163 + %.0.i1380 = phi float [ %3162, %3161 ], [ %3164, %3163 ], !dbg !159 + %3165 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !159 + %.not.i1382 = icmp eq i32 %3165, 0, !dbg !159 + br i1 %.not.i1382, label %3168, label %3166, !dbg !159 + +3166: ; preds = %__nv_exp2f.exit1381 + %3167 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3045) #3, !dbg !159 + br label %__nv_exp2f.exit1384, !dbg !159 + +3168: ; preds = %__nv_exp2f.exit1381 + %3169 = tail call float @llvm.nvvm.ex2.approx.f(float %3045) #3, !dbg !159 + br label %__nv_exp2f.exit1384, !dbg !159 + +__nv_exp2f.exit1384: ; preds = %3166, %3168 + %.0.i1383 = phi float [ %3167, %3166 ], [ %3169, %3168 ], !dbg !159 + %3170 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !159 + %.not.i1385 = icmp eq i32 %3170, 0, !dbg !159 + br i1 %.not.i1385, label %3173, label %3171, !dbg !159 + +3171: ; preds = %__nv_exp2f.exit1384 + %3172 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3046) #3, !dbg !159 + br label %__nv_exp2f.exit1387, !dbg !159 + +3173: ; preds = %__nv_exp2f.exit1384 + %3174 = tail call float @llvm.nvvm.ex2.approx.f(float %3046) #3, !dbg !159 + br label %__nv_exp2f.exit1387, !dbg !159 + +__nv_exp2f.exit1387: ; preds = %3171, %3173 + %.0.i1386 = phi float [ %3172, %3171 ], [ %3174, %3173 ], !dbg !159 + %3175 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !159 + %.not.i1388 = icmp eq i32 %3175, 0, !dbg !159 + br i1 %.not.i1388, label %3178, label %3176, !dbg !159 + +3176: ; preds = %__nv_exp2f.exit1387 + %3177 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3047) #3, !dbg !159 + br label %__nv_exp2f.exit1390, !dbg !159 + +3178: ; preds = %__nv_exp2f.exit1387 + %3179 = tail call float @llvm.nvvm.ex2.approx.f(float %3047) #3, !dbg !159 + br label %__nv_exp2f.exit1390, !dbg !159 + +__nv_exp2f.exit1390: ; preds = %3176, %3178 + %.0.i1389 = phi float [ %3177, %3176 ], [ %3179, %3178 ], !dbg !159 + %3180 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !159 + %.not.i1391 = icmp eq i32 %3180, 0, !dbg !159 + br i1 %.not.i1391, label %3183, label %3181, !dbg !159 + +3181: ; preds = %__nv_exp2f.exit1390 + %3182 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3048) #3, !dbg !159 + br label %__nv_exp2f.exit1393, !dbg !159 + +3183: ; preds = %__nv_exp2f.exit1390 + %3184 = tail call float @llvm.nvvm.ex2.approx.f(float %3048) #3, !dbg !159 + br label %__nv_exp2f.exit1393, !dbg !159 + +__nv_exp2f.exit1393: ; preds = %3181, %3183 + %.0.i1392 = phi float [ %3182, %3181 ], [ %3184, %3183 ], !dbg !159 + %3185 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !159 + %.not.i1394 = icmp eq i32 %3185, 0, !dbg !159 + br i1 %.not.i1394, label %3188, label %3186, !dbg !159 + +3186: ; preds = %__nv_exp2f.exit1393 + %3187 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3049) #3, !dbg !159 + br label %__nv_exp2f.exit1396, !dbg !159 + +3188: ; preds = %__nv_exp2f.exit1393 + %3189 = tail call float @llvm.nvvm.ex2.approx.f(float %3049) #3, !dbg !159 + br label %__nv_exp2f.exit1396, !dbg !159 + +__nv_exp2f.exit1396: ; preds = %3186, %3188 + %.0.i1395 = phi float [ %3187, %3186 ], [ %3189, %3188 ], !dbg !159 + %3190 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !159 + %.not.i1397 = icmp eq i32 %3190, 0, !dbg !159 + br i1 %.not.i1397, label %3193, label %3191, !dbg !159 + +3191: ; preds = %__nv_exp2f.exit1396 + %3192 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3050) #3, !dbg !159 + br label %__nv_exp2f.exit1399, !dbg !159 + +3193: ; preds = %__nv_exp2f.exit1396 + %3194 = tail call float @llvm.nvvm.ex2.approx.f(float %3050) #3, !dbg !159 + br label %__nv_exp2f.exit1399, !dbg !159 + +__nv_exp2f.exit1399: ; preds = %3191, %3193 + %.0.i1398 = phi float [ %3192, %3191 ], [ %3194, %3193 ], !dbg !159 + %3195 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !159 + %.not.i1400 = icmp eq i32 %3195, 0, !dbg !159 + br i1 %.not.i1400, label %3198, label %3196, !dbg !159 + +3196: ; preds = %__nv_exp2f.exit1399 + %3197 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3051) #3, !dbg !159 + br label %__nv_exp2f.exit1402, !dbg !159 + +3198: ; preds = %__nv_exp2f.exit1399 + %3199 = tail call float @llvm.nvvm.ex2.approx.f(float %3051) #3, !dbg !159 + br label %__nv_exp2f.exit1402, !dbg !159 + +__nv_exp2f.exit1402: ; preds = %3196, %3198 + %.0.i1401 = phi float [ %3197, %3196 ], [ %3199, %3198 ], !dbg !159 + %3200 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !159 + %.not.i1403 = icmp eq i32 %3200, 0, !dbg !159 + br i1 %.not.i1403, label %3203, label %3201, !dbg !159 + +3201: ; preds = %__nv_exp2f.exit1402 + %3202 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3052) #3, !dbg !159 + br label %__nv_exp2f.exit1405, !dbg !159 + +3203: ; preds = %__nv_exp2f.exit1402 + %3204 = tail call float @llvm.nvvm.ex2.approx.f(float %3052) #3, !dbg !159 + br label %__nv_exp2f.exit1405, !dbg !159 + +__nv_exp2f.exit1405: ; preds = %3201, %3203 + %.0.i1404 = phi float [ %3202, %3201 ], [ %3204, %3203 ], !dbg !159 + %3205 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !159 + %.not.i1406 = icmp eq i32 %3205, 0, !dbg !159 + br i1 %.not.i1406, label %3208, label %3206, !dbg !159 + +3206: ; preds = %__nv_exp2f.exit1405 + %3207 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3053) #3, !dbg !159 + br label %__nv_exp2f.exit1408, !dbg !159 + +3208: ; preds = %__nv_exp2f.exit1405 + %3209 = tail call float @llvm.nvvm.ex2.approx.f(float %3053) #3, !dbg !159 + br label %__nv_exp2f.exit1408, !dbg !159 + +__nv_exp2f.exit1408: ; preds = %3206, %3208 + %.0.i1407 = phi float [ %3207, %3206 ], [ %3209, %3208 ], !dbg !159 + %3210 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !159 + %.not.i1409 = icmp eq i32 %3210, 0, !dbg !159 + br i1 %.not.i1409, label %3213, label %3211, !dbg !159 + +3211: ; preds = %__nv_exp2f.exit1408 + %3212 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3054) #3, !dbg !159 + br label %__nv_exp2f.exit1411, !dbg !159 + +3213: ; preds = %__nv_exp2f.exit1408 + %3214 = tail call float @llvm.nvvm.ex2.approx.f(float %3054) #3, !dbg !159 + br label %__nv_exp2f.exit1411, !dbg !159 + +__nv_exp2f.exit1411: ; preds = %3211, %3213 + %.0.i1410 = phi float [ %3212, %3211 ], [ %3214, %3213 ], !dbg !159 + %3215 = getelementptr bfloat, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %2570, !dbg !150 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !160 + %3216 = add i32 %2574, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !160 + %3217 = lshr exact i32 %3216, 4, !dbg !160 + %3218 = and i32 %3217, 16383, !dbg !160 + %3219 = zext nneg i32 %3218 to i64, !dbg !160 + %3220 = or disjoint i64 %3219, 4611686293372403712, !dbg !160 + %3221 = ptrtoint ptr addrspace(3) %3215 to i32, !dbg !160 + %3222 = lshr exact i32 %3221, 4, !dbg !160 + %3223 = and i32 %3222, 16383, !dbg !160 + %3224 = zext nneg i32 %3223 to i64, !dbg !160 + %3225 = or disjoint i64 %3224, 4611686293338849280, !dbg !160 + %3226 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $32, $33, 0, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,l,l"(i64 %3220, i64 %3225) #3, !dbg !160 + %3227 = add i32 %2586, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !160 + %3228 = lshr exact i32 %3227, 4, !dbg !160 + %3229 = and i32 %3228, 16383, !dbg !160 + %3230 = zext nneg i32 %3229 to i64, !dbg !160 + %3231 = or disjoint i64 %3230, 4611686293372403712, !dbg !160 + %3232 = add i32 %3221, 32, !dbg !160 + %3233 = lshr exact i32 %3232, 4, !dbg !160 + %3234 = and i32 %3233, 16383, !dbg !160 + %3235 = zext nneg i32 %3234 to i64, !dbg !160 + %3236 = or disjoint i64 %3235, 4611686293338849280, !dbg !160 + %3237 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3226, 0, !dbg !160 + %3238 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3226, 1, !dbg !160 + %3239 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3226, 2, !dbg !160 + %3240 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3226, 3, !dbg !160 + %3241 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3226, 4, !dbg !160 + %3242 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3226, 5, !dbg !160 + %3243 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3226, 6, !dbg !160 + %3244 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3226, 7, !dbg !160 + %3245 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3226, 8, !dbg !160 + %3246 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3226, 9, !dbg !160 + %3247 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3226, 10, !dbg !160 + %3248 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3226, 11, !dbg !160 + %3249 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3226, 12, !dbg !160 + %3250 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3226, 13, !dbg !160 + %3251 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3226, 14, !dbg !160 + %3252 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3226, 15, !dbg !160 + %3253 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3226, 16, !dbg !160 + %3254 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3226, 17, !dbg !160 + %3255 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3226, 18, !dbg !160 + %3256 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3226, 19, !dbg !160 + %3257 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3226, 20, !dbg !160 + %3258 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3226, 21, !dbg !160 + %3259 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3226, 22, !dbg !160 + %3260 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3226, 23, !dbg !160 + %3261 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3226, 24, !dbg !160 + %3262 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3226, 25, !dbg !160 + %3263 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3226, 26, !dbg !160 + %3264 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3226, 27, !dbg !160 + %3265 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3226, 28, !dbg !160 + %3266 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3226, 29, !dbg !160 + %3267 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3226, 30, !dbg !160 + %3268 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3226, 31, !dbg !160 + %3269 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %3237, float %3238, float %3239, float %3240, float %3241, float %3242, float %3243, float %3244, float %3245, float %3246, float %3247, float %3248, float %3249, float %3250, float %3251, float %3252, float %3253, float %3254, float %3255, float %3256, float %3257, float %3258, float %3259, float %3260, float %3261, float %3262, float %3263, float %3264, float %3265, float %3266, float %3267, float %3268, i64 %3231, i64 %3236, i1 true) #3, !dbg !160 + %3270 = add i32 %2630, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !160 + %3271 = lshr exact i32 %3270, 4, !dbg !160 + %3272 = and i32 %3271, 16383, !dbg !160 + %3273 = zext nneg i32 %3272 to i64, !dbg !160 + %3274 = or disjoint i64 %3273, 4611686293372403712, !dbg !160 + %3275 = add i32 %3221, 64, !dbg !160 + %3276 = lshr exact i32 %3275, 4, !dbg !160 + %3277 = and i32 %3276, 16383, !dbg !160 + %3278 = zext nneg i32 %3277 to i64, !dbg !160 + %3279 = or disjoint i64 %3278, 4611686293338849280, !dbg !160 + %3280 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3269, 0, !dbg !160 + %3281 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3269, 1, !dbg !160 + %3282 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3269, 2, !dbg !160 + %3283 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3269, 3, !dbg !160 + %3284 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3269, 4, !dbg !160 + %3285 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3269, 5, !dbg !160 + %3286 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3269, 6, !dbg !160 + %3287 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3269, 7, !dbg !160 + %3288 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3269, 8, !dbg !160 + %3289 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3269, 9, !dbg !160 + %3290 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3269, 10, !dbg !160 + %3291 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3269, 11, !dbg !160 + %3292 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3269, 12, !dbg !160 + %3293 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3269, 13, !dbg !160 + %3294 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3269, 14, !dbg !160 + %3295 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3269, 15, !dbg !160 + %3296 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3269, 16, !dbg !160 + %3297 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3269, 17, !dbg !160 + %3298 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3269, 18, !dbg !160 + %3299 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3269, 19, !dbg !160 + %3300 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3269, 20, !dbg !160 + %3301 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3269, 21, !dbg !160 + %3302 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3269, 22, !dbg !160 + %3303 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3269, 23, !dbg !160 + %3304 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3269, 24, !dbg !160 + %3305 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3269, 25, !dbg !160 + %3306 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3269, 26, !dbg !160 + %3307 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3269, 27, !dbg !160 + %3308 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3269, 28, !dbg !160 + %3309 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3269, 29, !dbg !160 + %3310 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3269, 30, !dbg !160 + %3311 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3269, 31, !dbg !160 + %3312 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %3280, float %3281, float %3282, float %3283, float %3284, float %3285, float %3286, float %3287, float %3288, float %3289, float %3290, float %3291, float %3292, float %3293, float %3294, float %3295, float %3296, float %3297, float %3298, float %3299, float %3300, float %3301, float %3302, float %3303, float %3304, float %3305, float %3306, float %3307, float %3308, float %3309, float %3310, float %3311, i64 %3274, i64 %3279, i1 true) #3, !dbg !160 + %3313 = add i32 %2674, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !160 + %3314 = lshr exact i32 %3313, 4, !dbg !160 + %3315 = and i32 %3314, 16383, !dbg !160 + %3316 = zext nneg i32 %3315 to i64, !dbg !160 + %3317 = or disjoint i64 %3316, 4611686293372403712, !dbg !160 + %3318 = add i32 %3221, 96, !dbg !160 + %3319 = lshr exact i32 %3318, 4, !dbg !160 + %3320 = and i32 %3319, 16383, !dbg !160 + %3321 = zext nneg i32 %3320 to i64, !dbg !160 + %3322 = or disjoint i64 %3321, 4611686293338849280, !dbg !160 + %3323 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3312, 0, !dbg !160 + %3324 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3312, 1, !dbg !160 + %3325 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3312, 2, !dbg !160 + %3326 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3312, 3, !dbg !160 + %3327 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3312, 4, !dbg !160 + %3328 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3312, 5, !dbg !160 + %3329 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3312, 6, !dbg !160 + %3330 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3312, 7, !dbg !160 + %3331 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3312, 8, !dbg !160 + %3332 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3312, 9, !dbg !160 + %3333 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3312, 10, !dbg !160 + %3334 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3312, 11, !dbg !160 + %3335 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3312, 12, !dbg !160 + %3336 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3312, 13, !dbg !160 + %3337 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3312, 14, !dbg !160 + %3338 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3312, 15, !dbg !160 + %3339 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3312, 16, !dbg !160 + %3340 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3312, 17, !dbg !160 + %3341 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3312, 18, !dbg !160 + %3342 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3312, 19, !dbg !160 + %3343 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3312, 20, !dbg !160 + %3344 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3312, 21, !dbg !160 + %3345 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3312, 22, !dbg !160 + %3346 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3312, 23, !dbg !160 + %3347 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3312, 24, !dbg !160 + %3348 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3312, 25, !dbg !160 + %3349 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3312, 26, !dbg !160 + %3350 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3312, 27, !dbg !160 + %3351 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3312, 28, !dbg !160 + %3352 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3312, 29, !dbg !160 + %3353 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3312, 30, !dbg !160 + %3354 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3312, 31, !dbg !160 + %3355 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %3323, float %3324, float %3325, float %3326, float %3327, float %3328, float %3329, float %3330, float %3331, float %3332, float %3333, float %3334, float %3335, float %3336, float %3337, float %3338, float %3339, float %3340, float %3341, float %3342, float %3343, float %3344, float %3345, float %3346, float %3347, float %3348, float %3349, float %3350, float %3351, float %3352, float %3353, float %3354, i64 %3317, i64 %3322, i1 true) #3, !dbg !160 + %3356 = add i32 %2718, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !160 + %3357 = lshr exact i32 %3356, 4, !dbg !160 + %3358 = and i32 %3357, 16383, !dbg !160 + %3359 = zext nneg i32 %3358 to i64, !dbg !160 + %3360 = or disjoint i64 %3359, 4611686293372403712, !dbg !160 + %3361 = add i32 %3221, 8192, !dbg !160 + %3362 = lshr exact i32 %3361, 4, !dbg !160 + %3363 = and i32 %3362, 16383, !dbg !160 + %3364 = zext nneg i32 %3363 to i64, !dbg !160 + %3365 = or disjoint i64 %3364, 4611686293338849280, !dbg !160 + %3366 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3355, 0, !dbg !160 + %3367 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3355, 1, !dbg !160 + %3368 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3355, 2, !dbg !160 + %3369 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3355, 3, !dbg !160 + %3370 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3355, 4, !dbg !160 + %3371 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3355, 5, !dbg !160 + %3372 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3355, 6, !dbg !160 + %3373 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3355, 7, !dbg !160 + %3374 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3355, 8, !dbg !160 + %3375 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3355, 9, !dbg !160 + %3376 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3355, 10, !dbg !160 + %3377 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3355, 11, !dbg !160 + %3378 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3355, 12, !dbg !160 + %3379 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3355, 13, !dbg !160 + %3380 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3355, 14, !dbg !160 + %3381 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3355, 15, !dbg !160 + %3382 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3355, 16, !dbg !160 + %3383 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3355, 17, !dbg !160 + %3384 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3355, 18, !dbg !160 + %3385 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3355, 19, !dbg !160 + %3386 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3355, 20, !dbg !160 + %3387 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3355, 21, !dbg !160 + %3388 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3355, 22, !dbg !160 + %3389 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3355, 23, !dbg !160 + %3390 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3355, 24, !dbg !160 + %3391 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3355, 25, !dbg !160 + %3392 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3355, 26, !dbg !160 + %3393 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3355, 27, !dbg !160 + %3394 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3355, 28, !dbg !160 + %3395 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3355, 29, !dbg !160 + %3396 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3355, 30, !dbg !160 + %3397 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3355, 31, !dbg !160 + %3398 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %3366, float %3367, float %3368, float %3369, float %3370, float %3371, float %3372, float %3373, float %3374, float %3375, float %3376, float %3377, float %3378, float %3379, float %3380, float %3381, float %3382, float %3383, float %3384, float %3385, float %3386, float %3387, float %3388, float %3389, float %3390, float %3391, float %3392, float %3393, float %3394, float %3395, float %3396, float %3397, i64 %3360, i64 %3365, i1 true) #3, !dbg !160 + %3399 = add i32 %2762, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !160 + %3400 = lshr exact i32 %3399, 4, !dbg !160 + %3401 = and i32 %3400, 16383, !dbg !160 + %3402 = zext nneg i32 %3401 to i64, !dbg !160 + %3403 = or disjoint i64 %3402, 4611686293372403712, !dbg !160 + %3404 = add i32 %3221, 8224, !dbg !160 + %3405 = lshr exact i32 %3404, 4, !dbg !160 + %3406 = and i32 %3405, 16383, !dbg !160 + %3407 = zext nneg i32 %3406 to i64, !dbg !160 + %3408 = or disjoint i64 %3407, 4611686293338849280, !dbg !160 + %3409 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3398, 0, !dbg !160 + %3410 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3398, 1, !dbg !160 + %3411 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3398, 2, !dbg !160 + %3412 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3398, 3, !dbg !160 + %3413 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3398, 4, !dbg !160 + %3414 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3398, 5, !dbg !160 + %3415 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3398, 6, !dbg !160 + %3416 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3398, 7, !dbg !160 + %3417 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3398, 8, !dbg !160 + %3418 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3398, 9, !dbg !160 + %3419 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3398, 10, !dbg !160 + %3420 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3398, 11, !dbg !160 + %3421 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3398, 12, !dbg !160 + %3422 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3398, 13, !dbg !160 + %3423 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3398, 14, !dbg !160 + %3424 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3398, 15, !dbg !160 + %3425 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3398, 16, !dbg !160 + %3426 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3398, 17, !dbg !160 + %3427 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3398, 18, !dbg !160 + %3428 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3398, 19, !dbg !160 + %3429 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3398, 20, !dbg !160 + %3430 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3398, 21, !dbg !160 + %3431 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3398, 22, !dbg !160 + %3432 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3398, 23, !dbg !160 + %3433 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3398, 24, !dbg !160 + %3434 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3398, 25, !dbg !160 + %3435 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3398, 26, !dbg !160 + %3436 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3398, 27, !dbg !160 + %3437 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3398, 28, !dbg !160 + %3438 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3398, 29, !dbg !160 + %3439 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3398, 30, !dbg !160 + %3440 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3398, 31, !dbg !160 + %3441 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %3409, float %3410, float %3411, float %3412, float %3413, float %3414, float %3415, float %3416, float %3417, float %3418, float %3419, float %3420, float %3421, float %3422, float %3423, float %3424, float %3425, float %3426, float %3427, float %3428, float %3429, float %3430, float %3431, float %3432, float %3433, float %3434, float %3435, float %3436, float %3437, float %3438, float %3439, float %3440, i64 %3403, i64 %3408, i1 true) #3, !dbg !160 + %3442 = add i32 %2806, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !160 + %3443 = lshr exact i32 %3442, 4, !dbg !160 + %3444 = and i32 %3443, 16383, !dbg !160 + %3445 = zext nneg i32 %3444 to i64, !dbg !160 + %3446 = or disjoint i64 %3445, 4611686293372403712, !dbg !160 + %3447 = add i32 %3221, 8256, !dbg !160 + %3448 = lshr exact i32 %3447, 4, !dbg !160 + %3449 = and i32 %3448, 16383, !dbg !160 + %3450 = zext nneg i32 %3449 to i64, !dbg !160 + %3451 = or disjoint i64 %3450, 4611686293338849280, !dbg !160 + %3452 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3441, 0, !dbg !160 + %3453 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3441, 1, !dbg !160 + %3454 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3441, 2, !dbg !160 + %3455 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3441, 3, !dbg !160 + %3456 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3441, 4, !dbg !160 + %3457 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3441, 5, !dbg !160 + %3458 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3441, 6, !dbg !160 + %3459 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3441, 7, !dbg !160 + %3460 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3441, 8, !dbg !160 + %3461 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3441, 9, !dbg !160 + %3462 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3441, 10, !dbg !160 + %3463 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3441, 11, !dbg !160 + %3464 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3441, 12, !dbg !160 + %3465 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3441, 13, !dbg !160 + %3466 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3441, 14, !dbg !160 + %3467 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3441, 15, !dbg !160 + %3468 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3441, 16, !dbg !160 + %3469 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3441, 17, !dbg !160 + %3470 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3441, 18, !dbg !160 + %3471 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3441, 19, !dbg !160 + %3472 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3441, 20, !dbg !160 + %3473 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3441, 21, !dbg !160 + %3474 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3441, 22, !dbg !160 + %3475 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3441, 23, !dbg !160 + %3476 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3441, 24, !dbg !160 + %3477 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3441, 25, !dbg !160 + %3478 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3441, 26, !dbg !160 + %3479 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3441, 27, !dbg !160 + %3480 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3441, 28, !dbg !160 + %3481 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3441, 29, !dbg !160 + %3482 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3441, 30, !dbg !160 + %3483 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3441, 31, !dbg !160 + %3484 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %3452, float %3453, float %3454, float %3455, float %3456, float %3457, float %3458, float %3459, float %3460, float %3461, float %3462, float %3463, float %3464, float %3465, float %3466, float %3467, float %3468, float %3469, float %3470, float %3471, float %3472, float %3473, float %3474, float %3475, float %3476, float %3477, float %3478, float %3479, float %3480, float %3481, float %3482, float %3483, i64 %3446, i64 %3451, i1 true) #3, !dbg !160 + %3485 = add i32 %2850, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !160 + %3486 = lshr exact i32 %3485, 4, !dbg !160 + %3487 = and i32 %3486, 16383, !dbg !160 + %3488 = zext nneg i32 %3487 to i64, !dbg !160 + %3489 = or disjoint i64 %3488, 4611686293372403712, !dbg !160 + %3490 = add i32 %3221, 8288, !dbg !160 + %3491 = lshr exact i32 %3490, 4, !dbg !160 + %3492 = and i32 %3491, 16383, !dbg !160 + %3493 = zext nneg i32 %3492 to i64, !dbg !160 + %3494 = or disjoint i64 %3493, 4611686293338849280, !dbg !160 + %3495 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3484, 0, !dbg !160 + %3496 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3484, 1, !dbg !160 + %3497 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3484, 2, !dbg !160 + %3498 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3484, 3, !dbg !160 + %3499 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3484, 4, !dbg !160 + %3500 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3484, 5, !dbg !160 + %3501 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3484, 6, !dbg !160 + %3502 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3484, 7, !dbg !160 + %3503 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3484, 8, !dbg !160 + %3504 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3484, 9, !dbg !160 + %3505 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3484, 10, !dbg !160 + %3506 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3484, 11, !dbg !160 + %3507 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3484, 12, !dbg !160 + %3508 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3484, 13, !dbg !160 + %3509 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3484, 14, !dbg !160 + %3510 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3484, 15, !dbg !160 + %3511 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3484, 16, !dbg !160 + %3512 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3484, 17, !dbg !160 + %3513 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3484, 18, !dbg !160 + %3514 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3484, 19, !dbg !160 + %3515 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3484, 20, !dbg !160 + %3516 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3484, 21, !dbg !160 + %3517 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3484, 22, !dbg !160 + %3518 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3484, 23, !dbg !160 + %3519 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3484, 24, !dbg !160 + %3520 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3484, 25, !dbg !160 + %3521 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3484, 26, !dbg !160 + %3522 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3484, 27, !dbg !160 + %3523 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3484, 28, !dbg !160 + %3524 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3484, 29, !dbg !160 + %3525 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3484, 30, !dbg !160 + %3526 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3484, 31, !dbg !160 + %3527 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %3495, float %3496, float %3497, float %3498, float %3499, float %3500, float %3501, float %3502, float %3503, float %3504, float %3505, float %3506, float %3507, float %3508, float %3509, float %3510, float %3511, float %3512, float %3513, float %3514, float %3515, float %3516, float %3517, float %3518, float %3519, float %3520, float %3521, float %3522, float %3523, float %3524, float %3525, float %3526, i64 %3489, i64 %3494, i1 true) #3, !dbg !160 + %3528 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3527, 0, !dbg !160 + %3529 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3527, 1, !dbg !160 + %3530 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3527, 2, !dbg !160 + %3531 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3527, 3, !dbg !160 + %3532 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3527, 4, !dbg !160 + %3533 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3527, 5, !dbg !160 + %3534 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3527, 6, !dbg !160 + %3535 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3527, 7, !dbg !160 + %3536 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3527, 8, !dbg !160 + %3537 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3527, 9, !dbg !160 + %3538 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3527, 10, !dbg !160 + %3539 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3527, 11, !dbg !160 + %3540 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3527, 12, !dbg !160 + %3541 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3527, 13, !dbg !160 + %3542 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3527, 14, !dbg !160 + %3543 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3527, 15, !dbg !160 + %3544 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3527, 16, !dbg !160 + %3545 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3527, 17, !dbg !160 + %3546 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3527, 18, !dbg !160 + %3547 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3527, 19, !dbg !160 + %3548 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3527, 20, !dbg !160 + %3549 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3527, 21, !dbg !160 + %3550 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3527, 22, !dbg !160 + %3551 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3527, 23, !dbg !160 + %3552 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3527, 24, !dbg !160 + %3553 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3527, 25, !dbg !160 + %3554 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3527, 26, !dbg !160 + %3555 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3527, 27, !dbg !160 + %3556 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3527, 28, !dbg !160 + %3557 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3527, 29, !dbg !160 + %3558 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3527, 30, !dbg !160 + %3559 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3527, 31, !dbg !160 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !160 + %3560 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } asm sideeffect "// wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37\0A\09wgmma.wait_group.sync.aligned 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37"(float %3528, float %3529, float %3530, float %3531, float %3532, float %3533, float %3534, float %3535, float %3536, float %3537, float %3538, float %3539, float %3540, float %3541, float %3542, float %3543, float %3544, float %3545, float %3546, float %3547, float %3548, float %3549, float %3550, float %3551, float %3552, float %3553, float %3554, float %3555, float %3556, float %3557, float %3558, float %3559, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072), i32 0, i32 0, ptr addrspace(3) %3215, i32 0, i32 0) #3, !dbg !160 + %3561 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3560, 0, !dbg !160 + %3562 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3560, 1, !dbg !160 + %3563 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3560, 2, !dbg !160 + %3564 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3560, 3, !dbg !160 + %3565 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3560, 4, !dbg !160 + %3566 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3560, 5, !dbg !160 + %3567 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3560, 6, !dbg !160 + %3568 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3560, 7, !dbg !160 + %3569 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3560, 8, !dbg !160 + %3570 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3560, 9, !dbg !160 + %3571 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3560, 10, !dbg !160 + %3572 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3560, 11, !dbg !160 + %3573 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3560, 12, !dbg !160 + %3574 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3560, 13, !dbg !160 + %3575 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3560, 14, !dbg !160 + %3576 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3560, 15, !dbg !160 + %3577 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3560, 16, !dbg !160 + %3578 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3560, 17, !dbg !160 + %3579 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3560, 18, !dbg !160 + %3580 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3560, 19, !dbg !160 + %3581 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3560, 20, !dbg !160 + %3582 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3560, 21, !dbg !160 + %3583 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3560, 22, !dbg !160 + %3584 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3560, 23, !dbg !160 + %3585 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3560, 24, !dbg !160 + %3586 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3560, 25, !dbg !160 + %3587 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3560, 26, !dbg !160 + %3588 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3560, 27, !dbg !160 + %3589 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3560, 28, !dbg !160 + %3590 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3560, 29, !dbg !160 + %3591 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3560, 30, !dbg !160 + %3592 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3560, 31, !dbg !160 + %3593 = insertelement <2 x float> poison, float %3561, i64 0, !dbg !155 + %3594 = insertelement <2 x float> %3593, float %3562, i64 1, !dbg !155 + %3595 = fsub <2 x float> %3594, %2560, !dbg !155 + %3596 = insertelement <2 x float> poison, float %.0.i1317, i64 0, !dbg !161 + %3597 = insertelement <2 x float> %3596, float %.0.i1320, i64 1, !dbg !161 + %3598 = fmul <2 x float> %3597, %3595, !dbg !161 + %3599 = fptrunc <2 x float> %3598 to <2 x bfloat>, !dbg !162 + %3600 = insertelement <2 x float> poison, float %3563, i64 0, !dbg !155 + %3601 = insertelement <2 x float> %3600, float %3564, i64 1, !dbg !155 + %3602 = fsub <2 x float> %3601, %2558, !dbg !155 + %3603 = insertelement <2 x float> poison, float %.0.i1323, i64 0, !dbg !161 + %3604 = insertelement <2 x float> %3603, float %.0.i1326, i64 1, !dbg !161 + %3605 = fmul <2 x float> %3604, %3602, !dbg !161 + %3606 = fptrunc <2 x float> %3605 to <2 x bfloat>, !dbg !162 + %3607 = insertelement <2 x float> poison, float %3565, i64 0, !dbg !155 + %3608 = insertelement <2 x float> %3607, float %3566, i64 1, !dbg !155 + %3609 = fsub <2 x float> %3608, %2560, !dbg !155 + %3610 = insertelement <2 x float> poison, float %.0.i1329, i64 0, !dbg !161 + %3611 = insertelement <2 x float> %3610, float %.0.i1332, i64 1, !dbg !161 + %3612 = fmul <2 x float> %3611, %3609, !dbg !161 + %3613 = fptrunc <2 x float> %3612 to <2 x bfloat>, !dbg !162 + %3614 = insertelement <2 x float> poison, float %3567, i64 0, !dbg !155 + %3615 = insertelement <2 x float> %3614, float %3568, i64 1, !dbg !155 + %3616 = fsub <2 x float> %3615, %2558, !dbg !155 + %3617 = insertelement <2 x float> poison, float %.0.i1335, i64 0, !dbg !161 + %3618 = insertelement <2 x float> %3617, float %.0.i1338, i64 1, !dbg !161 + %3619 = fmul <2 x float> %3618, %3616, !dbg !161 + %3620 = fptrunc <2 x float> %3619 to <2 x bfloat>, !dbg !162 + %3621 = insertelement <2 x float> poison, float %3569, i64 0, !dbg !155 + %3622 = insertelement <2 x float> %3621, float %3570, i64 1, !dbg !155 + %3623 = fsub <2 x float> %3622, %2560, !dbg !155 + %3624 = insertelement <2 x float> poison, float %.0.i1341, i64 0, !dbg !161 + %3625 = insertelement <2 x float> %3624, float %.0.i1344, i64 1, !dbg !161 + %3626 = fmul <2 x float> %3625, %3623, !dbg !161 + %3627 = fptrunc <2 x float> %3626 to <2 x bfloat>, !dbg !162 + %3628 = insertelement <2 x float> poison, float %3571, i64 0, !dbg !155 + %3629 = insertelement <2 x float> %3628, float %3572, i64 1, !dbg !155 + %3630 = fsub <2 x float> %3629, %2558, !dbg !155 + %3631 = insertelement <2 x float> poison, float %.0.i1347, i64 0, !dbg !161 + %3632 = insertelement <2 x float> %3631, float %.0.i1350, i64 1, !dbg !161 + %3633 = fmul <2 x float> %3632, %3630, !dbg !161 + %3634 = fptrunc <2 x float> %3633 to <2 x bfloat>, !dbg !162 + %3635 = insertelement <2 x float> poison, float %3573, i64 0, !dbg !155 + %3636 = insertelement <2 x float> %3635, float %3574, i64 1, !dbg !155 + %3637 = fsub <2 x float> %3636, %2560, !dbg !155 + %3638 = insertelement <2 x float> poison, float %.0.i1353, i64 0, !dbg !161 + %3639 = insertelement <2 x float> %3638, float %.0.i1356, i64 1, !dbg !161 + %3640 = fmul <2 x float> %3639, %3637, !dbg !161 + %3641 = fptrunc <2 x float> %3640 to <2 x bfloat>, !dbg !162 + %3642 = insertelement <2 x float> poison, float %3575, i64 0, !dbg !155 + %3643 = insertelement <2 x float> %3642, float %3576, i64 1, !dbg !155 + %3644 = fsub <2 x float> %3643, %2558, !dbg !155 + %3645 = insertelement <2 x float> poison, float %.0.i1359, i64 0, !dbg !161 + %3646 = insertelement <2 x float> %3645, float %.0.i1362, i64 1, !dbg !161 + %3647 = fmul <2 x float> %3646, %3644, !dbg !161 + %3648 = fptrunc <2 x float> %3647 to <2 x bfloat>, !dbg !162 + %3649 = insertelement <2 x float> poison, float %3577, i64 0, !dbg !155 + %3650 = insertelement <2 x float> %3649, float %3578, i64 1, !dbg !155 + %3651 = fsub <2 x float> %3650, %2560, !dbg !155 + %3652 = insertelement <2 x float> poison, float %.0.i1365, i64 0, !dbg !161 + %3653 = insertelement <2 x float> %3652, float %.0.i1368, i64 1, !dbg !161 + %3654 = fmul <2 x float> %3653, %3651, !dbg !161 + %3655 = fptrunc <2 x float> %3654 to <2 x bfloat>, !dbg !162 + %3656 = insertelement <2 x float> poison, float %3579, i64 0, !dbg !155 + %3657 = insertelement <2 x float> %3656, float %3580, i64 1, !dbg !155 + %3658 = fsub <2 x float> %3657, %2558, !dbg !155 + %3659 = insertelement <2 x float> poison, float %.0.i1371, i64 0, !dbg !161 + %3660 = insertelement <2 x float> %3659, float %.0.i1374, i64 1, !dbg !161 + %3661 = fmul <2 x float> %3660, %3658, !dbg !161 + %3662 = fptrunc <2 x float> %3661 to <2 x bfloat>, !dbg !162 + %3663 = insertelement <2 x float> poison, float %3581, i64 0, !dbg !155 + %3664 = insertelement <2 x float> %3663, float %3582, i64 1, !dbg !155 + %3665 = fsub <2 x float> %3664, %2560, !dbg !155 + %3666 = insertelement <2 x float> poison, float %.0.i1377, i64 0, !dbg !161 + %3667 = insertelement <2 x float> %3666, float %.0.i1380, i64 1, !dbg !161 + %3668 = fmul <2 x float> %3667, %3665, !dbg !161 + %3669 = fptrunc <2 x float> %3668 to <2 x bfloat>, !dbg !162 + %3670 = insertelement <2 x float> poison, float %3583, i64 0, !dbg !155 + %3671 = insertelement <2 x float> %3670, float %3584, i64 1, !dbg !155 + %3672 = fsub <2 x float> %3671, %2558, !dbg !155 + %3673 = insertelement <2 x float> poison, float %.0.i1383, i64 0, !dbg !161 + %3674 = insertelement <2 x float> %3673, float %.0.i1386, i64 1, !dbg !161 + %3675 = fmul <2 x float> %3674, %3672, !dbg !161 + %3676 = fptrunc <2 x float> %3675 to <2 x bfloat>, !dbg !162 + %3677 = insertelement <2 x float> poison, float %3585, i64 0, !dbg !155 + %3678 = insertelement <2 x float> %3677, float %3586, i64 1, !dbg !155 + %3679 = fsub <2 x float> %3678, %2560, !dbg !155 + %3680 = insertelement <2 x float> poison, float %.0.i1389, i64 0, !dbg !161 + %3681 = insertelement <2 x float> %3680, float %.0.i1392, i64 1, !dbg !161 + %3682 = fmul <2 x float> %3681, %3679, !dbg !161 + %3683 = fptrunc <2 x float> %3682 to <2 x bfloat>, !dbg !162 + %3684 = insertelement <2 x float> poison, float %3587, i64 0, !dbg !155 + %3685 = insertelement <2 x float> %3684, float %3588, i64 1, !dbg !155 + %3686 = fsub <2 x float> %3685, %2558, !dbg !155 + %3687 = insertelement <2 x float> poison, float %.0.i1395, i64 0, !dbg !161 + %3688 = insertelement <2 x float> %3687, float %.0.i1398, i64 1, !dbg !161 + %3689 = fmul <2 x float> %3688, %3686, !dbg !161 + %3690 = fptrunc <2 x float> %3689 to <2 x bfloat>, !dbg !162 + %3691 = insertelement <2 x float> poison, float %3589, i64 0, !dbg !155 + %3692 = insertelement <2 x float> %3691, float %3590, i64 1, !dbg !155 + %3693 = fsub <2 x float> %3692, %2560, !dbg !155 + %3694 = insertelement <2 x float> poison, float %.0.i1401, i64 0, !dbg !161 + %3695 = insertelement <2 x float> %3694, float %.0.i1404, i64 1, !dbg !161 + %3696 = fmul <2 x float> %3695, %3693, !dbg !161 + %3697 = fptrunc <2 x float> %3696 to <2 x bfloat>, !dbg !162 + %3698 = insertelement <2 x float> poison, float %3591, i64 0, !dbg !155 + %3699 = insertelement <2 x float> %3698, float %3592, i64 1, !dbg !155 + %3700 = fsub <2 x float> %3699, %2558, !dbg !155 + %3701 = insertelement <2 x float> poison, float %.0.i1407, i64 0, !dbg !161 + %3702 = insertelement <2 x float> %3701, float %.0.i1410, i64 1, !dbg !161 + %3703 = fmul <2 x float> %3702, %3700, !dbg !161 + %3704 = fptrunc <2 x float> %3703 to <2 x bfloat>, !dbg !162 + %3705 = bitcast <2 x bfloat> %3599 to i32, !dbg !163 + %3706 = bitcast <2 x bfloat> %3606 to i32, !dbg !163 + %3707 = bitcast <2 x bfloat> %3613 to i32, !dbg !163 + %3708 = bitcast <2 x bfloat> %3620 to i32, !dbg !163 + %3709 = bitcast <2 x bfloat> %3627 to i32, !dbg !163 + %3710 = bitcast <2 x bfloat> %3634 to i32, !dbg !163 + %3711 = bitcast <2 x bfloat> %3641 to i32, !dbg !163 + %3712 = bitcast <2 x bfloat> %3648 to i32, !dbg !163 + %3713 = bitcast <2 x bfloat> %3655 to i32, !dbg !163 + %3714 = bitcast <2 x bfloat> %3662 to i32, !dbg !163 + %3715 = bitcast <2 x bfloat> %3669 to i32, !dbg !163 + %3716 = bitcast <2 x bfloat> %3676 to i32, !dbg !163 + %3717 = bitcast <2 x bfloat> %3683 to i32, !dbg !163 + %3718 = bitcast <2 x bfloat> %3690 to i32, !dbg !163 + %3719 = bitcast <2 x bfloat> %3697 to i32, !dbg !163 + %3720 = bitcast <2 x bfloat> %3704 to i32, !dbg !163 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !163 + %3721 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %.pn, float %.pn2389, float %.pn2390, float %.pn2391, float %.pn2392, float %.pn2393, float %.pn2394, float %.pn2395, float %.pn2396, float %.pn2397, float %.pn2398, float %.pn2399, float %.pn2400, float %.pn2401, float %.pn2402, float %.pn2403, float %.pn2404, float %.pn2405, float %.pn2406, float %.pn2407, float %.pn2408, float %.pn2409, float %.pn2410, float %.pn2411, float %.pn2412, float %.pn2413, float %.pn2414, float %.pn2415, float %.pn2416, float %.pn2417, float %.pn2418, float %.pn2419, float %.pn2420, float %.pn2421, float %.pn2422, float %.pn2423, float %.pn2424, float %.pn2425, float %.pn2426, float %.pn2427, float %.pn2428, float %.pn2429, float %.pn2430, float %.pn2431, float %.pn2432, float %.pn2433, float %.pn2434, float %.pn2435, float %.pn2436, float %.pn2437, float %.pn2438, float %.pn2439, float %.pn2440, float %.pn2441, float %.pn2442, float %.pn2443, float %.pn2444, float %.pn2445, float %.pn2446, float %.pn2447, float %.pn2448, float %.pn2449, float %.pn2450, float %.pn2451, i32 %3705, i32 %3706, i32 %3707, i32 %3708, i64 %2584, i1 true) #3, !dbg !163 + %3722 = add i32 %2580, 2048, !dbg !163 + %3723 = lshr exact i32 %3722, 4, !dbg !163 + %3724 = and i32 %3723, 16383, !dbg !163 + %3725 = zext nneg i32 %3724 to i64, !dbg !163 + %3726 = or disjoint i64 %3725, 4611686293338849280, !dbg !163 + %3727 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 0, !dbg !163 + %3728 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 1, !dbg !163 + %3729 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 2, !dbg !163 + %3730 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 3, !dbg !163 + %3731 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 4, !dbg !163 + %3732 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 5, !dbg !163 + %3733 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 6, !dbg !163 + %3734 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 7, !dbg !163 + %3735 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 8, !dbg !163 + %3736 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 9, !dbg !163 + %3737 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 10, !dbg !163 + %3738 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 11, !dbg !163 + %3739 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 12, !dbg !163 + %3740 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 13, !dbg !163 + %3741 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 14, !dbg !163 + %3742 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 15, !dbg !163 + %3743 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 16, !dbg !163 + %3744 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 17, !dbg !163 + %3745 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 18, !dbg !163 + %3746 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 19, !dbg !163 + %3747 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 20, !dbg !163 + %3748 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 21, !dbg !163 + %3749 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 22, !dbg !163 + %3750 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 23, !dbg !163 + %3751 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 24, !dbg !163 + %3752 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 25, !dbg !163 + %3753 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 26, !dbg !163 + %3754 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 27, !dbg !163 + %3755 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 28, !dbg !163 + %3756 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 29, !dbg !163 + %3757 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 30, !dbg !163 + %3758 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 31, !dbg !163 + %3759 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 32, !dbg !163 + %3760 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 33, !dbg !163 + %3761 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 34, !dbg !163 + %3762 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 35, !dbg !163 + %3763 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 36, !dbg !163 + %3764 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 37, !dbg !163 + %3765 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 38, !dbg !163 + %3766 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 39, !dbg !163 + %3767 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 40, !dbg !163 + %3768 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 41, !dbg !163 + %3769 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 42, !dbg !163 + %3770 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 43, !dbg !163 + %3771 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 44, !dbg !163 + %3772 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 45, !dbg !163 + %3773 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 46, !dbg !163 + %3774 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 47, !dbg !163 + %3775 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 48, !dbg !163 + %3776 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 49, !dbg !163 + %3777 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 50, !dbg !163 + %3778 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 51, !dbg !163 + %3779 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 52, !dbg !163 + %3780 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 53, !dbg !163 + %3781 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 54, !dbg !163 + %3782 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 55, !dbg !163 + %3783 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 56, !dbg !163 + %3784 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 57, !dbg !163 + %3785 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 58, !dbg !163 + %3786 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 59, !dbg !163 + %3787 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 60, !dbg !163 + %3788 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 61, !dbg !163 + %3789 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 62, !dbg !163 + %3790 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3721, 63, !dbg !163 + %3791 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %3727, float %3728, float %3729, float %3730, float %3731, float %3732, float %3733, float %3734, float %3735, float %3736, float %3737, float %3738, float %3739, float %3740, float %3741, float %3742, float %3743, float %3744, float %3745, float %3746, float %3747, float %3748, float %3749, float %3750, float %3751, float %3752, float %3753, float %3754, float %3755, float %3756, float %3757, float %3758, float %3759, float %3760, float %3761, float %3762, float %3763, float %3764, float %3765, float %3766, float %3767, float %3768, float %3769, float %3770, float %3771, float %3772, float %3773, float %3774, float %3775, float %3776, float %3777, float %3778, float %3779, float %3780, float %3781, float %3782, float %3783, float %3784, float %3785, float %3786, float %3787, float %3788, float %3789, float %3790, i32 %3709, i32 %3710, i32 %3711, i32 %3712, i64 %3726, i1 true) #3, !dbg !163 + %3792 = add i32 %2580, 4096, !dbg !163 + %3793 = lshr exact i32 %3792, 4, !dbg !163 + %3794 = and i32 %3793, 16383, !dbg !163 + %3795 = zext nneg i32 %3794 to i64, !dbg !163 + %3796 = or disjoint i64 %3795, 4611686293338849280, !dbg !163 + %3797 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 0, !dbg !163 + %3798 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 1, !dbg !163 + %3799 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 2, !dbg !163 + %3800 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 3, !dbg !163 + %3801 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 4, !dbg !163 + %3802 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 5, !dbg !163 + %3803 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 6, !dbg !163 + %3804 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 7, !dbg !163 + %3805 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 8, !dbg !163 + %3806 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 9, !dbg !163 + %3807 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 10, !dbg !163 + %3808 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 11, !dbg !163 + %3809 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 12, !dbg !163 + %3810 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 13, !dbg !163 + %3811 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 14, !dbg !163 + %3812 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 15, !dbg !163 + %3813 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 16, !dbg !163 + %3814 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 17, !dbg !163 + %3815 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 18, !dbg !163 + %3816 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 19, !dbg !163 + %3817 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 20, !dbg !163 + %3818 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 21, !dbg !163 + %3819 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 22, !dbg !163 + %3820 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 23, !dbg !163 + %3821 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 24, !dbg !163 + %3822 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 25, !dbg !163 + %3823 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 26, !dbg !163 + %3824 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 27, !dbg !163 + %3825 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 28, !dbg !163 + %3826 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 29, !dbg !163 + %3827 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 30, !dbg !163 + %3828 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 31, !dbg !163 + %3829 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 32, !dbg !163 + %3830 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 33, !dbg !163 + %3831 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 34, !dbg !163 + %3832 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 35, !dbg !163 + %3833 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 36, !dbg !163 + %3834 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 37, !dbg !163 + %3835 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 38, !dbg !163 + %3836 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 39, !dbg !163 + %3837 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 40, !dbg !163 + %3838 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 41, !dbg !163 + %3839 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 42, !dbg !163 + %3840 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 43, !dbg !163 + %3841 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 44, !dbg !163 + %3842 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 45, !dbg !163 + %3843 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 46, !dbg !163 + %3844 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 47, !dbg !163 + %3845 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 48, !dbg !163 + %3846 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 49, !dbg !163 + %3847 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 50, !dbg !163 + %3848 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 51, !dbg !163 + %3849 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 52, !dbg !163 + %3850 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 53, !dbg !163 + %3851 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 54, !dbg !163 + %3852 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 55, !dbg !163 + %3853 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 56, !dbg !163 + %3854 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 57, !dbg !163 + %3855 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 58, !dbg !163 + %3856 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 59, !dbg !163 + %3857 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 60, !dbg !163 + %3858 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 61, !dbg !163 + %3859 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 62, !dbg !163 + %3860 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3791, 63, !dbg !163 + %3861 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %3797, float %3798, float %3799, float %3800, float %3801, float %3802, float %3803, float %3804, float %3805, float %3806, float %3807, float %3808, float %3809, float %3810, float %3811, float %3812, float %3813, float %3814, float %3815, float %3816, float %3817, float %3818, float %3819, float %3820, float %3821, float %3822, float %3823, float %3824, float %3825, float %3826, float %3827, float %3828, float %3829, float %3830, float %3831, float %3832, float %3833, float %3834, float %3835, float %3836, float %3837, float %3838, float %3839, float %3840, float %3841, float %3842, float %3843, float %3844, float %3845, float %3846, float %3847, float %3848, float %3849, float %3850, float %3851, float %3852, float %3853, float %3854, float %3855, float %3856, float %3857, float %3858, float %3859, float %3860, i32 %3713, i32 %3714, i32 %3715, i32 %3716, i64 %3796, i1 true) #3, !dbg !163 + %3862 = add i32 %2580, 6144, !dbg !163 + %3863 = lshr exact i32 %3862, 4, !dbg !163 + %3864 = and i32 %3863, 16383, !dbg !163 + %3865 = zext nneg i32 %3864 to i64, !dbg !163 + %3866 = or disjoint i64 %3865, 4611686293338849280, !dbg !163 + %3867 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 0, !dbg !163 + %3868 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 1, !dbg !163 + %3869 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 2, !dbg !163 + %3870 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 3, !dbg !163 + %3871 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 4, !dbg !163 + %3872 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 5, !dbg !163 + %3873 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 6, !dbg !163 + %3874 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 7, !dbg !163 + %3875 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 8, !dbg !163 + %3876 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 9, !dbg !163 + %3877 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 10, !dbg !163 + %3878 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 11, !dbg !163 + %3879 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 12, !dbg !163 + %3880 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 13, !dbg !163 + %3881 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 14, !dbg !163 + %3882 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 15, !dbg !163 + %3883 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 16, !dbg !163 + %3884 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 17, !dbg !163 + %3885 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 18, !dbg !163 + %3886 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 19, !dbg !163 + %3887 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 20, !dbg !163 + %3888 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 21, !dbg !163 + %3889 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 22, !dbg !163 + %3890 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 23, !dbg !163 + %3891 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 24, !dbg !163 + %3892 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 25, !dbg !163 + %3893 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 26, !dbg !163 + %3894 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 27, !dbg !163 + %3895 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 28, !dbg !163 + %3896 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 29, !dbg !163 + %3897 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 30, !dbg !163 + %3898 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 31, !dbg !163 + %3899 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 32, !dbg !163 + %3900 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 33, !dbg !163 + %3901 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 34, !dbg !163 + %3902 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 35, !dbg !163 + %3903 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 36, !dbg !163 + %3904 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 37, !dbg !163 + %3905 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 38, !dbg !163 + %3906 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 39, !dbg !163 + %3907 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 40, !dbg !163 + %3908 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 41, !dbg !163 + %3909 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 42, !dbg !163 + %3910 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 43, !dbg !163 + %3911 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 44, !dbg !163 + %3912 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 45, !dbg !163 + %3913 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 46, !dbg !163 + %3914 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 47, !dbg !163 + %3915 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 48, !dbg !163 + %3916 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 49, !dbg !163 + %3917 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 50, !dbg !163 + %3918 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 51, !dbg !163 + %3919 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 52, !dbg !163 + %3920 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 53, !dbg !163 + %3921 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 54, !dbg !163 + %3922 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 55, !dbg !163 + %3923 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 56, !dbg !163 + %3924 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 57, !dbg !163 + %3925 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 58, !dbg !163 + %3926 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 59, !dbg !163 + %3927 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 60, !dbg !163 + %3928 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 61, !dbg !163 + %3929 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 62, !dbg !163 + %3930 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3861, 63, !dbg !163 + %3931 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %3867, float %3868, float %3869, float %3870, float %3871, float %3872, float %3873, float %3874, float %3875, float %3876, float %3877, float %3878, float %3879, float %3880, float %3881, float %3882, float %3883, float %3884, float %3885, float %3886, float %3887, float %3888, float %3889, float %3890, float %3891, float %3892, float %3893, float %3894, float %3895, float %3896, float %3897, float %3898, float %3899, float %3900, float %3901, float %3902, float %3903, float %3904, float %3905, float %3906, float %3907, float %3908, float %3909, float %3910, float %3911, float %3912, float %3913, float %3914, float %3915, float %3916, float %3917, float %3918, float %3919, float %3920, float %3921, float %3922, float %3923, float %3924, float %3925, float %3926, float %3927, float %3928, float %3929, float %3930, i32 %3717, i32 %3718, i32 %3719, i32 %3720, i64 %3866, i1 true) #3, !dbg !163 + %3932 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 0, !dbg !163 + %3933 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 1, !dbg !163 + %3934 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 2, !dbg !163 + %3935 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 3, !dbg !163 + %3936 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 4, !dbg !163 + %3937 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 5, !dbg !163 + %3938 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 6, !dbg !163 + %3939 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 7, !dbg !163 + %3940 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 8, !dbg !163 + %3941 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 9, !dbg !163 + %3942 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 10, !dbg !163 + %3943 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 11, !dbg !163 + %3944 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 12, !dbg !163 + %3945 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 13, !dbg !163 + %3946 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 14, !dbg !163 + %3947 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 15, !dbg !163 + %3948 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 16, !dbg !163 + %3949 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 17, !dbg !163 + %3950 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 18, !dbg !163 + %3951 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 19, !dbg !163 + %3952 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 20, !dbg !163 + %3953 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 21, !dbg !163 + %3954 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 22, !dbg !163 + %3955 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 23, !dbg !163 + %3956 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 24, !dbg !163 + %3957 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 25, !dbg !163 + %3958 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 26, !dbg !163 + %3959 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 27, !dbg !163 + %3960 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 28, !dbg !163 + %3961 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 29, !dbg !163 + %3962 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 30, !dbg !163 + %3963 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 31, !dbg !163 + %3964 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 32, !dbg !163 + %3965 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 33, !dbg !163 + %3966 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 34, !dbg !163 + %3967 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 35, !dbg !163 + %3968 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 36, !dbg !163 + %3969 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 37, !dbg !163 + %3970 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 38, !dbg !163 + %3971 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 39, !dbg !163 + %3972 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 40, !dbg !163 + %3973 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 41, !dbg !163 + %3974 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 42, !dbg !163 + %3975 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 43, !dbg !163 + %3976 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 44, !dbg !163 + %3977 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 45, !dbg !163 + %3978 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 46, !dbg !163 + %3979 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 47, !dbg !163 + %3980 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 48, !dbg !163 + %3981 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 49, !dbg !163 + %3982 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 50, !dbg !163 + %3983 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 51, !dbg !163 + %3984 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 52, !dbg !163 + %3985 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 53, !dbg !163 + %3986 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 54, !dbg !163 + %3987 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 55, !dbg !163 + %3988 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 56, !dbg !163 + %3989 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 57, !dbg !163 + %3990 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 58, !dbg !163 + %3991 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 59, !dbg !163 + %3992 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 60, !dbg !163 + %3993 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 61, !dbg !163 + %3994 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 62, !dbg !163 + %3995 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3931, 63, !dbg !163 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !163 + %3996 = add nuw nsw i32 %2564, 1, !dbg !149 + %3997 = lshr i32 %3996, 1, !dbg !164 + %3998 = zext nneg i32 %3997 to i64, !dbg !165 + %3999 = getelementptr i32, ptr addrspace(1) %2380, i64 %3998, !dbg !165 + %4000 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #3, !dbg !166 + %4001 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b32 { $0 }, [ $1 + 0 ], $2;", "=r,l,l,b"(ptr addrspace(1) %3999, i64 %4000, i1 %2566) #3, !dbg !166 + %4002 = add nuw nsw i32 %3997, 1, !dbg !167 + %4003 = icmp slt i32 %4002, %2384, !dbg !168 + %4004 = getelementptr i8, ptr addrspace(1) %3999, i64 4, !dbg !169 + %4005 = and i1 %2566, %4003, !dbg !149 + %4006 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #3, !dbg !170 + %4007 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b32 { $0 }, [ $1 + 0 ], $2;", "=r,l,l,b"(ptr addrspace(1) %4004, i64 %4006, i1 %4005) #3, !dbg !170 + %4008 = and i32 %2564, 1, !dbg !171 + %4009 = sub i32 %4007, %4001, !dbg !172 + %4010 = shl i32 %4009, 7, !dbg !173 + %4011 = add i32 %4010, 33554368, !dbg !174 + %4012 = shl nuw nsw i32 %4008, 13, !dbg !175 + %4013 = shl nuw nsw i32 %4008, 7, !dbg !176 + %4014 = xor i32 %4013, 128, !dbg !176 + %4015 = mul i32 %4014, %4011, !dbg !175 + %4016 = add i32 %4015, %4012, !dbg !175 + %4017 = sext i32 %4016 to i64, !dbg !151 + %4018 = getelementptr bfloat, ptr addrspace(1) %.pn10621849, i64 %4017, !dbg !151 + %4019 = getelementptr bfloat, ptr addrspace(1) %.pn10461850, i64 %4017, !dbg !151 + %4020 = getelementptr bfloat, ptr addrspace(1) %.pn10301851, i64 %4017, !dbg !151 + %4021 = getelementptr bfloat, ptr addrspace(1) %.pn10141852, i64 %4017, !dbg !151 + %4022 = getelementptr bfloat, ptr addrspace(1) %.pn11261853, i64 %4017, !dbg !152 + %4023 = getelementptr bfloat, ptr addrspace(1) %.pn11101854, i64 %4017, !dbg !152 + %4024 = getelementptr bfloat, ptr addrspace(1) %.pn10941855, i64 %4017, !dbg !152 + %4025 = getelementptr bfloat, ptr addrspace(1) %.pn10781856, i64 %4017, !dbg !152 + %4026 = add i32 %2563, 1, !dbg !149 + %4027 = icmp sgt i32 %4026, 2, !dbg !149 + %4028 = select i1 %4027, i32 0, i32 %4026, !dbg !149 + %4029 = shl i32 %4028, 13, !dbg !150 + %4030 = getelementptr bfloat, ptr addrspace(3) @global_smem, i32 %4029, !dbg !150 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !150 + %4031 = getelementptr inbounds nuw i8, ptr addrspace(3) %4030, i32 %402, !dbg !150 + %4032 = select i1 %2565, i32 16, i32 0, !dbg !150 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %4031, ptr addrspace(1) %4018, i32 %4032) #3, !dbg !150 + %4033 = getelementptr inbounds nuw i8, ptr addrspace(3) %4030, i32 %405, !dbg !150 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4033, ptr addrspace(1) %4019, i32 %4032) #3, !dbg !150 + %4034 = getelementptr inbounds nuw i8, ptr addrspace(3) %4030, i32 %407, !dbg !150 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4034, ptr addrspace(1) %4020, i32 %4032) #3, !dbg !150 + %4035 = getelementptr inbounds nuw i8, ptr addrspace(3) %4030, i32 %409, !dbg !150 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4035, ptr addrspace(1) %4021, i32 %4032) #3, !dbg !150 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !150 + %4036 = getelementptr bfloat, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %4029, !dbg !150 + %4037 = getelementptr inbounds nuw i8, ptr addrspace(3) %4036, i32 %402, !dbg !150 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %4037, ptr addrspace(1) %4022, i32 %4032) #3, !dbg !150 + %4038 = getelementptr inbounds nuw i8, ptr addrspace(3) %4036, i32 %405, !dbg !150 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4038, ptr addrspace(1) %4023, i32 %4032) #3, !dbg !150 + %4039 = getelementptr inbounds nuw i8, ptr addrspace(3) %4036, i32 %407, !dbg !150 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4039, ptr addrspace(1) %4024, i32 %4032) #3, !dbg !150 + %4040 = getelementptr inbounds nuw i8, ptr addrspace(3) %4036, i32 %409, !dbg !150 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4040, ptr addrspace(1) %4025, i32 %4032) #3, !dbg !150 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !150 + %exitcond2127.not = icmp eq i32 %3996, %2490, !dbg !149 + br i1 %exitcond2127.not, label %._crit_edge1859, label %2561, !dbg !149 + +._crit_edge1859: ; preds = %__nv_exp2f.exit1411, %._crit_edge1847 + %4041 = phi float [ %2426, %._crit_edge1847 ], [ %3932, %__nv_exp2f.exit1411 ], !dbg !77 + %4042 = phi float [ %2427, %._crit_edge1847 ], [ %3933, %__nv_exp2f.exit1411 ], !dbg !77 + %4043 = phi float [ %2428, %._crit_edge1847 ], [ %3934, %__nv_exp2f.exit1411 ], !dbg !77 + %4044 = phi float [ %2429, %._crit_edge1847 ], [ %3935, %__nv_exp2f.exit1411 ], !dbg !77 + %4045 = phi float [ %2430, %._crit_edge1847 ], [ %3936, %__nv_exp2f.exit1411 ], !dbg !77 + %4046 = phi float [ %2431, %._crit_edge1847 ], [ %3937, %__nv_exp2f.exit1411 ], !dbg !77 + %4047 = phi float [ %2432, %._crit_edge1847 ], [ %3938, %__nv_exp2f.exit1411 ], !dbg !77 + %4048 = phi float [ %2433, %._crit_edge1847 ], [ %3939, %__nv_exp2f.exit1411 ], !dbg !77 + %4049 = phi float [ %2434, %._crit_edge1847 ], [ %3940, %__nv_exp2f.exit1411 ], !dbg !77 + %4050 = phi float [ %2435, %._crit_edge1847 ], [ %3941, %__nv_exp2f.exit1411 ], !dbg !77 + %4051 = phi float [ %2436, %._crit_edge1847 ], [ %3942, %__nv_exp2f.exit1411 ], !dbg !77 + %4052 = phi float [ %2437, %._crit_edge1847 ], [ %3943, %__nv_exp2f.exit1411 ], !dbg !77 + %4053 = phi float [ %2438, %._crit_edge1847 ], [ %3944, %__nv_exp2f.exit1411 ], !dbg !77 + %4054 = phi float [ %2439, %._crit_edge1847 ], [ %3945, %__nv_exp2f.exit1411 ], !dbg !77 + %4055 = phi float [ %2440, %._crit_edge1847 ], [ %3946, %__nv_exp2f.exit1411 ], !dbg !77 + %4056 = phi float [ %2441, %._crit_edge1847 ], [ %3947, %__nv_exp2f.exit1411 ], !dbg !77 + %4057 = phi float [ %2442, %._crit_edge1847 ], [ %3948, %__nv_exp2f.exit1411 ], !dbg !77 + %4058 = phi float [ %2443, %._crit_edge1847 ], [ %3949, %__nv_exp2f.exit1411 ], !dbg !77 + %4059 = phi float [ %2444, %._crit_edge1847 ], [ %3950, %__nv_exp2f.exit1411 ], !dbg !77 + %4060 = phi float [ %2445, %._crit_edge1847 ], [ %3951, %__nv_exp2f.exit1411 ], !dbg !77 + %4061 = phi float [ %2446, %._crit_edge1847 ], [ %3952, %__nv_exp2f.exit1411 ], !dbg !77 + %4062 = phi float [ %2447, %._crit_edge1847 ], [ %3953, %__nv_exp2f.exit1411 ], !dbg !77 + %4063 = phi float [ %2448, %._crit_edge1847 ], [ %3954, %__nv_exp2f.exit1411 ], !dbg !77 + %4064 = phi float [ %2449, %._crit_edge1847 ], [ %3955, %__nv_exp2f.exit1411 ], !dbg !77 + %4065 = phi float [ %2450, %._crit_edge1847 ], [ %3956, %__nv_exp2f.exit1411 ], !dbg !77 + %4066 = phi float [ %2451, %._crit_edge1847 ], [ %3957, %__nv_exp2f.exit1411 ], !dbg !77 + %4067 = phi float [ %2452, %._crit_edge1847 ], [ %3958, %__nv_exp2f.exit1411 ], !dbg !77 + %4068 = phi float [ %2453, %._crit_edge1847 ], [ %3959, %__nv_exp2f.exit1411 ], !dbg !77 + %4069 = phi float [ %2454, %._crit_edge1847 ], [ %3960, %__nv_exp2f.exit1411 ], !dbg !77 + %4070 = phi float [ %2455, %._crit_edge1847 ], [ %3961, %__nv_exp2f.exit1411 ], !dbg !77 + %4071 = phi float [ %2456, %._crit_edge1847 ], [ %3962, %__nv_exp2f.exit1411 ], !dbg !77 + %4072 = phi float [ %2457, %._crit_edge1847 ], [ %3963, %__nv_exp2f.exit1411 ], !dbg !77 + %4073 = phi float [ %2458, %._crit_edge1847 ], [ %3964, %__nv_exp2f.exit1411 ], !dbg !77 + %4074 = phi float [ %2459, %._crit_edge1847 ], [ %3965, %__nv_exp2f.exit1411 ], !dbg !77 + %4075 = phi float [ %2460, %._crit_edge1847 ], [ %3966, %__nv_exp2f.exit1411 ], !dbg !77 + %4076 = phi float [ %2461, %._crit_edge1847 ], [ %3967, %__nv_exp2f.exit1411 ], !dbg !77 + %4077 = phi float [ %2462, %._crit_edge1847 ], [ %3968, %__nv_exp2f.exit1411 ], !dbg !77 + %4078 = phi float [ %2463, %._crit_edge1847 ], [ %3969, %__nv_exp2f.exit1411 ], !dbg !77 + %4079 = phi float [ %2464, %._crit_edge1847 ], [ %3970, %__nv_exp2f.exit1411 ], !dbg !77 + %4080 = phi float [ %2465, %._crit_edge1847 ], [ %3971, %__nv_exp2f.exit1411 ], !dbg !77 + %4081 = phi float [ %2466, %._crit_edge1847 ], [ %3972, %__nv_exp2f.exit1411 ], !dbg !77 + %4082 = phi float [ %2467, %._crit_edge1847 ], [ %3973, %__nv_exp2f.exit1411 ], !dbg !77 + %4083 = phi float [ %2468, %._crit_edge1847 ], [ %3974, %__nv_exp2f.exit1411 ], !dbg !77 + %4084 = phi float [ %2469, %._crit_edge1847 ], [ %3975, %__nv_exp2f.exit1411 ], !dbg !77 + %4085 = phi float [ %2470, %._crit_edge1847 ], [ %3976, %__nv_exp2f.exit1411 ], !dbg !77 + %4086 = phi float [ %2471, %._crit_edge1847 ], [ %3977, %__nv_exp2f.exit1411 ], !dbg !77 + %4087 = phi float [ %2472, %._crit_edge1847 ], [ %3978, %__nv_exp2f.exit1411 ], !dbg !77 + %4088 = phi float [ %2473, %._crit_edge1847 ], [ %3979, %__nv_exp2f.exit1411 ], !dbg !77 + %4089 = phi float [ %2474, %._crit_edge1847 ], [ %3980, %__nv_exp2f.exit1411 ], !dbg !77 + %4090 = phi float [ %2475, %._crit_edge1847 ], [ %3981, %__nv_exp2f.exit1411 ], !dbg !77 + %4091 = phi float [ %2476, %._crit_edge1847 ], [ %3982, %__nv_exp2f.exit1411 ], !dbg !77 + %4092 = phi float [ %2477, %._crit_edge1847 ], [ %3983, %__nv_exp2f.exit1411 ], !dbg !77 + %4093 = phi float [ %2478, %._crit_edge1847 ], [ %3984, %__nv_exp2f.exit1411 ], !dbg !77 + %4094 = phi float [ %2479, %._crit_edge1847 ], [ %3985, %__nv_exp2f.exit1411 ], !dbg !77 + %4095 = phi float [ %2480, %._crit_edge1847 ], [ %3986, %__nv_exp2f.exit1411 ], !dbg !77 + %4096 = phi float [ %2481, %._crit_edge1847 ], [ %3987, %__nv_exp2f.exit1411 ], !dbg !77 + %4097 = phi float [ %2482, %._crit_edge1847 ], [ %3988, %__nv_exp2f.exit1411 ], !dbg !77 + %4098 = phi float [ %2483, %._crit_edge1847 ], [ %3989, %__nv_exp2f.exit1411 ], !dbg !77 + %4099 = phi float [ %2484, %._crit_edge1847 ], [ %3990, %__nv_exp2f.exit1411 ], !dbg !77 + %4100 = phi float [ %2485, %._crit_edge1847 ], [ %3991, %__nv_exp2f.exit1411 ], !dbg !77 + %4101 = phi float [ %2486, %._crit_edge1847 ], [ %3992, %__nv_exp2f.exit1411 ], !dbg !77 + %4102 = phi float [ %2487, %._crit_edge1847 ], [ %3993, %__nv_exp2f.exit1411 ], !dbg !77 + %4103 = phi float [ %2488, %._crit_edge1847 ], [ %3994, %__nv_exp2f.exit1411 ], !dbg !77 + %4104 = phi float [ %2489, %._crit_edge1847 ], [ %3995, %__nv_exp2f.exit1411 ], !dbg !77 + %4105 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "// wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63\0A\09wgmma.wait_group.sync.aligned 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63"(float %4041, float %4042, float %4043, float %4044, float %4045, float %4046, float %4047, float %4048, float %4049, float %4050, float %4051, float %4052, float %4053, float %4054, float %4055, float %4056, float %4057, float %4058, float %4059, float %4060, float %4061, float %4062, float %4063, float %4064, float %4065, float %4066, float %4067, float %4068, float %4069, float %4070, float %4071, float %4072, float %4073, float %4074, float %4075, float %4076, float %4077, float %4078, float %4079, float %4080, float %4081, float %4082, float %4083, float %4084, float %4085, float %4086, float %4087, float %4088, float %4089, float %4090, float %4091, float %4092, float %4093, float %4094, float %4095, float %4096, float %4097, float %4098, float %4099, float %4100, float %4101, float %4102, float %4103, float %4104) #3, !dbg !149 + tail call void @llvm.nvvm.cp.async.wait.group(i32 0), !dbg !149 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !149 + %4106 = getelementptr bfloat, ptr addrspace(1) %77, i64 %103, !dbg !177 + %4107 = getelementptr bfloat, ptr addrspace(1) %77, i64 %105, !dbg !177 + %4108 = getelementptr bfloat, ptr addrspace(1) %77, i64 %107, !dbg !177 + %4109 = getelementptr bfloat, ptr addrspace(1) %77, i64 %109, !dbg !177 + %4110 = getelementptr bfloat, ptr addrspace(1) %77, i64 %111, !dbg !177 + %4111 = getelementptr bfloat, ptr addrspace(1) %77, i64 %113, !dbg !177 + %4112 = getelementptr bfloat, ptr addrspace(1) %77, i64 %115, !dbg !177 + %4113 = getelementptr bfloat, ptr addrspace(1) %77, i64 %117, !dbg !177 + %4114 = getelementptr bfloat, ptr addrspace(1) %4106, i64 %121, !dbg !178 + %4115 = getelementptr bfloat, ptr addrspace(1) %4107, i64 %121, !dbg !178 + %4116 = getelementptr bfloat, ptr addrspace(1) %4108, i64 %121, !dbg !178 + %4117 = getelementptr bfloat, ptr addrspace(1) %4109, i64 %121, !dbg !178 + %4118 = getelementptr bfloat, ptr addrspace(1) %4110, i64 %121, !dbg !178 + %4119 = getelementptr bfloat, ptr addrspace(1) %4111, i64 %121, !dbg !178 + %4120 = getelementptr bfloat, ptr addrspace(1) %4112, i64 %121, !dbg !178 + %4121 = getelementptr bfloat, ptr addrspace(1) %4113, i64 %121, !dbg !178 + %4122 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 0, !dbg !179 + %4123 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 1, !dbg !179 + %4124 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 2, !dbg !179 + %4125 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 3, !dbg !179 + %4126 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 4, !dbg !179 + %4127 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 5, !dbg !179 + %4128 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 6, !dbg !179 + %4129 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 7, !dbg !179 + %4130 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 8, !dbg !179 + %4131 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 9, !dbg !179 + %4132 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 10, !dbg !179 + %4133 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 11, !dbg !179 + %4134 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 12, !dbg !179 + %4135 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 13, !dbg !179 + %4136 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 14, !dbg !179 + %4137 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 15, !dbg !179 + %4138 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 16, !dbg !179 + %4139 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 17, !dbg !179 + %4140 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 18, !dbg !179 + %4141 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 19, !dbg !179 + %4142 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 20, !dbg !179 + %4143 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 21, !dbg !179 + %4144 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 22, !dbg !179 + %4145 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 23, !dbg !179 + %4146 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 24, !dbg !179 + %4147 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 25, !dbg !179 + %4148 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 26, !dbg !179 + %4149 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 27, !dbg !179 + %4150 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 28, !dbg !179 + %4151 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 29, !dbg !179 + %4152 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 30, !dbg !179 + %4153 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 31, !dbg !179 + %4154 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 32, !dbg !179 + %4155 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 33, !dbg !179 + %4156 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 34, !dbg !179 + %4157 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 35, !dbg !179 + %4158 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 36, !dbg !179 + %4159 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 37, !dbg !179 + %4160 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 38, !dbg !179 + %4161 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 39, !dbg !179 + %4162 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 40, !dbg !179 + %4163 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 41, !dbg !179 + %4164 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 42, !dbg !179 + %4165 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 43, !dbg !179 + %4166 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 44, !dbg !179 + %4167 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 45, !dbg !179 + %4168 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 46, !dbg !179 + %4169 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 47, !dbg !179 + %4170 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 48, !dbg !179 + %4171 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 49, !dbg !179 + %4172 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 50, !dbg !179 + %4173 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 51, !dbg !179 + %4174 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 52, !dbg !179 + %4175 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 53, !dbg !179 + %4176 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 54, !dbg !179 + %4177 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 55, !dbg !179 + %4178 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 56, !dbg !179 + %4179 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 57, !dbg !179 + %4180 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 58, !dbg !179 + %4181 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 59, !dbg !179 + %4182 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 60, !dbg !179 + %4183 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 61, !dbg !179 + %4184 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 62, !dbg !179 + %4185 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4105, 63, !dbg !179 + %4186 = insertelement <2 x float> poison, float %4122, i64 0, !dbg !179 + %4187 = insertelement <2 x float> %4186, float %4123, i64 1, !dbg !179 + %4188 = fmul <2 x float> %4187, splat (float 0x3FB6A09E60000000), !dbg !179 + %4189 = fptrunc <2 x float> %4188 to <2 x bfloat>, !dbg !180 + %4190 = insertelement <2 x float> poison, float %4124, i64 0, !dbg !179 + %4191 = insertelement <2 x float> %4190, float %4125, i64 1, !dbg !179 + %4192 = fmul <2 x float> %4191, splat (float 0x3FB6A09E60000000), !dbg !179 + %4193 = fptrunc <2 x float> %4192 to <2 x bfloat>, !dbg !180 + %4194 = insertelement <2 x float> poison, float %4126, i64 0, !dbg !179 + %4195 = insertelement <2 x float> %4194, float %4127, i64 1, !dbg !179 + %4196 = fmul <2 x float> %4195, splat (float 0x3FB6A09E60000000), !dbg !179 + %4197 = fptrunc <2 x float> %4196 to <2 x bfloat>, !dbg !180 + %4198 = insertelement <2 x float> poison, float %4128, i64 0, !dbg !179 + %4199 = insertelement <2 x float> %4198, float %4129, i64 1, !dbg !179 + %4200 = fmul <2 x float> %4199, splat (float 0x3FB6A09E60000000), !dbg !179 + %4201 = fptrunc <2 x float> %4200 to <2 x bfloat>, !dbg !180 + %4202 = insertelement <2 x float> poison, float %4130, i64 0, !dbg !179 + %4203 = insertelement <2 x float> %4202, float %4131, i64 1, !dbg !179 + %4204 = fmul <2 x float> %4203, splat (float 0x3FB6A09E60000000), !dbg !179 + %4205 = fptrunc <2 x float> %4204 to <2 x bfloat>, !dbg !180 + %4206 = insertelement <2 x float> poison, float %4132, i64 0, !dbg !179 + %4207 = insertelement <2 x float> %4206, float %4133, i64 1, !dbg !179 + %4208 = fmul <2 x float> %4207, splat (float 0x3FB6A09E60000000), !dbg !179 + %4209 = fptrunc <2 x float> %4208 to <2 x bfloat>, !dbg !180 + %4210 = insertelement <2 x float> poison, float %4134, i64 0, !dbg !179 + %4211 = insertelement <2 x float> %4210, float %4135, i64 1, !dbg !179 + %4212 = fmul <2 x float> %4211, splat (float 0x3FB6A09E60000000), !dbg !179 + %4213 = fptrunc <2 x float> %4212 to <2 x bfloat>, !dbg !180 + %4214 = insertelement <2 x float> poison, float %4136, i64 0, !dbg !179 + %4215 = insertelement <2 x float> %4214, float %4137, i64 1, !dbg !179 + %4216 = fmul <2 x float> %4215, splat (float 0x3FB6A09E60000000), !dbg !179 + %4217 = fptrunc <2 x float> %4216 to <2 x bfloat>, !dbg !180 + %4218 = insertelement <2 x float> poison, float %4138, i64 0, !dbg !179 + %4219 = insertelement <2 x float> %4218, float %4139, i64 1, !dbg !179 + %4220 = fmul <2 x float> %4219, splat (float 0x3FB6A09E60000000), !dbg !179 + %4221 = fptrunc <2 x float> %4220 to <2 x bfloat>, !dbg !180 + %4222 = insertelement <2 x float> poison, float %4140, i64 0, !dbg !179 + %4223 = insertelement <2 x float> %4222, float %4141, i64 1, !dbg !179 + %4224 = fmul <2 x float> %4223, splat (float 0x3FB6A09E60000000), !dbg !179 + %4225 = fptrunc <2 x float> %4224 to <2 x bfloat>, !dbg !180 + %4226 = insertelement <2 x float> poison, float %4142, i64 0, !dbg !179 + %4227 = insertelement <2 x float> %4226, float %4143, i64 1, !dbg !179 + %4228 = fmul <2 x float> %4227, splat (float 0x3FB6A09E60000000), !dbg !179 + %4229 = fptrunc <2 x float> %4228 to <2 x bfloat>, !dbg !180 + %4230 = insertelement <2 x float> poison, float %4144, i64 0, !dbg !179 + %4231 = insertelement <2 x float> %4230, float %4145, i64 1, !dbg !179 + %4232 = fmul <2 x float> %4231, splat (float 0x3FB6A09E60000000), !dbg !179 + %4233 = fptrunc <2 x float> %4232 to <2 x bfloat>, !dbg !180 + %4234 = insertelement <2 x float> poison, float %4146, i64 0, !dbg !179 + %4235 = insertelement <2 x float> %4234, float %4147, i64 1, !dbg !179 + %4236 = fmul <2 x float> %4235, splat (float 0x3FB6A09E60000000), !dbg !179 + %4237 = fptrunc <2 x float> %4236 to <2 x bfloat>, !dbg !180 + %4238 = insertelement <2 x float> poison, float %4148, i64 0, !dbg !179 + %4239 = insertelement <2 x float> %4238, float %4149, i64 1, !dbg !179 + %4240 = fmul <2 x float> %4239, splat (float 0x3FB6A09E60000000), !dbg !179 + %4241 = fptrunc <2 x float> %4240 to <2 x bfloat>, !dbg !180 + %4242 = insertelement <2 x float> poison, float %4150, i64 0, !dbg !179 + %4243 = insertelement <2 x float> %4242, float %4151, i64 1, !dbg !179 + %4244 = fmul <2 x float> %4243, splat (float 0x3FB6A09E60000000), !dbg !179 + %4245 = fptrunc <2 x float> %4244 to <2 x bfloat>, !dbg !180 + %4246 = insertelement <2 x float> poison, float %4152, i64 0, !dbg !179 + %4247 = insertelement <2 x float> %4246, float %4153, i64 1, !dbg !179 + %4248 = fmul <2 x float> %4247, splat (float 0x3FB6A09E60000000), !dbg !179 + %4249 = fptrunc <2 x float> %4248 to <2 x bfloat>, !dbg !180 + %4250 = insertelement <2 x float> poison, float %4154, i64 0, !dbg !179 + %4251 = insertelement <2 x float> %4250, float %4155, i64 1, !dbg !179 + %4252 = fmul <2 x float> %4251, splat (float 0x3FB6A09E60000000), !dbg !179 + %4253 = fptrunc <2 x float> %4252 to <2 x bfloat>, !dbg !180 + %4254 = insertelement <2 x float> poison, float %4156, i64 0, !dbg !179 + %4255 = insertelement <2 x float> %4254, float %4157, i64 1, !dbg !179 + %4256 = fmul <2 x float> %4255, splat (float 0x3FB6A09E60000000), !dbg !179 + %4257 = fptrunc <2 x float> %4256 to <2 x bfloat>, !dbg !180 + %4258 = insertelement <2 x float> poison, float %4158, i64 0, !dbg !179 + %4259 = insertelement <2 x float> %4258, float %4159, i64 1, !dbg !179 + %4260 = fmul <2 x float> %4259, splat (float 0x3FB6A09E60000000), !dbg !179 + %4261 = fptrunc <2 x float> %4260 to <2 x bfloat>, !dbg !180 + %4262 = insertelement <2 x float> poison, float %4160, i64 0, !dbg !179 + %4263 = insertelement <2 x float> %4262, float %4161, i64 1, !dbg !179 + %4264 = fmul <2 x float> %4263, splat (float 0x3FB6A09E60000000), !dbg !179 + %4265 = fptrunc <2 x float> %4264 to <2 x bfloat>, !dbg !180 + %4266 = insertelement <2 x float> poison, float %4162, i64 0, !dbg !179 + %4267 = insertelement <2 x float> %4266, float %4163, i64 1, !dbg !179 + %4268 = fmul <2 x float> %4267, splat (float 0x3FB6A09E60000000), !dbg !179 + %4269 = fptrunc <2 x float> %4268 to <2 x bfloat>, !dbg !180 + %4270 = insertelement <2 x float> poison, float %4164, i64 0, !dbg !179 + %4271 = insertelement <2 x float> %4270, float %4165, i64 1, !dbg !179 + %4272 = fmul <2 x float> %4271, splat (float 0x3FB6A09E60000000), !dbg !179 + %4273 = fptrunc <2 x float> %4272 to <2 x bfloat>, !dbg !180 + %4274 = insertelement <2 x float> poison, float %4166, i64 0, !dbg !179 + %4275 = insertelement <2 x float> %4274, float %4167, i64 1, !dbg !179 + %4276 = fmul <2 x float> %4275, splat (float 0x3FB6A09E60000000), !dbg !179 + %4277 = fptrunc <2 x float> %4276 to <2 x bfloat>, !dbg !180 + %4278 = insertelement <2 x float> poison, float %4168, i64 0, !dbg !179 + %4279 = insertelement <2 x float> %4278, float %4169, i64 1, !dbg !179 + %4280 = fmul <2 x float> %4279, splat (float 0x3FB6A09E60000000), !dbg !179 + %4281 = fptrunc <2 x float> %4280 to <2 x bfloat>, !dbg !180 + %4282 = insertelement <2 x float> poison, float %4170, i64 0, !dbg !179 + %4283 = insertelement <2 x float> %4282, float %4171, i64 1, !dbg !179 + %4284 = fmul <2 x float> %4283, splat (float 0x3FB6A09E60000000), !dbg !179 + %4285 = fptrunc <2 x float> %4284 to <2 x bfloat>, !dbg !180 + %4286 = insertelement <2 x float> poison, float %4172, i64 0, !dbg !179 + %4287 = insertelement <2 x float> %4286, float %4173, i64 1, !dbg !179 + %4288 = fmul <2 x float> %4287, splat (float 0x3FB6A09E60000000), !dbg !179 + %4289 = fptrunc <2 x float> %4288 to <2 x bfloat>, !dbg !180 + %4290 = insertelement <2 x float> poison, float %4174, i64 0, !dbg !179 + %4291 = insertelement <2 x float> %4290, float %4175, i64 1, !dbg !179 + %4292 = fmul <2 x float> %4291, splat (float 0x3FB6A09E60000000), !dbg !179 + %4293 = fptrunc <2 x float> %4292 to <2 x bfloat>, !dbg !180 + %4294 = insertelement <2 x float> poison, float %4176, i64 0, !dbg !179 + %4295 = insertelement <2 x float> %4294, float %4177, i64 1, !dbg !179 + %4296 = fmul <2 x float> %4295, splat (float 0x3FB6A09E60000000), !dbg !179 + %4297 = fptrunc <2 x float> %4296 to <2 x bfloat>, !dbg !180 + %4298 = insertelement <2 x float> poison, float %4178, i64 0, !dbg !179 + %4299 = insertelement <2 x float> %4298, float %4179, i64 1, !dbg !179 + %4300 = fmul <2 x float> %4299, splat (float 0x3FB6A09E60000000), !dbg !179 + %4301 = fptrunc <2 x float> %4300 to <2 x bfloat>, !dbg !180 + %4302 = insertelement <2 x float> poison, float %4180, i64 0, !dbg !179 + %4303 = insertelement <2 x float> %4302, float %4181, i64 1, !dbg !179 + %4304 = fmul <2 x float> %4303, splat (float 0x3FB6A09E60000000), !dbg !179 + %4305 = fptrunc <2 x float> %4304 to <2 x bfloat>, !dbg !180 + %4306 = insertelement <2 x float> poison, float %4182, i64 0, !dbg !179 + %4307 = insertelement <2 x float> %4306, float %4183, i64 1, !dbg !179 + %4308 = fmul <2 x float> %4307, splat (float 0x3FB6A09E60000000), !dbg !179 + %4309 = fptrunc <2 x float> %4308 to <2 x bfloat>, !dbg !180 + %4310 = insertelement <2 x float> poison, float %4184, i64 0, !dbg !179 + %4311 = insertelement <2 x float> %4310, float %4185, i64 1, !dbg !179 + %4312 = fmul <2 x float> %4311, splat (float 0x3FB6A09E60000000), !dbg !179 + %4313 = fptrunc <2 x float> %4312 to <2 x bfloat>, !dbg !180 + %4314 = shl nuw nsw i32 %365, 13, !dbg !180 + %4315 = shl nuw nsw i32 %35, 5, !dbg !180 + %4316 = and i32 %4315, 7264, !dbg !180 + %4317 = and i32 %35, 24, !dbg !180 + %4318 = shl nuw nsw i32 %4317, 4, !dbg !180 + %4319 = shl nuw nsw i32 %35, 2, !dbg !180 + %4320 = and i32 %4319, 16, !dbg !180 + %4321 = or disjoint i32 %4314, %4320, !dbg !180 + %4322 = or disjoint i32 %4316, %4318, !dbg !180 + %4323 = or disjoint i32 %4321, %4322, !dbg !180 + %4324 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %4323, !dbg !180 + %4325 = bitcast <2 x bfloat> %4189 to i32, !dbg !180 + %4326 = bitcast <2 x bfloat> %4197 to i32, !dbg !180 + %4327 = bitcast <2 x bfloat> %4205 to i32, !dbg !180 + %4328 = bitcast <2 x bfloat> %4213 to i32, !dbg !180 + %4329 = insertelement <4 x i32> poison, i32 %4325, i64 0, !dbg !180 + %4330 = insertelement <4 x i32> %4329, i32 %4326, i64 1, !dbg !180 + %4331 = insertelement <4 x i32> %4330, i32 %4327, i64 2, !dbg !180 + %4332 = insertelement <4 x i32> %4331, i32 %4328, i64 3, !dbg !180 + store <4 x i32> %4332, ptr addrspace(3) %4324, align 16, !dbg !180 + %4333 = getelementptr inbounds nuw i8, ptr addrspace(3) %4324, i32 512, !dbg !180 + %4334 = bitcast <2 x bfloat> %4193 to i32, !dbg !180 + %4335 = bitcast <2 x bfloat> %4201 to i32, !dbg !180 + %4336 = bitcast <2 x bfloat> %4209 to i32, !dbg !180 + %4337 = bitcast <2 x bfloat> %4217 to i32, !dbg !180 + %4338 = insertelement <4 x i32> poison, i32 %4334, i64 0, !dbg !180 + %4339 = insertelement <4 x i32> %4338, i32 %4335, i64 1, !dbg !180 + %4340 = insertelement <4 x i32> %4339, i32 %4336, i64 2, !dbg !180 + %4341 = insertelement <4 x i32> %4340, i32 %4337, i64 3, !dbg !180 + store <4 x i32> %4341, ptr addrspace(3) %4333, align 16, !dbg !180 + %4342 = xor i32 %4323, 32, !dbg !180 + %4343 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %4342, !dbg !180 + %4344 = bitcast <2 x bfloat> %4221 to i32, !dbg !180 + %4345 = bitcast <2 x bfloat> %4229 to i32, !dbg !180 + %4346 = bitcast <2 x bfloat> %4237 to i32, !dbg !180 + %4347 = bitcast <2 x bfloat> %4245 to i32, !dbg !180 + %4348 = insertelement <4 x i32> poison, i32 %4344, i64 0, !dbg !180 + %4349 = insertelement <4 x i32> %4348, i32 %4345, i64 1, !dbg !180 + %4350 = insertelement <4 x i32> %4349, i32 %4346, i64 2, !dbg !180 + %4351 = insertelement <4 x i32> %4350, i32 %4347, i64 3, !dbg !180 + store <4 x i32> %4351, ptr addrspace(3) %4343, align 16, !dbg !180 + %4352 = getelementptr inbounds nuw i8, ptr addrspace(3) %4343, i32 512, !dbg !180 + %4353 = bitcast <2 x bfloat> %4225 to i32, !dbg !180 + %4354 = bitcast <2 x bfloat> %4233 to i32, !dbg !180 + %4355 = bitcast <2 x bfloat> %4241 to i32, !dbg !180 + %4356 = bitcast <2 x bfloat> %4249 to i32, !dbg !180 + %4357 = insertelement <4 x i32> poison, i32 %4353, i64 0, !dbg !180 + %4358 = insertelement <4 x i32> %4357, i32 %4354, i64 1, !dbg !180 + %4359 = insertelement <4 x i32> %4358, i32 %4355, i64 2, !dbg !180 + %4360 = insertelement <4 x i32> %4359, i32 %4356, i64 3, !dbg !180 + store <4 x i32> %4360, ptr addrspace(3) %4352, align 16, !dbg !180 + %4361 = xor i32 %4323, 64, !dbg !180 + %4362 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %4361, !dbg !180 + %4363 = bitcast <2 x bfloat> %4253 to i32, !dbg !180 + %4364 = bitcast <2 x bfloat> %4261 to i32, !dbg !180 + %4365 = bitcast <2 x bfloat> %4269 to i32, !dbg !180 + %4366 = bitcast <2 x bfloat> %4277 to i32, !dbg !180 + %4367 = insertelement <4 x i32> poison, i32 %4363, i64 0, !dbg !180 + %4368 = insertelement <4 x i32> %4367, i32 %4364, i64 1, !dbg !180 + %4369 = insertelement <4 x i32> %4368, i32 %4365, i64 2, !dbg !180 + %4370 = insertelement <4 x i32> %4369, i32 %4366, i64 3, !dbg !180 + store <4 x i32> %4370, ptr addrspace(3) %4362, align 16, !dbg !180 + %4371 = getelementptr inbounds nuw i8, ptr addrspace(3) %4362, i32 512, !dbg !180 + %4372 = bitcast <2 x bfloat> %4257 to i32, !dbg !180 + %4373 = bitcast <2 x bfloat> %4265 to i32, !dbg !180 + %4374 = bitcast <2 x bfloat> %4273 to i32, !dbg !180 + %4375 = bitcast <2 x bfloat> %4281 to i32, !dbg !180 + %4376 = insertelement <4 x i32> poison, i32 %4372, i64 0, !dbg !180 + %4377 = insertelement <4 x i32> %4376, i32 %4373, i64 1, !dbg !180 + %4378 = insertelement <4 x i32> %4377, i32 %4374, i64 2, !dbg !180 + %4379 = insertelement <4 x i32> %4378, i32 %4375, i64 3, !dbg !180 + store <4 x i32> %4379, ptr addrspace(3) %4371, align 16, !dbg !180 + %4380 = xor i32 %4323, 96, !dbg !180 + %4381 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %4380, !dbg !180 + %4382 = bitcast <2 x bfloat> %4285 to i32, !dbg !180 + %4383 = bitcast <2 x bfloat> %4293 to i32, !dbg !180 + %4384 = bitcast <2 x bfloat> %4301 to i32, !dbg !180 + %4385 = bitcast <2 x bfloat> %4309 to i32, !dbg !180 + %4386 = insertelement <4 x i32> poison, i32 %4382, i64 0, !dbg !180 + %4387 = insertelement <4 x i32> %4386, i32 %4383, i64 1, !dbg !180 + %4388 = insertelement <4 x i32> %4387, i32 %4384, i64 2, !dbg !180 + %4389 = insertelement <4 x i32> %4388, i32 %4385, i64 3, !dbg !180 + store <4 x i32> %4389, ptr addrspace(3) %4381, align 16, !dbg !180 + %4390 = getelementptr inbounds nuw i8, ptr addrspace(3) %4381, i32 512, !dbg !180 + %4391 = bitcast <2 x bfloat> %4289 to i32, !dbg !180 + %4392 = bitcast <2 x bfloat> %4297 to i32, !dbg !180 + %4393 = bitcast <2 x bfloat> %4305 to i32, !dbg !180 + %4394 = bitcast <2 x bfloat> %4313 to i32, !dbg !180 + %4395 = insertelement <4 x i32> poison, i32 %4391, i64 0, !dbg !180 + %4396 = insertelement <4 x i32> %4395, i32 %4392, i64 1, !dbg !180 + %4397 = insertelement <4 x i32> %4396, i32 %4393, i64 2, !dbg !180 + %4398 = insertelement <4 x i32> %4397, i32 %4394, i64 3, !dbg !180 + store <4 x i32> %4398, ptr addrspace(3) %4390, align 16, !dbg !180 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !180 + %4399 = shl nuw nsw i32 %4317, 10, !dbg !180 + %4400 = shl nuw nsw i32 %365, 5, !dbg !180 + %4401 = and i32 %4319, 1008, !dbg !180 + %4402 = or disjoint i32 %4399, %4400, !dbg !180 + %4403 = xor i32 %4402, %4401, !dbg !180 + %4404 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %4403, !dbg !180 + %4405 = ptrtoint ptr addrspace(3) %4404 to i32, !dbg !180 + %4406 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %4405) #3, !dbg !180 + %4407 = extractvalue { i32, i32, i32, i32 } %4406, 0, !dbg !180 + %4408 = extractvalue { i32, i32, i32, i32 } %4406, 1, !dbg !180 + %4409 = extractvalue { i32, i32, i32, i32 } %4406, 2, !dbg !180 + %4410 = extractvalue { i32, i32, i32, i32 } %4406, 3, !dbg !180 + %4411 = getelementptr inbounds nuw i8, ptr addrspace(3) %4404, i32 1024, !dbg !180 + %4412 = ptrtoint ptr addrspace(3) %4411 to i32, !dbg !180 + %4413 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %4412) #3, !dbg !180 + %4414 = extractvalue { i32, i32, i32, i32 } %4413, 0, !dbg !180 + %4415 = extractvalue { i32, i32, i32, i32 } %4413, 1, !dbg !180 + %4416 = extractvalue { i32, i32, i32, i32 } %4413, 2, !dbg !180 + %4417 = extractvalue { i32, i32, i32, i32 } %4413, 3, !dbg !180 + %4418 = getelementptr inbounds nuw i8, ptr addrspace(3) %4404, i32 2048, !dbg !180 + %4419 = ptrtoint ptr addrspace(3) %4418 to i32, !dbg !180 + %4420 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %4419) #3, !dbg !180 + %4421 = extractvalue { i32, i32, i32, i32 } %4420, 0, !dbg !180 + %4422 = extractvalue { i32, i32, i32, i32 } %4420, 1, !dbg !180 + %4423 = extractvalue { i32, i32, i32, i32 } %4420, 2, !dbg !180 + %4424 = extractvalue { i32, i32, i32, i32 } %4420, 3, !dbg !180 + %4425 = getelementptr inbounds nuw i8, ptr addrspace(3) %4404, i32 3072, !dbg !180 + %4426 = ptrtoint ptr addrspace(3) %4425 to i32, !dbg !180 + %4427 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %4426) #3, !dbg !180 + %4428 = extractvalue { i32, i32, i32, i32 } %4427, 0, !dbg !180 + %4429 = extractvalue { i32, i32, i32, i32 } %4427, 1, !dbg !180 + %4430 = extractvalue { i32, i32, i32, i32 } %4427, 2, !dbg !180 + %4431 = extractvalue { i32, i32, i32, i32 } %4427, 3, !dbg !180 + %4432 = getelementptr inbounds nuw i8, ptr addrspace(3) %4404, i32 4096, !dbg !180 + %4433 = ptrtoint ptr addrspace(3) %4432 to i32, !dbg !180 + %4434 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %4433) #3, !dbg !180 + %4435 = extractvalue { i32, i32, i32, i32 } %4434, 0, !dbg !180 + %4436 = extractvalue { i32, i32, i32, i32 } %4434, 1, !dbg !180 + %4437 = extractvalue { i32, i32, i32, i32 } %4434, 2, !dbg !180 + %4438 = extractvalue { i32, i32, i32, i32 } %4434, 3, !dbg !180 + %4439 = getelementptr inbounds nuw i8, ptr addrspace(3) %4404, i32 5120, !dbg !180 + %4440 = ptrtoint ptr addrspace(3) %4439 to i32, !dbg !180 + %4441 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %4440) #3, !dbg !180 + %4442 = extractvalue { i32, i32, i32, i32 } %4441, 0, !dbg !180 + %4443 = extractvalue { i32, i32, i32, i32 } %4441, 1, !dbg !180 + %4444 = extractvalue { i32, i32, i32, i32 } %4441, 2, !dbg !180 + %4445 = extractvalue { i32, i32, i32, i32 } %4441, 3, !dbg !180 + %4446 = getelementptr inbounds nuw i8, ptr addrspace(3) %4404, i32 6144, !dbg !180 + %4447 = ptrtoint ptr addrspace(3) %4446 to i32, !dbg !180 + %4448 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %4447) #3, !dbg !180 + %4449 = extractvalue { i32, i32, i32, i32 } %4448, 0, !dbg !180 + %4450 = extractvalue { i32, i32, i32, i32 } %4448, 1, !dbg !180 + %4451 = extractvalue { i32, i32, i32, i32 } %4448, 2, !dbg !180 + %4452 = extractvalue { i32, i32, i32, i32 } %4448, 3, !dbg !180 + %4453 = getelementptr inbounds nuw i8, ptr addrspace(3) %4404, i32 7168, !dbg !180 + %4454 = ptrtoint ptr addrspace(3) %4453 to i32, !dbg !180 + %4455 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %4454) #3, !dbg !180 + %4456 = extractvalue { i32, i32, i32, i32 } %4455, 0, !dbg !180 + %4457 = extractvalue { i32, i32, i32, i32 } %4455, 1, !dbg !180 + %4458 = extractvalue { i32, i32, i32, i32 } %4455, 2, !dbg !180 + %4459 = extractvalue { i32, i32, i32, i32 } %4455, 3, !dbg !180 + tail call void asm sideeffect "st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l"(i32 %4407, i32 %4408, i32 %4409, i32 %4410, ptr addrspace(1) %4114) #3, !dbg !180 + tail call void asm sideeffect "st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l"(i32 %4414, i32 %4415, i32 %4416, i32 %4417, ptr addrspace(1) %4115) #3, !dbg !180 + tail call void asm sideeffect "st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l"(i32 %4421, i32 %4422, i32 %4423, i32 %4424, ptr addrspace(1) %4116) #3, !dbg !180 + tail call void asm sideeffect "st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l"(i32 %4428, i32 %4429, i32 %4430, i32 %4431, ptr addrspace(1) %4117) #3, !dbg !180 + tail call void asm sideeffect "st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l"(i32 %4435, i32 %4436, i32 %4437, i32 %4438, ptr addrspace(1) %4118) #3, !dbg !180 + tail call void asm sideeffect "st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l"(i32 %4442, i32 %4443, i32 %4444, i32 %4445, ptr addrspace(1) %4119) #3, !dbg !180 + tail call void asm sideeffect "st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l"(i32 %4449, i32 %4450, i32 %4451, i32 %4452, ptr addrspace(1) %4120) #3, !dbg !180 + tail call void asm sideeffect "st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l"(i32 %4456, i32 %4457, i32 %4458, i32 %4459, ptr addrspace(1) %4121) #3, !dbg !180 + br label %10331, !dbg !24 + +4460: ; preds = %20 + %4461 = shl nuw nsw i32 %21, 7, !dbg !181 + %4462 = or disjoint i32 %38, %4461, !dbg !182 + %4463 = or disjoint i32 %39, %4461, !dbg !182 + %4464 = or disjoint i32 %40, %4461, !dbg !182 + %4465 = or disjoint i32 %41, %4461, !dbg !182 + %4466 = or disjoint i32 %42, %4461, !dbg !182 + %4467 = or disjoint i32 %43, %4461, !dbg !182 + %4468 = or disjoint i32 %44, %4461, !dbg !182 + %4469 = or disjoint i32 %45, %4461, !dbg !182 + %4470 = or disjoint i32 %50, %4461, !dbg !182 + %4471 = or disjoint i32 %51, %4461, !dbg !182 + %4472 = shl nuw nsw i32 %4462, 7, !dbg !183 + %4473 = shl nuw nsw i32 %4463, 7, !dbg !183 + %4474 = shl nuw nsw i32 %4464, 7, !dbg !183 + %4475 = shl nuw nsw i32 %4465, 7, !dbg !183 + %4476 = shl nuw nsw i32 %4466, 7, !dbg !183 + %4477 = shl nuw nsw i32 %4467, 7, !dbg !183 + %4478 = shl nuw nsw i32 %4468, 7, !dbg !183 + %4479 = shl nuw nsw i32 %4469, 7, !dbg !183 + %4480 = zext nneg i32 %4472 to i64, !dbg !185 + %4481 = getelementptr bfloat, ptr addrspace(1) %32, i64 %4480, !dbg !185 + %4482 = zext nneg i32 %4473 to i64, !dbg !185 + %4483 = getelementptr bfloat, ptr addrspace(1) %32, i64 %4482, !dbg !185 + %4484 = zext nneg i32 %4474 to i64, !dbg !185 + %4485 = getelementptr bfloat, ptr addrspace(1) %32, i64 %4484, !dbg !185 + %4486 = zext nneg i32 %4475 to i64, !dbg !185 + %4487 = getelementptr bfloat, ptr addrspace(1) %32, i64 %4486, !dbg !185 + %4488 = zext nneg i32 %4476 to i64, !dbg !185 + %4489 = getelementptr bfloat, ptr addrspace(1) %32, i64 %4488, !dbg !185 + %4490 = zext nneg i32 %4477 to i64, !dbg !185 + %4491 = getelementptr bfloat, ptr addrspace(1) %32, i64 %4490, !dbg !185 + %4492 = zext nneg i32 %4478 to i64, !dbg !185 + %4493 = getelementptr bfloat, ptr addrspace(1) %32, i64 %4492, !dbg !185 + %4494 = zext nneg i32 %4479 to i64, !dbg !185 + %4495 = getelementptr bfloat, ptr addrspace(1) %32, i64 %4494, !dbg !185 + %4496 = shl nuw nsw i32 %35, 3, !dbg !186 + %4497 = and i32 %4496, 120, !dbg !186 + %4498 = zext nneg i32 %4497 to i64, !dbg !187 + %4499 = getelementptr bfloat, ptr addrspace(1) %4481, i64 %4498, !dbg !187 + %4500 = getelementptr bfloat, ptr addrspace(1) %4483, i64 %4498, !dbg !187 + %4501 = getelementptr bfloat, ptr addrspace(1) %4485, i64 %4498, !dbg !187 + %4502 = getelementptr bfloat, ptr addrspace(1) %4487, i64 %4498, !dbg !187 + %4503 = getelementptr bfloat, ptr addrspace(1) %4489, i64 %4498, !dbg !187 + %4504 = getelementptr bfloat, ptr addrspace(1) %4491, i64 %4498, !dbg !187 + %4505 = getelementptr bfloat, ptr addrspace(1) %4493, i64 %4498, !dbg !187 + %4506 = getelementptr bfloat, ptr addrspace(1) %4495, i64 %4498, !dbg !187 + %4507 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l"(ptr addrspace(1) %4499) #3, !dbg !188 + %4508 = extractvalue { i32, i32, i32, i32 } %4507, 0, !dbg !188 + %4509 = extractvalue { i32, i32, i32, i32 } %4507, 1, !dbg !188 + %4510 = extractvalue { i32, i32, i32, i32 } %4507, 2, !dbg !188 + %4511 = extractvalue { i32, i32, i32, i32 } %4507, 3, !dbg !188 + %4512 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l"(ptr addrspace(1) %4500) #3, !dbg !188 + %4513 = extractvalue { i32, i32, i32, i32 } %4512, 0, !dbg !188 + %4514 = extractvalue { i32, i32, i32, i32 } %4512, 1, !dbg !188 + %4515 = extractvalue { i32, i32, i32, i32 } %4512, 2, !dbg !188 + %4516 = extractvalue { i32, i32, i32, i32 } %4512, 3, !dbg !188 + %4517 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l"(ptr addrspace(1) %4501) #3, !dbg !188 + %4518 = extractvalue { i32, i32, i32, i32 } %4517, 0, !dbg !188 + %4519 = extractvalue { i32, i32, i32, i32 } %4517, 1, !dbg !188 + %4520 = extractvalue { i32, i32, i32, i32 } %4517, 2, !dbg !188 + %4521 = extractvalue { i32, i32, i32, i32 } %4517, 3, !dbg !188 + %4522 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l"(ptr addrspace(1) %4502) #3, !dbg !188 + %4523 = extractvalue { i32, i32, i32, i32 } %4522, 0, !dbg !188 + %4524 = extractvalue { i32, i32, i32, i32 } %4522, 1, !dbg !188 + %4525 = extractvalue { i32, i32, i32, i32 } %4522, 2, !dbg !188 + %4526 = extractvalue { i32, i32, i32, i32 } %4522, 3, !dbg !188 + %4527 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l"(ptr addrspace(1) %4503) #3, !dbg !188 + %4528 = extractvalue { i32, i32, i32, i32 } %4527, 0, !dbg !188 + %4529 = extractvalue { i32, i32, i32, i32 } %4527, 1, !dbg !188 + %4530 = extractvalue { i32, i32, i32, i32 } %4527, 2, !dbg !188 + %4531 = extractvalue { i32, i32, i32, i32 } %4527, 3, !dbg !188 + %4532 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l"(ptr addrspace(1) %4504) #3, !dbg !188 + %4533 = extractvalue { i32, i32, i32, i32 } %4532, 0, !dbg !188 + %4534 = extractvalue { i32, i32, i32, i32 } %4532, 1, !dbg !188 + %4535 = extractvalue { i32, i32, i32, i32 } %4532, 2, !dbg !188 + %4536 = extractvalue { i32, i32, i32, i32 } %4532, 3, !dbg !188 + %4537 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l"(ptr addrspace(1) %4505) #3, !dbg !188 + %4538 = extractvalue { i32, i32, i32, i32 } %4537, 0, !dbg !188 + %4539 = extractvalue { i32, i32, i32, i32 } %4537, 1, !dbg !188 + %4540 = extractvalue { i32, i32, i32, i32 } %4537, 2, !dbg !188 + %4541 = extractvalue { i32, i32, i32, i32 } %4537, 3, !dbg !188 + %4542 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l"(ptr addrspace(1) %4506) #3, !dbg !188 + %4543 = extractvalue { i32, i32, i32, i32 } %4542, 0, !dbg !188 + %4544 = extractvalue { i32, i32, i32, i32 } %4542, 1, !dbg !188 + %4545 = extractvalue { i32, i32, i32, i32 } %4542, 2, !dbg !188 + %4546 = extractvalue { i32, i32, i32, i32 } %4542, 3, !dbg !188 + %4547 = shl nuw nsw i32 %35, 4, !dbg !188 + %4548 = and i32 %4547, 112, !dbg !188 + %4549 = shl nuw nsw i32 %37, 3, !dbg !188 + %4550 = and i32 %35, 112, !dbg !188 + %4551 = and i32 %35, 8, !dbg !188 + %4552 = shl nuw nsw i32 %4551, 11, !dbg !188 + %4553 = or disjoint i32 %4548, %4549, !dbg !188 + %4554 = xor i32 %4553, %4550, !dbg !188 + %4555 = or disjoint i32 %4554, %4552, !dbg !188 + %4556 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328), i32 %4555, !dbg !188 + %4557 = insertelement <4 x i32> poison, i32 %4508, i64 0, !dbg !188 + %4558 = insertelement <4 x i32> %4557, i32 %4509, i64 1, !dbg !188 + %4559 = insertelement <4 x i32> %4558, i32 %4510, i64 2, !dbg !188 + %4560 = insertelement <4 x i32> %4559, i32 %4511, i64 3, !dbg !188 + store <4 x i32> %4560, ptr addrspace(3) %4556, align 16, !dbg !188 + %4561 = or disjoint i32 %4555, 2048, !dbg !188 + %4562 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328), i32 %4561, !dbg !188 + %4563 = insertelement <4 x i32> poison, i32 %4513, i64 0, !dbg !188 + %4564 = insertelement <4 x i32> %4563, i32 %4514, i64 1, !dbg !188 + %4565 = insertelement <4 x i32> %4564, i32 %4515, i64 2, !dbg !188 + %4566 = insertelement <4 x i32> %4565, i32 %4516, i64 3, !dbg !188 + store <4 x i32> %4566, ptr addrspace(3) %4562, align 16, !dbg !188 + %4567 = or disjoint i32 %4555, 4096, !dbg !188 + %4568 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328), i32 %4567, !dbg !188 + %4569 = insertelement <4 x i32> poison, i32 %4518, i64 0, !dbg !188 + %4570 = insertelement <4 x i32> %4569, i32 %4519, i64 1, !dbg !188 + %4571 = insertelement <4 x i32> %4570, i32 %4520, i64 2, !dbg !188 + %4572 = insertelement <4 x i32> %4571, i32 %4521, i64 3, !dbg !188 + store <4 x i32> %4572, ptr addrspace(3) %4568, align 16, !dbg !188 + %4573 = or disjoint i32 %4555, 6144, !dbg !188 + %4574 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328), i32 %4573, !dbg !188 + %4575 = insertelement <4 x i32> poison, i32 %4523, i64 0, !dbg !188 + %4576 = insertelement <4 x i32> %4575, i32 %4524, i64 1, !dbg !188 + %4577 = insertelement <4 x i32> %4576, i32 %4525, i64 2, !dbg !188 + %4578 = insertelement <4 x i32> %4577, i32 %4526, i64 3, !dbg !188 + store <4 x i32> %4578, ptr addrspace(3) %4574, align 16, !dbg !188 + %4579 = or disjoint i32 %4555, 8192, !dbg !188 + %4580 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328), i32 %4579, !dbg !188 + %4581 = insertelement <4 x i32> poison, i32 %4528, i64 0, !dbg !188 + %4582 = insertelement <4 x i32> %4581, i32 %4529, i64 1, !dbg !188 + %4583 = insertelement <4 x i32> %4582, i32 %4530, i64 2, !dbg !188 + %4584 = insertelement <4 x i32> %4583, i32 %4531, i64 3, !dbg !188 + store <4 x i32> %4584, ptr addrspace(3) %4580, align 16, !dbg !188 + %4585 = or disjoint i32 %4555, 10240, !dbg !188 + %4586 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328), i32 %4585, !dbg !188 + %4587 = insertelement <4 x i32> poison, i32 %4533, i64 0, !dbg !188 + %4588 = insertelement <4 x i32> %4587, i32 %4534, i64 1, !dbg !188 + %4589 = insertelement <4 x i32> %4588, i32 %4535, i64 2, !dbg !188 + %4590 = insertelement <4 x i32> %4589, i32 %4536, i64 3, !dbg !188 + store <4 x i32> %4590, ptr addrspace(3) %4586, align 16, !dbg !188 + %4591 = or disjoint i32 %4555, 12288, !dbg !188 + %4592 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328), i32 %4591, !dbg !188 + %4593 = insertelement <4 x i32> poison, i32 %4538, i64 0, !dbg !188 + %4594 = insertelement <4 x i32> %4593, i32 %4539, i64 1, !dbg !188 + %4595 = insertelement <4 x i32> %4594, i32 %4540, i64 2, !dbg !188 + %4596 = insertelement <4 x i32> %4595, i32 %4541, i64 3, !dbg !188 + store <4 x i32> %4596, ptr addrspace(3) %4592, align 16, !dbg !188 + %4597 = or disjoint i32 %4555, 14336, !dbg !188 + %4598 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328), i32 %4597, !dbg !188 + %4599 = insertelement <4 x i32> poison, i32 %4543, i64 0, !dbg !188 + %4600 = insertelement <4 x i32> %4599, i32 %4544, i64 1, !dbg !188 + %4601 = insertelement <4 x i32> %4600, i32 %4545, i64 2, !dbg !188 + %4602 = insertelement <4 x i32> %4601, i32 %4546, i64 3, !dbg !188 + store <4 x i32> %4602, ptr addrspace(3) %4598, align 16, !dbg !188 + %4603 = getelementptr bfloat, ptr addrspace(1) %33, i64 %4480, !dbg !189 + %4604 = getelementptr bfloat, ptr addrspace(1) %33, i64 %4482, !dbg !189 + %4605 = getelementptr bfloat, ptr addrspace(1) %33, i64 %4484, !dbg !189 + %4606 = getelementptr bfloat, ptr addrspace(1) %33, i64 %4486, !dbg !189 + %4607 = getelementptr bfloat, ptr addrspace(1) %33, i64 %4488, !dbg !189 + %4608 = getelementptr bfloat, ptr addrspace(1) %33, i64 %4490, !dbg !189 + %4609 = getelementptr bfloat, ptr addrspace(1) %33, i64 %4492, !dbg !189 + %4610 = getelementptr bfloat, ptr addrspace(1) %33, i64 %4494, !dbg !189 + %4611 = getelementptr bfloat, ptr addrspace(1) %4603, i64 %4498, !dbg !191 + %4612 = getelementptr bfloat, ptr addrspace(1) %4604, i64 %4498, !dbg !191 + %4613 = getelementptr bfloat, ptr addrspace(1) %4605, i64 %4498, !dbg !191 + %4614 = getelementptr bfloat, ptr addrspace(1) %4606, i64 %4498, !dbg !191 + %4615 = getelementptr bfloat, ptr addrspace(1) %4607, i64 %4498, !dbg !191 + %4616 = getelementptr bfloat, ptr addrspace(1) %4608, i64 %4498, !dbg !191 + %4617 = getelementptr bfloat, ptr addrspace(1) %4609, i64 %4498, !dbg !191 + %4618 = getelementptr bfloat, ptr addrspace(1) %4610, i64 %4498, !dbg !191 + %4619 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l"(ptr addrspace(1) %4611) #3, !dbg !192 + %4620 = extractvalue { i32, i32, i32, i32 } %4619, 0, !dbg !192 + %4621 = extractvalue { i32, i32, i32, i32 } %4619, 1, !dbg !192 + %4622 = extractvalue { i32, i32, i32, i32 } %4619, 2, !dbg !192 + %4623 = extractvalue { i32, i32, i32, i32 } %4619, 3, !dbg !192 + %4624 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l"(ptr addrspace(1) %4612) #3, !dbg !192 + %4625 = extractvalue { i32, i32, i32, i32 } %4624, 0, !dbg !192 + %4626 = extractvalue { i32, i32, i32, i32 } %4624, 1, !dbg !192 + %4627 = extractvalue { i32, i32, i32, i32 } %4624, 2, !dbg !192 + %4628 = extractvalue { i32, i32, i32, i32 } %4624, 3, !dbg !192 + %4629 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l"(ptr addrspace(1) %4613) #3, !dbg !192 + %4630 = extractvalue { i32, i32, i32, i32 } %4629, 0, !dbg !192 + %4631 = extractvalue { i32, i32, i32, i32 } %4629, 1, !dbg !192 + %4632 = extractvalue { i32, i32, i32, i32 } %4629, 2, !dbg !192 + %4633 = extractvalue { i32, i32, i32, i32 } %4629, 3, !dbg !192 + %4634 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l"(ptr addrspace(1) %4614) #3, !dbg !192 + %4635 = extractvalue { i32, i32, i32, i32 } %4634, 0, !dbg !192 + %4636 = extractvalue { i32, i32, i32, i32 } %4634, 1, !dbg !192 + %4637 = extractvalue { i32, i32, i32, i32 } %4634, 2, !dbg !192 + %4638 = extractvalue { i32, i32, i32, i32 } %4634, 3, !dbg !192 + %4639 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l"(ptr addrspace(1) %4615) #3, !dbg !192 + %4640 = extractvalue { i32, i32, i32, i32 } %4639, 0, !dbg !192 + %4641 = extractvalue { i32, i32, i32, i32 } %4639, 1, !dbg !192 + %4642 = extractvalue { i32, i32, i32, i32 } %4639, 2, !dbg !192 + %4643 = extractvalue { i32, i32, i32, i32 } %4639, 3, !dbg !192 + %4644 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l"(ptr addrspace(1) %4616) #3, !dbg !192 + %4645 = extractvalue { i32, i32, i32, i32 } %4644, 0, !dbg !192 + %4646 = extractvalue { i32, i32, i32, i32 } %4644, 1, !dbg !192 + %4647 = extractvalue { i32, i32, i32, i32 } %4644, 2, !dbg !192 + %4648 = extractvalue { i32, i32, i32, i32 } %4644, 3, !dbg !192 + %4649 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l"(ptr addrspace(1) %4617) #3, !dbg !192 + %4650 = extractvalue { i32, i32, i32, i32 } %4649, 0, !dbg !192 + %4651 = extractvalue { i32, i32, i32, i32 } %4649, 1, !dbg !192 + %4652 = extractvalue { i32, i32, i32, i32 } %4649, 2, !dbg !192 + %4653 = extractvalue { i32, i32, i32, i32 } %4649, 3, !dbg !192 + %4654 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l"(ptr addrspace(1) %4618) #3, !dbg !192 + %4655 = extractvalue { i32, i32, i32, i32 } %4654, 0, !dbg !192 + %4656 = extractvalue { i32, i32, i32, i32 } %4654, 1, !dbg !192 + %4657 = extractvalue { i32, i32, i32, i32 } %4654, 2, !dbg !192 + %4658 = extractvalue { i32, i32, i32, i32 } %4654, 3, !dbg !192 + %4659 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096), i32 %4555, !dbg !192 + %4660 = insertelement <4 x i32> poison, i32 %4620, i64 0, !dbg !192 + %4661 = insertelement <4 x i32> %4660, i32 %4621, i64 1, !dbg !192 + %4662 = insertelement <4 x i32> %4661, i32 %4622, i64 2, !dbg !192 + %4663 = insertelement <4 x i32> %4662, i32 %4623, i64 3, !dbg !192 + store <4 x i32> %4663, ptr addrspace(3) %4659, align 16, !dbg !192 + %4664 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096), i32 %4561, !dbg !192 + %4665 = insertelement <4 x i32> poison, i32 %4625, i64 0, !dbg !192 + %4666 = insertelement <4 x i32> %4665, i32 %4626, i64 1, !dbg !192 + %4667 = insertelement <4 x i32> %4666, i32 %4627, i64 2, !dbg !192 + %4668 = insertelement <4 x i32> %4667, i32 %4628, i64 3, !dbg !192 + store <4 x i32> %4668, ptr addrspace(3) %4664, align 16, !dbg !192 + %4669 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096), i32 %4567, !dbg !192 + %4670 = insertelement <4 x i32> poison, i32 %4630, i64 0, !dbg !192 + %4671 = insertelement <4 x i32> %4670, i32 %4631, i64 1, !dbg !192 + %4672 = insertelement <4 x i32> %4671, i32 %4632, i64 2, !dbg !192 + %4673 = insertelement <4 x i32> %4672, i32 %4633, i64 3, !dbg !192 + store <4 x i32> %4673, ptr addrspace(3) %4669, align 16, !dbg !192 + %4674 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096), i32 %4573, !dbg !192 + %4675 = insertelement <4 x i32> poison, i32 %4635, i64 0, !dbg !192 + %4676 = insertelement <4 x i32> %4675, i32 %4636, i64 1, !dbg !192 + %4677 = insertelement <4 x i32> %4676, i32 %4637, i64 2, !dbg !192 + %4678 = insertelement <4 x i32> %4677, i32 %4638, i64 3, !dbg !192 + store <4 x i32> %4678, ptr addrspace(3) %4674, align 16, !dbg !192 + %4679 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096), i32 %4579, !dbg !192 + %4680 = insertelement <4 x i32> poison, i32 %4640, i64 0, !dbg !192 + %4681 = insertelement <4 x i32> %4680, i32 %4641, i64 1, !dbg !192 + %4682 = insertelement <4 x i32> %4681, i32 %4642, i64 2, !dbg !192 + %4683 = insertelement <4 x i32> %4682, i32 %4643, i64 3, !dbg !192 + store <4 x i32> %4683, ptr addrspace(3) %4679, align 16, !dbg !192 + %4684 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096), i32 %4585, !dbg !192 + %4685 = insertelement <4 x i32> poison, i32 %4645, i64 0, !dbg !192 + %4686 = insertelement <4 x i32> %4685, i32 %4646, i64 1, !dbg !192 + %4687 = insertelement <4 x i32> %4686, i32 %4647, i64 2, !dbg !192 + %4688 = insertelement <4 x i32> %4687, i32 %4648, i64 3, !dbg !192 + store <4 x i32> %4688, ptr addrspace(3) %4684, align 16, !dbg !192 + %4689 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096), i32 %4591, !dbg !192 + %4690 = insertelement <4 x i32> poison, i32 %4650, i64 0, !dbg !192 + %4691 = insertelement <4 x i32> %4690, i32 %4651, i64 1, !dbg !192 + %4692 = insertelement <4 x i32> %4691, i32 %4652, i64 2, !dbg !192 + %4693 = insertelement <4 x i32> %4692, i32 %4653, i64 3, !dbg !192 + store <4 x i32> %4693, ptr addrspace(3) %4689, align 16, !dbg !192 + %4694 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096), i32 %4597, !dbg !192 + %4695 = insertelement <4 x i32> poison, i32 %4655, i64 0, !dbg !192 + %4696 = insertelement <4 x i32> %4695, i32 %4656, i64 1, !dbg !192 + %4697 = insertelement <4 x i32> %4696, i32 %4657, i64 2, !dbg !192 + %4698 = insertelement <4 x i32> %4697, i32 %4658, i64 3, !dbg !192 + store <4 x i32> %4698, ptr addrspace(3) %4694, align 16, !dbg !192 + %4699 = shl nuw nsw i32 %23, 2, !dbg !193 + %4700 = shl i32 %22, 23, !dbg !194 + %4701 = shl nuw nsw i32 %24, 4, !dbg !195 + %4702 = or disjoint i32 %4701, %21, !dbg !196 + %4703 = shl nuw nsw i32 %24, 8, !dbg !197 + %4704 = shl nuw nsw i32 %21, 4, !dbg !198 + %4705 = or disjoint i32 %4703, %4704, !dbg !199 + %4706 = zext nneg i32 %4705 to i64, !dbg !200 + %4707 = getelementptr i32, ptr addrspace(1) %11, i64 %4706, !dbg !200 + %4708 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b"(ptr addrspace(1) %4707, i1 true) #3, !dbg !201 + %4709 = shl i32 %4708, 7, !dbg !202 + %4710 = zext nneg i32 %4702 to i64, !dbg !203 + %4711 = getelementptr i32, ptr addrspace(1) %10, i64 %4710, !dbg !203 + %4712 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b"(ptr addrspace(1) %4711, i1 true) #3, !dbg !204 + %4713 = and i32 %35, 3, !dbg !205 + %4714 = shl nuw nsw i32 %4713, 1, !dbg !205 + %4715 = or disjoint i32 %4714, 8, !dbg !205 + %4716 = or disjoint i32 %4714, 16, !dbg !205 + %4717 = or disjoint i32 %4714, 24, !dbg !205 + %4718 = or disjoint i32 %4714, 32, !dbg !205 + %4719 = or disjoint i32 %4714, 40, !dbg !205 + %4720 = or disjoint i32 %4714, 48, !dbg !205 + %4721 = or disjoint i32 %4714, 56, !dbg !205 + %4722 = or disjoint i32 %4709, %38, !dbg !206 + %4723 = or disjoint i32 %4709, %39, !dbg !206 + %4724 = or disjoint i32 %4709, %40, !dbg !206 + %4725 = or disjoint i32 %4709, %41, !dbg !206 + %4726 = or disjoint i32 %4709, %4714, !dbg !206 + %4727 = or disjoint i32 %4726, 1, !dbg !206 + %4728 = or disjoint i32 %4709, %4715, !dbg !206 + %4729 = or disjoint i32 %4726, 9, !dbg !206 + %4730 = or disjoint i32 %4709, %4716, !dbg !206 + %4731 = or disjoint i32 %4726, 17, !dbg !206 + %4732 = or disjoint i32 %4709, %4717, !dbg !206 + %4733 = or disjoint i32 %4726, 25, !dbg !206 + %4734 = or disjoint i32 %4709, %4718, !dbg !206 + %4735 = or disjoint i32 %4726, 33, !dbg !206 + %4736 = or disjoint i32 %4709, %4719, !dbg !206 + %4737 = or disjoint i32 %4726, 41, !dbg !206 + %4738 = or disjoint i32 %4709, %4720, !dbg !206 + %4739 = or disjoint i32 %4726, 49, !dbg !206 + %4740 = or disjoint i32 %4709, %4721, !dbg !206 + %4741 = or disjoint i32 %4726, 57, !dbg !206 + %4742 = shl i32 %4722, 12, !dbg !207 + %4743 = shl i32 %4723, 12, !dbg !207 + %4744 = shl i32 %4724, 12, !dbg !207 + %4745 = shl i32 %4725, 12, !dbg !207 + %4746 = shl i32 %4722, 7, !dbg !209 + %4747 = shl i32 %4723, 7, !dbg !209 + %4748 = shl i32 %4724, 7, !dbg !209 + %4749 = shl i32 %4725, 7, !dbg !209 + %4750 = shl i32 %4712, 1, !dbg !210 + %4751 = tail call i32 @llvm.smin.i32(i32 %4750, i32 32), !dbg !211 + %4752 = zext nneg i32 %22 to i64, !dbg !212 + %4753 = getelementptr i64, ptr addrspace(1) %16, i64 %4752, !dbg !212 + %4754 = icmp sgt i32 %4750, 0, !dbg !213 + %4755 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09@$2 ld.global.b64 { $0 }, [ $1 + 0 ];", "=l,l,b"(ptr addrspace(1) %4753, i1 %4754) #3, !dbg !214 + %4756 = getelementptr i32, ptr addrspace(1) %15, i64 %4706, !dbg !215 + %4757 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b"(ptr addrspace(1) %4756, i1 true) #3, !dbg !216 + %4758 = shl i32 %4757, 7, !dbg !217 + %4759 = getelementptr i32, ptr addrspace(1) %14, i64 %4710, !dbg !218 + %4760 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b"(ptr addrspace(1) %4759, i1 true) #3, !dbg !219 + %4761 = or disjoint i32 %4758, %38, !dbg !220 + %4762 = or disjoint i32 %4758, %39, !dbg !220 + %4763 = or disjoint i32 %4758, %40, !dbg !220 + %4764 = or disjoint i32 %4758, %41, !dbg !220 + %4765 = or disjoint i32 %4758, %4714, !dbg !220 + %4766 = or disjoint i32 %4758, %4715, !dbg !220 + %4767 = or disjoint i32 %4758, %4716, !dbg !220 + %4768 = or disjoint i32 %4758, %4717, !dbg !220 + %4769 = or disjoint i32 %4758, %4718, !dbg !220 + %4770 = or disjoint i32 %4758, %4719, !dbg !220 + %4771 = or disjoint i32 %4758, %4720, !dbg !220 + %4772 = or disjoint i32 %4758, %4721, !dbg !220 + %4773 = shl i32 %4761, 12, !dbg !221 + %4774 = shl i32 %4762, 12, !dbg !221 + %4775 = shl i32 %4763, 12, !dbg !221 + %4776 = shl i32 %4764, 12, !dbg !221 + %4777 = shl i32 %4761, 7, !dbg !223 + %4778 = shl i32 %4762, 7, !dbg !223 + %4779 = shl i32 %4763, 7, !dbg !223 + %4780 = shl i32 %4764, 7, !dbg !223 + %4781 = shl i32 %4760, 1, !dbg !224 + %4782 = tail call i32 @llvm.smin.i32(i32 %4781, i32 32), !dbg !225 + tail call void asm sideeffect "fence.proxy.async.shared::cta;", ""() #3, !dbg !226 + %4783 = shl nuw i32 %22, 16 + %4784 = sext i32 %4742 to i64 + %4785 = sext i32 %4743 to i64 + %4786 = sext i32 %4744 to i64 + %4787 = sext i32 %4745 to i64 + %4788 = sext i32 %4746 to i64 + %4789 = sext i32 %4747 to i64 + %4790 = sext i32 %4748 to i64 + %4791 = sext i32 %4749 to i64 + %4792 = shl nuw nsw i32 %4551, 10 + %4793 = or disjoint i32 %4554, %4792 + %4794 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %4793 + %4795 = select i1 %4754, i32 16, i32 0 + %4796 = or disjoint i32 %4793, 2048 + %4797 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %4796 + %4798 = or disjoint i32 %4793, 4096 + %4799 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %4798 + %4800 = or disjoint i32 %4793, 6144 + %4801 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %4800 + %4802 = sext i32 %4726 to i64 + %4803 = sext i32 %4728 to i64 + %4804 = sext i32 %4730 to i64 + %4805 = sext i32 %4732 to i64 + %4806 = sext i32 %4734 to i64 + %4807 = sext i32 %4736 to i64 + %4808 = sext i32 %4738 to i64 + %4809 = sext i32 %4740 to i64 + %4810 = and i32 %35, 252 + %4811 = icmp eq i32 %4810, 0 + %4812 = shl nuw nsw i32 %4713, 3 + %4813 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %4812 + %4814 = select i1 %4754, i32 8, i32 0 + %4815 = or disjoint i32 %4812, 32 + %4816 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %4815 + %4817 = or disjoint i32 %4812, 64 + %4818 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %4817 + %4819 = or disjoint i32 %4812, 96 + %4820 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %4819 + %4821 = or disjoint i32 %4812, 128 + %4822 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %4821 + %4823 = or disjoint i32 %4812, 160 + %4824 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %4823 + %4825 = or disjoint i32 %4812, 192 + %4826 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %4825 + %4827 = or disjoint i32 %4812, 224 + %4828 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %4827 + %4829 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %4793 + %4830 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %4796 + %4831 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %4798 + %4832 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %4800 + %4833 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %4812 + %4834 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %4815 + %4835 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %4817 + %4836 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %4819 + %4837 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %4821 + %4838 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %4823 + %4839 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %4825 + %4840 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %4827 + %4841 = icmp sgt i32 %4750, 1 + %4842 = or disjoint i32 %4726, 64 + %4843 = or disjoint i32 %4728, 64 + %4844 = or disjoint i32 %4730, 64 + %4845 = or disjoint i32 %4732, 64 + %4846 = or disjoint i32 %4734, 64 + %4847 = or disjoint i32 %4736, 64 + %4848 = or disjoint i32 %4738, 64 + %4849 = or disjoint i32 %4740, 64 + %4850 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 16384), i32 %4793 + %4851 = select i1 %4841, i32 16, i32 0 + %4852 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 16384), i32 %4796 + %4853 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 16384), i32 %4798 + %4854 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 16384), i32 %4800 + %4855 = sext i32 %4842 to i64 + %4856 = sext i32 %4843 to i64 + %4857 = sext i32 %4844 to i64 + %4858 = sext i32 %4845 to i64 + %4859 = sext i32 %4846 to i64 + %4860 = sext i32 %4847 to i64 + %4861 = sext i32 %4848 to i64 + %4862 = sext i32 %4849 to i64 + %4863 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %4812 + %4864 = select i1 %4841, i32 8, i32 0 + %4865 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %4815 + %4866 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %4817 + %4867 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %4819 + %4868 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %4821 + %4869 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %4823 + %4870 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %4825 + %4871 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %4827 + %4872 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 65536), i32 %4793 + %4873 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 65536), i32 %4796 + %4874 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 65536), i32 %4798 + %4875 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 65536), i32 %4800 + %4876 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %4812 + %4877 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %4815 + %4878 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %4817 + %4879 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %4819 + %4880 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %4821 + %4881 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %4823 + %4882 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %4825 + %4883 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %4827 + %4884 = add i32 %4751, -2 + %4885 = add nsw i32 %4751, -1 + %4886 = sext i32 %4773 to i64 + %4887 = sext i32 %4774 to i64 + %4888 = sext i32 %4775 to i64 + %4889 = sext i32 %4776 to i64 + %4890 = sext i32 %4777 to i64 + %4891 = sext i32 %4778 to i64 + %4892 = sext i32 %4779 to i64 + %4893 = sext i32 %4780 to i64 + %4894 = icmp sgt i32 %4781, 0 + %4895 = select i1 %4894, i32 16, i32 0 + %4896 = sext i32 %4765 to i64 + %4897 = sext i32 %4766 to i64 + %4898 = sext i32 %4767 to i64 + %4899 = sext i32 %4768 to i64 + %4900 = sext i32 %4769 to i64 + %4901 = sext i32 %4770 to i64 + %4902 = sext i32 %4771 to i64 + %4903 = sext i32 %4772 to i64 + %4904 = select i1 %4894, i32 8, i32 0 + %4905 = icmp sgt i32 %4781, 1 + %4906 = or disjoint i32 %4765, 64 + %4907 = or disjoint i32 %4766, 64 + %4908 = or disjoint i32 %4767, 64 + %4909 = or disjoint i32 %4768, 64 + %4910 = or disjoint i32 %4769, 64 + %4911 = or disjoint i32 %4770, 64 + %4912 = or disjoint i32 %4771, 64 + %4913 = or disjoint i32 %4772, 64 + %4914 = select i1 %4905, i32 16, i32 0 + %4915 = sext i32 %4906 to i64 + %4916 = sext i32 %4907 to i64 + %4917 = sext i32 %4908 to i64 + %4918 = sext i32 %4909 to i64 + %4919 = sext i32 %4910 to i64 + %4920 = sext i32 %4911 to i64 + %4921 = sext i32 %4912 to i64 + %4922 = sext i32 %4913 to i64 + %4923 = select i1 %4905, i32 8, i32 0 + %4924 = add i32 %4782, -2 + %4925 = add nsw i32 %4782, -1 + %smax = tail call i32 @llvm.smax.i32(i32 %4751, i32 1), !dbg !227 + %smax2121 = tail call i32 @llvm.smax.i32(i32 %4782, i32 1), !dbg !227 + %4926 = zext nneg i32 %4699 to i64, !dbg !227 + %4927 = insertelement <16 x i32> poison, i32 %4741, i64 0 + %4928 = insertelement <16 x i32> %4927, i32 %4740, i64 1 + %4929 = insertelement <16 x i32> %4928, i32 %4739, i64 2 + %4930 = insertelement <16 x i32> %4929, i32 %4738, i64 3 + %4931 = insertelement <16 x i32> %4930, i32 %4737, i64 4 + %4932 = insertelement <16 x i32> %4931, i32 %4736, i64 5 + %4933 = insertelement <16 x i32> %4932, i32 %4735, i64 6 + %4934 = insertelement <16 x i32> %4933, i32 %4734, i64 7 + %4935 = insertelement <16 x i32> %4934, i32 %4733, i64 8 + %4936 = insertelement <16 x i32> %4935, i32 %4732, i64 9 + %4937 = insertelement <16 x i32> %4936, i32 %4731, i64 10 + %4938 = insertelement <16 x i32> %4937, i32 %4730, i64 11 + %4939 = insertelement <16 x i32> %4938, i32 %4729, i64 12 + %4940 = insertelement <16 x i32> %4939, i32 %4728, i64 13 + %4941 = insertelement <16 x i32> %4940, i32 %4727, i64 14 + %4942 = insertelement <16 x i32> %4941, i32 %4726, i64 15 + %4943 = insertelement <16 x i64> poison, i64 %4755, i64 0 + %4944 = shufflevector <16 x i64> %4943, <16 x i64> poison, <16 x i32> zeroinitializer + br label %4945, !dbg !227 + +4945: ; preds = %4460, %._crit_edge1692 + %indvars.iv = phi i64 [ 0, %4460 ], [ %indvars.iv.next, %._crit_edge1692 ] + %4946 = phi float [ 0.000000e+00, %4460 ], [ %9744, %._crit_edge1692 ] + %4947 = phi float [ 0.000000e+00, %4460 ], [ %9745, %._crit_edge1692 ] + %4948 = phi float [ 0.000000e+00, %4460 ], [ %9746, %._crit_edge1692 ] + %4949 = phi float [ 0.000000e+00, %4460 ], [ %9747, %._crit_edge1692 ] + %4950 = phi float [ 0.000000e+00, %4460 ], [ %9748, %._crit_edge1692 ] + %4951 = phi float [ 0.000000e+00, %4460 ], [ %9749, %._crit_edge1692 ] + %4952 = phi float [ 0.000000e+00, %4460 ], [ %9750, %._crit_edge1692 ] + %4953 = phi float [ 0.000000e+00, %4460 ], [ %9751, %._crit_edge1692 ] + %4954 = phi float [ 0.000000e+00, %4460 ], [ %9752, %._crit_edge1692 ] + %4955 = phi float [ 0.000000e+00, %4460 ], [ %9753, %._crit_edge1692 ] + %4956 = phi float [ 0.000000e+00, %4460 ], [ %9754, %._crit_edge1692 ] + %4957 = phi float [ 0.000000e+00, %4460 ], [ %9755, %._crit_edge1692 ] + %4958 = phi float [ 0.000000e+00, %4460 ], [ %9756, %._crit_edge1692 ] + %4959 = phi float [ 0.000000e+00, %4460 ], [ %9757, %._crit_edge1692 ] + %4960 = phi float [ 0.000000e+00, %4460 ], [ %9758, %._crit_edge1692 ] + %4961 = phi float [ 0.000000e+00, %4460 ], [ %9759, %._crit_edge1692 ] + %4962 = phi float [ 0.000000e+00, %4460 ], [ %9760, %._crit_edge1692 ] + %4963 = phi float [ 0.000000e+00, %4460 ], [ %9761, %._crit_edge1692 ] + %4964 = phi float [ 0.000000e+00, %4460 ], [ %9762, %._crit_edge1692 ] + %4965 = phi float [ 0.000000e+00, %4460 ], [ %9763, %._crit_edge1692 ] + %4966 = phi float [ 0.000000e+00, %4460 ], [ %9764, %._crit_edge1692 ] + %4967 = phi float [ 0.000000e+00, %4460 ], [ %9765, %._crit_edge1692 ] + %4968 = phi float [ 0.000000e+00, %4460 ], [ %9766, %._crit_edge1692 ] + %4969 = phi float [ 0.000000e+00, %4460 ], [ %9767, %._crit_edge1692 ] + %4970 = phi float [ 0.000000e+00, %4460 ], [ %9768, %._crit_edge1692 ] + %4971 = phi float [ 0.000000e+00, %4460 ], [ %9769, %._crit_edge1692 ] + %4972 = phi float [ 0.000000e+00, %4460 ], [ %9770, %._crit_edge1692 ] + %4973 = phi float [ 0.000000e+00, %4460 ], [ %9771, %._crit_edge1692 ] + %4974 = phi float [ 0.000000e+00, %4460 ], [ %9772, %._crit_edge1692 ] + %4975 = phi float [ 0.000000e+00, %4460 ], [ %9773, %._crit_edge1692 ] + %4976 = phi float [ 0.000000e+00, %4460 ], [ %9774, %._crit_edge1692 ] + %4977 = phi float [ 0.000000e+00, %4460 ], [ %9775, %._crit_edge1692 ] + %4978 = phi float [ 0.000000e+00, %4460 ], [ %9776, %._crit_edge1692 ] + %4979 = phi float [ 0.000000e+00, %4460 ], [ %9777, %._crit_edge1692 ] + %4980 = phi float [ 0.000000e+00, %4460 ], [ %9778, %._crit_edge1692 ] + %4981 = phi float [ 0.000000e+00, %4460 ], [ %9779, %._crit_edge1692 ] + %4982 = phi float [ 0.000000e+00, %4460 ], [ %9780, %._crit_edge1692 ] + %4983 = phi float [ 0.000000e+00, %4460 ], [ %9781, %._crit_edge1692 ] + %4984 = phi float [ 0.000000e+00, %4460 ], [ %9782, %._crit_edge1692 ] + %4985 = phi float [ 0.000000e+00, %4460 ], [ %9783, %._crit_edge1692 ] + %4986 = phi float [ 0.000000e+00, %4460 ], [ %9784, %._crit_edge1692 ] + %4987 = phi float [ 0.000000e+00, %4460 ], [ %9785, %._crit_edge1692 ] + %4988 = phi float [ 0.000000e+00, %4460 ], [ %9786, %._crit_edge1692 ] + %4989 = phi float [ 0.000000e+00, %4460 ], [ %9787, %._crit_edge1692 ] + %4990 = phi float [ 0.000000e+00, %4460 ], [ %9788, %._crit_edge1692 ] + %4991 = phi float [ 0.000000e+00, %4460 ], [ %9789, %._crit_edge1692 ] + %4992 = phi float [ 0.000000e+00, %4460 ], [ %9790, %._crit_edge1692 ] + %4993 = phi float [ 0.000000e+00, %4460 ], [ %9791, %._crit_edge1692 ] + %4994 = phi float [ 0.000000e+00, %4460 ], [ %9792, %._crit_edge1692 ] + %4995 = phi float [ 0.000000e+00, %4460 ], [ %9793, %._crit_edge1692 ] + %4996 = phi float [ 0.000000e+00, %4460 ], [ %9794, %._crit_edge1692 ] + %4997 = phi float [ 0.000000e+00, %4460 ], [ %9795, %._crit_edge1692 ] + %4998 = phi float [ 0.000000e+00, %4460 ], [ %9796, %._crit_edge1692 ] + %4999 = phi float [ 0.000000e+00, %4460 ], [ %9797, %._crit_edge1692 ] + %5000 = phi float [ 0.000000e+00, %4460 ], [ %9798, %._crit_edge1692 ] + %5001 = phi float [ 0.000000e+00, %4460 ], [ %9799, %._crit_edge1692 ] + %5002 = phi float [ 0.000000e+00, %4460 ], [ %9800, %._crit_edge1692 ] + %5003 = phi float [ 0.000000e+00, %4460 ], [ %9801, %._crit_edge1692 ] + %5004 = phi float [ 0.000000e+00, %4460 ], [ %9802, %._crit_edge1692 ] + %5005 = phi float [ 0.000000e+00, %4460 ], [ %9803, %._crit_edge1692 ] + %5006 = phi float [ 0.000000e+00, %4460 ], [ %9804, %._crit_edge1692 ] + %5007 = phi float [ 0.000000e+00, %4460 ], [ %9805, %._crit_edge1692 ] + %5008 = phi float [ 0.000000e+00, %4460 ], [ %9806, %._crit_edge1692 ] + %5009 = phi float [ 0.000000e+00, %4460 ], [ %9807, %._crit_edge1692 ] + %5010 = phi float [ 0.000000e+00, %4460 ], [ %9680, %._crit_edge1692 ] + %5011 = phi float [ 0.000000e+00, %4460 ], [ %9681, %._crit_edge1692 ] + %5012 = phi float [ 0.000000e+00, %4460 ], [ %9682, %._crit_edge1692 ] + %5013 = phi float [ 0.000000e+00, %4460 ], [ %9683, %._crit_edge1692 ] + %5014 = phi float [ 0.000000e+00, %4460 ], [ %9684, %._crit_edge1692 ] + %5015 = phi float [ 0.000000e+00, %4460 ], [ %9685, %._crit_edge1692 ] + %5016 = phi float [ 0.000000e+00, %4460 ], [ %9686, %._crit_edge1692 ] + %5017 = phi float [ 0.000000e+00, %4460 ], [ %9687, %._crit_edge1692 ] + %5018 = phi float [ 0.000000e+00, %4460 ], [ %9688, %._crit_edge1692 ] + %5019 = phi float [ 0.000000e+00, %4460 ], [ %9689, %._crit_edge1692 ] + %5020 = phi float [ 0.000000e+00, %4460 ], [ %9690, %._crit_edge1692 ] + %5021 = phi float [ 0.000000e+00, %4460 ], [ %9691, %._crit_edge1692 ] + %5022 = phi float [ 0.000000e+00, %4460 ], [ %9692, %._crit_edge1692 ] + %5023 = phi float [ 0.000000e+00, %4460 ], [ %9693, %._crit_edge1692 ] + %5024 = phi float [ 0.000000e+00, %4460 ], [ %9694, %._crit_edge1692 ] + %5025 = phi float [ 0.000000e+00, %4460 ], [ %9695, %._crit_edge1692 ] + %5026 = phi float [ 0.000000e+00, %4460 ], [ %9696, %._crit_edge1692 ] + %5027 = phi float [ 0.000000e+00, %4460 ], [ %9697, %._crit_edge1692 ] + %5028 = phi float [ 0.000000e+00, %4460 ], [ %9698, %._crit_edge1692 ] + %5029 = phi float [ 0.000000e+00, %4460 ], [ %9699, %._crit_edge1692 ] + %5030 = phi float [ 0.000000e+00, %4460 ], [ %9700, %._crit_edge1692 ] + %5031 = phi float [ 0.000000e+00, %4460 ], [ %9701, %._crit_edge1692 ] + %5032 = phi float [ 0.000000e+00, %4460 ], [ %9702, %._crit_edge1692 ] + %5033 = phi float [ 0.000000e+00, %4460 ], [ %9703, %._crit_edge1692 ] + %5034 = phi float [ 0.000000e+00, %4460 ], [ %9704, %._crit_edge1692 ] + %5035 = phi float [ 0.000000e+00, %4460 ], [ %9705, %._crit_edge1692 ] + %5036 = phi float [ 0.000000e+00, %4460 ], [ %9706, %._crit_edge1692 ] + %5037 = phi float [ 0.000000e+00, %4460 ], [ %9707, %._crit_edge1692 ] + %5038 = phi float [ 0.000000e+00, %4460 ], [ %9708, %._crit_edge1692 ] + %5039 = phi float [ 0.000000e+00, %4460 ], [ %9709, %._crit_edge1692 ] + %5040 = phi float [ 0.000000e+00, %4460 ], [ %9710, %._crit_edge1692 ] + %5041 = phi float [ 0.000000e+00, %4460 ], [ %9711, %._crit_edge1692 ] + %5042 = phi float [ 0.000000e+00, %4460 ], [ %9712, %._crit_edge1692 ] + %5043 = phi float [ 0.000000e+00, %4460 ], [ %9713, %._crit_edge1692 ] + %5044 = phi float [ 0.000000e+00, %4460 ], [ %9714, %._crit_edge1692 ] + %5045 = phi float [ 0.000000e+00, %4460 ], [ %9715, %._crit_edge1692 ] + %5046 = phi float [ 0.000000e+00, %4460 ], [ %9716, %._crit_edge1692 ] + %5047 = phi float [ 0.000000e+00, %4460 ], [ %9717, %._crit_edge1692 ] + %5048 = phi float [ 0.000000e+00, %4460 ], [ %9718, %._crit_edge1692 ] + %5049 = phi float [ 0.000000e+00, %4460 ], [ %9719, %._crit_edge1692 ] + %5050 = phi float [ 0.000000e+00, %4460 ], [ %9720, %._crit_edge1692 ] + %5051 = phi float [ 0.000000e+00, %4460 ], [ %9721, %._crit_edge1692 ] + %5052 = phi float [ 0.000000e+00, %4460 ], [ %9722, %._crit_edge1692 ] + %5053 = phi float [ 0.000000e+00, %4460 ], [ %9723, %._crit_edge1692 ] + %5054 = phi float [ 0.000000e+00, %4460 ], [ %9724, %._crit_edge1692 ] + %5055 = phi float [ 0.000000e+00, %4460 ], [ %9725, %._crit_edge1692 ] + %5056 = phi float [ 0.000000e+00, %4460 ], [ %9726, %._crit_edge1692 ] + %5057 = phi float [ 0.000000e+00, %4460 ], [ %9727, %._crit_edge1692 ] + %5058 = phi float [ 0.000000e+00, %4460 ], [ %9728, %._crit_edge1692 ] + %5059 = phi float [ 0.000000e+00, %4460 ], [ %9729, %._crit_edge1692 ] + %5060 = phi float [ 0.000000e+00, %4460 ], [ %9730, %._crit_edge1692 ] + %5061 = phi float [ 0.000000e+00, %4460 ], [ %9731, %._crit_edge1692 ] + %5062 = phi float [ 0.000000e+00, %4460 ], [ %9732, %._crit_edge1692 ] + %5063 = phi float [ 0.000000e+00, %4460 ], [ %9733, %._crit_edge1692 ] + %5064 = phi float [ 0.000000e+00, %4460 ], [ %9734, %._crit_edge1692 ] + %5065 = phi float [ 0.000000e+00, %4460 ], [ %9735, %._crit_edge1692 ] + %5066 = phi float [ 0.000000e+00, %4460 ], [ %9736, %._crit_edge1692 ] + %5067 = phi float [ 0.000000e+00, %4460 ], [ %9737, %._crit_edge1692 ] + %5068 = phi float [ 0.000000e+00, %4460 ], [ %9738, %._crit_edge1692 ] + %5069 = phi float [ 0.000000e+00, %4460 ], [ %9739, %._crit_edge1692 ] + %5070 = phi float [ 0.000000e+00, %4460 ], [ %9740, %._crit_edge1692 ] + %5071 = phi float [ 0.000000e+00, %4460 ], [ %9741, %._crit_edge1692 ] + %5072 = phi float [ 0.000000e+00, %4460 ], [ %9742, %._crit_edge1692 ] + %5073 = phi float [ 0.000000e+00, %4460 ], [ %9743, %._crit_edge1692 ] + %5074 = add nuw nsw i64 %indvars.iv, %4926, !dbg !228 + %.tr = trunc i64 %5074 to i32, !dbg !229 + %5075 = shl i32 %.tr, 7, !dbg !229 + %5076 = add i32 %5075, %4700, !dbg !229 + %5077 = sext i32 %5076 to i64, !dbg !230 + %5078 = trunc nuw nsw i64 %5074 to i32, !dbg !231 + %5079 = shl i32 %5078, 18, !dbg !231 + %5080 = add i32 %5079, %4700, !dbg !232 + %5081 = sext i32 %5080 to i64, !dbg !233 + %.tr2128 = trunc i64 %5074 to i32, !dbg !234 + %5082 = shl i32 %.tr2128, 11, !dbg !234 + %5083 = add i32 %5082, %4783, !dbg !234 + %5084 = sext i32 %5083 to i64, !dbg !235 + %5085 = getelementptr bfloat, ptr addrspace(1) %0, i64 %5077, !dbg !236 + %5086 = getelementptr bfloat, ptr addrspace(1) %5, i64 %5081, !dbg !237 + %5087 = getelementptr float, ptr addrspace(1) %3, i64 %5084, !dbg !238 + %5088 = getelementptr float, ptr addrspace(1) %4, i64 %5084, !dbg !239 + %5089 = getelementptr bfloat, ptr addrspace(1) %5085, i64 %4784, !dbg !240 + %5090 = getelementptr bfloat, ptr addrspace(1) %5085, i64 %4785, !dbg !240 + %5091 = getelementptr bfloat, ptr addrspace(1) %5085, i64 %4786, !dbg !240 + %5092 = getelementptr bfloat, ptr addrspace(1) %5085, i64 %4787, !dbg !240 + %5093 = getelementptr bfloat, ptr addrspace(1) %5089, i64 %4498, !dbg !241 + %5094 = getelementptr bfloat, ptr addrspace(1) %5090, i64 %4498, !dbg !241 + %5095 = getelementptr bfloat, ptr addrspace(1) %5091, i64 %4498, !dbg !241 + %5096 = getelementptr bfloat, ptr addrspace(1) %5092, i64 %4498, !dbg !241 + %5097 = getelementptr bfloat, ptr addrspace(1) %5086, i64 %4788, !dbg !242 + %5098 = getelementptr bfloat, ptr addrspace(1) %5086, i64 %4789, !dbg !242 + %5099 = getelementptr bfloat, ptr addrspace(1) %5086, i64 %4790, !dbg !242 + %5100 = getelementptr bfloat, ptr addrspace(1) %5086, i64 %4791, !dbg !242 + %5101 = getelementptr bfloat, ptr addrspace(1) %5097, i64 %4498, !dbg !243 + %5102 = getelementptr bfloat, ptr addrspace(1) %5098, i64 %4498, !dbg !243 + %5103 = getelementptr bfloat, ptr addrspace(1) %5099, i64 %4498, !dbg !243 + %5104 = getelementptr bfloat, ptr addrspace(1) %5100, i64 %4498, !dbg !243 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %4794, ptr addrspace(1) %5093, i32 %4795) #3, !dbg !244 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4797, ptr addrspace(1) %5094, i32 %4795) #3, !dbg !244 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4799, ptr addrspace(1) %5095, i32 %4795) #3, !dbg !244 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4801, ptr addrspace(1) %5096, i32 %4795) #3, !dbg !244 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !244 + %5105 = getelementptr float, ptr addrspace(1) %5087, i64 %4802, !dbg !245 + %5106 = getelementptr float, ptr addrspace(1) %5087, i64 %4803, !dbg !245 + %5107 = getelementptr float, ptr addrspace(1) %5087, i64 %4804, !dbg !245 + %5108 = getelementptr float, ptr addrspace(1) %5087, i64 %4805, !dbg !245 + %5109 = getelementptr float, ptr addrspace(1) %5087, i64 %4806, !dbg !245 + %5110 = getelementptr float, ptr addrspace(1) %5087, i64 %4807, !dbg !245 + %5111 = getelementptr float, ptr addrspace(1) %5087, i64 %4808, !dbg !245 + %5112 = getelementptr float, ptr addrspace(1) %5087, i64 %4809, !dbg !245 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) %4813, ptr addrspace(1) %5105, i32 %4814, i1 %4811) #3, !dbg !246 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4816, ptr addrspace(1) %5106, i32 %4814, i1 %4811) #3, !dbg !246 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4818, ptr addrspace(1) %5107, i32 %4814, i1 %4811) #3, !dbg !246 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4820, ptr addrspace(1) %5108, i32 %4814, i1 %4811) #3, !dbg !246 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4822, ptr addrspace(1) %5109, i32 %4814, i1 %4811) #3, !dbg !246 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4824, ptr addrspace(1) %5110, i32 %4814, i1 %4811) #3, !dbg !246 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4826, ptr addrspace(1) %5111, i32 %4814, i1 %4811) #3, !dbg !246 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4828, ptr addrspace(1) %5112, i32 %4814, i1 %4811) #3, !dbg !246 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !246 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %4829, ptr addrspace(1) %5101, i32 %4795) #3, !dbg !244 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4830, ptr addrspace(1) %5102, i32 %4795) #3, !dbg !244 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4831, ptr addrspace(1) %5103, i32 %4795) #3, !dbg !244 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4832, ptr addrspace(1) %5104, i32 %4795) #3, !dbg !244 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !244 + %5113 = getelementptr float, ptr addrspace(1) %5088, i64 %4802, !dbg !247 + %5114 = getelementptr float, ptr addrspace(1) %5088, i64 %4803, !dbg !247 + %5115 = getelementptr float, ptr addrspace(1) %5088, i64 %4804, !dbg !247 + %5116 = getelementptr float, ptr addrspace(1) %5088, i64 %4805, !dbg !247 + %5117 = getelementptr float, ptr addrspace(1) %5088, i64 %4806, !dbg !247 + %5118 = getelementptr float, ptr addrspace(1) %5088, i64 %4807, !dbg !247 + %5119 = getelementptr float, ptr addrspace(1) %5088, i64 %4808, !dbg !247 + %5120 = getelementptr float, ptr addrspace(1) %5088, i64 %4809, !dbg !247 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) %4833, ptr addrspace(1) %5113, i32 %4814, i1 %4811) #3, !dbg !248 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4834, ptr addrspace(1) %5114, i32 %4814, i1 %4811) #3, !dbg !248 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4835, ptr addrspace(1) %5115, i32 %4814, i1 %4811) #3, !dbg !248 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4836, ptr addrspace(1) %5116, i32 %4814, i1 %4811) #3, !dbg !248 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4837, ptr addrspace(1) %5117, i32 %4814, i1 %4811) #3, !dbg !248 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4838, ptr addrspace(1) %5118, i32 %4814, i1 %4811) #3, !dbg !248 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4839, ptr addrspace(1) %5119, i32 %4814, i1 %4811) #3, !dbg !248 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4840, ptr addrspace(1) %5120, i32 %4814, i1 %4811) #3, !dbg !248 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !248 + %5121 = getelementptr i8, ptr addrspace(1) %5093, i64 524288, !dbg !249 + %5122 = getelementptr i8, ptr addrspace(1) %5094, i64 524288, !dbg !249 + %5123 = getelementptr i8, ptr addrspace(1) %5095, i64 524288, !dbg !249 + %5124 = getelementptr i8, ptr addrspace(1) %5096, i64 524288, !dbg !249 + %5125 = getelementptr i8, ptr addrspace(1) %5101, i64 16384, !dbg !250 + %5126 = getelementptr i8, ptr addrspace(1) %5102, i64 16384, !dbg !250 + %5127 = getelementptr i8, ptr addrspace(1) %5103, i64 16384, !dbg !250 + %5128 = getelementptr i8, ptr addrspace(1) %5104, i64 16384, !dbg !250 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !244 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %4850, ptr addrspace(1) %5121, i32 %4851) #3, !dbg !244 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4852, ptr addrspace(1) %5122, i32 %4851) #3, !dbg !244 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4853, ptr addrspace(1) %5123, i32 %4851) #3, !dbg !244 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4854, ptr addrspace(1) %5124, i32 %4851) #3, !dbg !244 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !244 + %5129 = getelementptr float, ptr addrspace(1) %5087, i64 %4855, !dbg !245 + %5130 = getelementptr float, ptr addrspace(1) %5087, i64 %4856, !dbg !245 + %5131 = getelementptr float, ptr addrspace(1) %5087, i64 %4857, !dbg !245 + %5132 = getelementptr float, ptr addrspace(1) %5087, i64 %4858, !dbg !245 + %5133 = getelementptr float, ptr addrspace(1) %5087, i64 %4859, !dbg !245 + %5134 = getelementptr float, ptr addrspace(1) %5087, i64 %4860, !dbg !245 + %5135 = getelementptr float, ptr addrspace(1) %5087, i64 %4861, !dbg !245 + %5136 = getelementptr float, ptr addrspace(1) %5087, i64 %4862, !dbg !245 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) %4863, ptr addrspace(1) %5129, i32 %4864, i1 %4811) #3, !dbg !246 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4865, ptr addrspace(1) %5130, i32 %4864, i1 %4811) #3, !dbg !246 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4866, ptr addrspace(1) %5131, i32 %4864, i1 %4811) #3, !dbg !246 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4867, ptr addrspace(1) %5132, i32 %4864, i1 %4811) #3, !dbg !246 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4868, ptr addrspace(1) %5133, i32 %4864, i1 %4811) #3, !dbg !246 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4869, ptr addrspace(1) %5134, i32 %4864, i1 %4811) #3, !dbg !246 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4870, ptr addrspace(1) %5135, i32 %4864, i1 %4811) #3, !dbg !246 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4871, ptr addrspace(1) %5136, i32 %4864, i1 %4811) #3, !dbg !246 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !246 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %4872, ptr addrspace(1) %5125, i32 %4851) #3, !dbg !244 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4873, ptr addrspace(1) %5126, i32 %4851) #3, !dbg !244 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4874, ptr addrspace(1) %5127, i32 %4851) #3, !dbg !244 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4875, ptr addrspace(1) %5128, i32 %4851) #3, !dbg !244 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !244 + %5137 = getelementptr float, ptr addrspace(1) %5088, i64 %4855, !dbg !247 + %5138 = getelementptr float, ptr addrspace(1) %5088, i64 %4856, !dbg !247 + %5139 = getelementptr float, ptr addrspace(1) %5088, i64 %4857, !dbg !247 + %5140 = getelementptr float, ptr addrspace(1) %5088, i64 %4858, !dbg !247 + %5141 = getelementptr float, ptr addrspace(1) %5088, i64 %4859, !dbg !247 + %5142 = getelementptr float, ptr addrspace(1) %5088, i64 %4860, !dbg !247 + %5143 = getelementptr float, ptr addrspace(1) %5088, i64 %4861, !dbg !247 + %5144 = getelementptr float, ptr addrspace(1) %5088, i64 %4862, !dbg !247 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) %4876, ptr addrspace(1) %5137, i32 %4864, i1 %4811) #3, !dbg !248 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4877, ptr addrspace(1) %5138, i32 %4864, i1 %4811) #3, !dbg !248 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4878, ptr addrspace(1) %5139, i32 %4864, i1 %4811) #3, !dbg !248 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4879, ptr addrspace(1) %5140, i32 %4864, i1 %4811) #3, !dbg !248 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4880, ptr addrspace(1) %5141, i32 %4864, i1 %4811) #3, !dbg !248 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4881, ptr addrspace(1) %5142, i32 %4864, i1 %4811) #3, !dbg !248 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4882, ptr addrspace(1) %5143, i32 %4864, i1 %4811) #3, !dbg !248 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4883, ptr addrspace(1) %5144, i32 %4864, i1 %4811) #3, !dbg !248 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !248 + br i1 %4754, label %.lr.ph, label %._crit_edge, !dbg !213 + +.lr.ph: ; preds = %4945, %__nv_exp2f.exit1315 + %5145 = phi i32 [ %7352, %__nv_exp2f.exit1315 ], [ 64, %4945 ] + %5146 = phi i32 [ %5284, %__nv_exp2f.exit1315 ], [ -1, %4945 ] + %5147 = phi i32 [ %7375, %__nv_exp2f.exit1315 ], [ 1, %4945 ] + %5148 = phi i32 [ %5287, %__nv_exp2f.exit1315 ], [ -1, %4945 ] + %5149 = phi i32 [ %7378, %__nv_exp2f.exit1315 ], [ 1, %4945 ] + %.pn1881528 = phi i32 [ %7372, %__nv_exp2f.exit1315 ], [ %4849, %4945 ] + %.pn1921527 = phi i32 [ %7371, %__nv_exp2f.exit1315 ], [ %4848, %4945 ] + %.pn1961526 = phi i32 [ %7370, %__nv_exp2f.exit1315 ], [ %4847, %4945 ] + %.pn2001525 = phi i32 [ %7369, %__nv_exp2f.exit1315 ], [ %4846, %4945 ] + %.pn2041524 = phi i32 [ %7368, %__nv_exp2f.exit1315 ], [ %4845, %4945 ] + %.pn2081523 = phi i32 [ %7367, %__nv_exp2f.exit1315 ], [ %4844, %4945 ] + %.pn2121522 = phi i32 [ %7366, %__nv_exp2f.exit1315 ], [ %4843, %4945 ] + %.pn2161521 = phi i32 [ %7365, %__nv_exp2f.exit1315 ], [ %4842, %4945 ] + %.pn1361520 = phi ptr addrspace(1) [ %7364, %__nv_exp2f.exit1315 ], [ %5128, %4945 ] + %.pn1521519 = phi ptr addrspace(1) [ %7363, %__nv_exp2f.exit1315 ], [ %5127, %4945 ] + %.pn1681518 = phi ptr addrspace(1) [ %7362, %__nv_exp2f.exit1315 ], [ %5126, %4945 ] + %.pn1841517 = phi ptr addrspace(1) [ %7361, %__nv_exp2f.exit1315 ], [ %5125, %4945 ] + %.pn721516 = phi ptr addrspace(1) [ %7358, %__nv_exp2f.exit1315 ], [ %5124, %4945 ] + %.pn881515 = phi ptr addrspace(1) [ %7357, %__nv_exp2f.exit1315 ], [ %5123, %4945 ] + %.pn1041514 = phi ptr addrspace(1) [ %7356, %__nv_exp2f.exit1315 ], [ %5122, %4945 ] + %.pn1201513 = phi ptr addrspace(1) [ %7355, %__nv_exp2f.exit1315 ], [ %5121, %4945 ] + %5150 = phi float [ %6410, %__nv_exp2f.exit1315 ], [ %5010, %4945 ] + %5151 = phi float [ %6411, %__nv_exp2f.exit1315 ], [ %5011, %4945 ] + %5152 = phi float [ %6412, %__nv_exp2f.exit1315 ], [ %5012, %4945 ] + %5153 = phi float [ %6413, %__nv_exp2f.exit1315 ], [ %5013, %4945 ] + %5154 = phi float [ %6414, %__nv_exp2f.exit1315 ], [ %5014, %4945 ] + %5155 = phi float [ %6415, %__nv_exp2f.exit1315 ], [ %5015, %4945 ] + %5156 = phi float [ %6416, %__nv_exp2f.exit1315 ], [ %5016, %4945 ] + %5157 = phi float [ %6417, %__nv_exp2f.exit1315 ], [ %5017, %4945 ] + %5158 = phi float [ %6418, %__nv_exp2f.exit1315 ], [ %5018, %4945 ] + %5159 = phi float [ %6419, %__nv_exp2f.exit1315 ], [ %5019, %4945 ] + %5160 = phi float [ %6420, %__nv_exp2f.exit1315 ], [ %5020, %4945 ] + %5161 = phi float [ %6421, %__nv_exp2f.exit1315 ], [ %5021, %4945 ] + %5162 = phi float [ %6422, %__nv_exp2f.exit1315 ], [ %5022, %4945 ] + %5163 = phi float [ %6423, %__nv_exp2f.exit1315 ], [ %5023, %4945 ] + %5164 = phi float [ %6424, %__nv_exp2f.exit1315 ], [ %5024, %4945 ] + %5165 = phi float [ %6425, %__nv_exp2f.exit1315 ], [ %5025, %4945 ] + %5166 = phi float [ %6426, %__nv_exp2f.exit1315 ], [ %5026, %4945 ] + %5167 = phi float [ %6427, %__nv_exp2f.exit1315 ], [ %5027, %4945 ] + %5168 = phi float [ %6428, %__nv_exp2f.exit1315 ], [ %5028, %4945 ] + %5169 = phi float [ %6429, %__nv_exp2f.exit1315 ], [ %5029, %4945 ] + %5170 = phi float [ %6430, %__nv_exp2f.exit1315 ], [ %5030, %4945 ] + %5171 = phi float [ %6431, %__nv_exp2f.exit1315 ], [ %5031, %4945 ] + %5172 = phi float [ %6432, %__nv_exp2f.exit1315 ], [ %5032, %4945 ] + %5173 = phi float [ %6433, %__nv_exp2f.exit1315 ], [ %5033, %4945 ] + %5174 = phi float [ %6434, %__nv_exp2f.exit1315 ], [ %5034, %4945 ] + %5175 = phi float [ %6435, %__nv_exp2f.exit1315 ], [ %5035, %4945 ] + %5176 = phi float [ %6436, %__nv_exp2f.exit1315 ], [ %5036, %4945 ] + %5177 = phi float [ %6437, %__nv_exp2f.exit1315 ], [ %5037, %4945 ] + %5178 = phi float [ %6438, %__nv_exp2f.exit1315 ], [ %5038, %4945 ] + %5179 = phi float [ %6439, %__nv_exp2f.exit1315 ], [ %5039, %4945 ] + %5180 = phi float [ %6440, %__nv_exp2f.exit1315 ], [ %5040, %4945 ] + %5181 = phi float [ %6441, %__nv_exp2f.exit1315 ], [ %5041, %4945 ] + %5182 = phi float [ %6442, %__nv_exp2f.exit1315 ], [ %5042, %4945 ] + %5183 = phi float [ %6443, %__nv_exp2f.exit1315 ], [ %5043, %4945 ] + %5184 = phi float [ %6444, %__nv_exp2f.exit1315 ], [ %5044, %4945 ] + %5185 = phi float [ %6445, %__nv_exp2f.exit1315 ], [ %5045, %4945 ] + %5186 = phi float [ %6446, %__nv_exp2f.exit1315 ], [ %5046, %4945 ] + %5187 = phi float [ %6447, %__nv_exp2f.exit1315 ], [ %5047, %4945 ] + %5188 = phi float [ %6448, %__nv_exp2f.exit1315 ], [ %5048, %4945 ] + %5189 = phi float [ %6449, %__nv_exp2f.exit1315 ], [ %5049, %4945 ] + %5190 = phi float [ %6450, %__nv_exp2f.exit1315 ], [ %5050, %4945 ] + %5191 = phi float [ %6451, %__nv_exp2f.exit1315 ], [ %5051, %4945 ] + %5192 = phi float [ %6452, %__nv_exp2f.exit1315 ], [ %5052, %4945 ] + %5193 = phi float [ %6453, %__nv_exp2f.exit1315 ], [ %5053, %4945 ] + %5194 = phi float [ %6454, %__nv_exp2f.exit1315 ], [ %5054, %4945 ] + %5195 = phi float [ %6455, %__nv_exp2f.exit1315 ], [ %5055, %4945 ] + %5196 = phi float [ %6456, %__nv_exp2f.exit1315 ], [ %5056, %4945 ] + %5197 = phi float [ %6457, %__nv_exp2f.exit1315 ], [ %5057, %4945 ] + %5198 = phi float [ %6458, %__nv_exp2f.exit1315 ], [ %5058, %4945 ] + %5199 = phi float [ %6459, %__nv_exp2f.exit1315 ], [ %5059, %4945 ] + %5200 = phi float [ %6460, %__nv_exp2f.exit1315 ], [ %5060, %4945 ] + %5201 = phi float [ %6461, %__nv_exp2f.exit1315 ], [ %5061, %4945 ] + %5202 = phi float [ %6462, %__nv_exp2f.exit1315 ], [ %5062, %4945 ] + %5203 = phi float [ %6463, %__nv_exp2f.exit1315 ], [ %5063, %4945 ] + %5204 = phi float [ %6464, %__nv_exp2f.exit1315 ], [ %5064, %4945 ] + %5205 = phi float [ %6465, %__nv_exp2f.exit1315 ], [ %5065, %4945 ] + %5206 = phi float [ %6466, %__nv_exp2f.exit1315 ], [ %5066, %4945 ] + %5207 = phi float [ %6467, %__nv_exp2f.exit1315 ], [ %5067, %4945 ] + %5208 = phi float [ %6468, %__nv_exp2f.exit1315 ], [ %5068, %4945 ] + %5209 = phi float [ %6469, %__nv_exp2f.exit1315 ], [ %5069, %4945 ] + %5210 = phi float [ %6470, %__nv_exp2f.exit1315 ], [ %5070, %4945 ] + %5211 = phi float [ %6471, %__nv_exp2f.exit1315 ], [ %5071, %4945 ] + %5212 = phi float [ %6472, %__nv_exp2f.exit1315 ], [ %5072, %4945 ] + %5213 = phi float [ %6473, %__nv_exp2f.exit1315 ], [ %5073, %4945 ] + %5214 = phi float [ %7266, %__nv_exp2f.exit1315 ], [ %4946, %4945 ] + %5215 = phi float [ %7267, %__nv_exp2f.exit1315 ], [ %4947, %4945 ] + %5216 = phi float [ %7268, %__nv_exp2f.exit1315 ], [ %4948, %4945 ] + %5217 = phi float [ %7269, %__nv_exp2f.exit1315 ], [ %4949, %4945 ] + %5218 = phi float [ %7270, %__nv_exp2f.exit1315 ], [ %4950, %4945 ] + %5219 = phi float [ %7271, %__nv_exp2f.exit1315 ], [ %4951, %4945 ] + %5220 = phi float [ %7272, %__nv_exp2f.exit1315 ], [ %4952, %4945 ] + %5221 = phi float [ %7273, %__nv_exp2f.exit1315 ], [ %4953, %4945 ] + %5222 = phi float [ %7274, %__nv_exp2f.exit1315 ], [ %4954, %4945 ] + %5223 = phi float [ %7275, %__nv_exp2f.exit1315 ], [ %4955, %4945 ] + %5224 = phi float [ %7276, %__nv_exp2f.exit1315 ], [ %4956, %4945 ] + %5225 = phi float [ %7277, %__nv_exp2f.exit1315 ], [ %4957, %4945 ] + %5226 = phi float [ %7278, %__nv_exp2f.exit1315 ], [ %4958, %4945 ] + %5227 = phi float [ %7279, %__nv_exp2f.exit1315 ], [ %4959, %4945 ] + %5228 = phi float [ %7280, %__nv_exp2f.exit1315 ], [ %4960, %4945 ] + %5229 = phi float [ %7281, %__nv_exp2f.exit1315 ], [ %4961, %4945 ] + %5230 = phi float [ %7282, %__nv_exp2f.exit1315 ], [ %4962, %4945 ] + %5231 = phi float [ %7283, %__nv_exp2f.exit1315 ], [ %4963, %4945 ] + %5232 = phi float [ %7284, %__nv_exp2f.exit1315 ], [ %4964, %4945 ] + %5233 = phi float [ %7285, %__nv_exp2f.exit1315 ], [ %4965, %4945 ] + %5234 = phi float [ %7286, %__nv_exp2f.exit1315 ], [ %4966, %4945 ] + %5235 = phi float [ %7287, %__nv_exp2f.exit1315 ], [ %4967, %4945 ] + %5236 = phi float [ %7288, %__nv_exp2f.exit1315 ], [ %4968, %4945 ] + %5237 = phi float [ %7289, %__nv_exp2f.exit1315 ], [ %4969, %4945 ] + %5238 = phi float [ %7290, %__nv_exp2f.exit1315 ], [ %4970, %4945 ] + %5239 = phi float [ %7291, %__nv_exp2f.exit1315 ], [ %4971, %4945 ] + %5240 = phi float [ %7292, %__nv_exp2f.exit1315 ], [ %4972, %4945 ] + %5241 = phi float [ %7293, %__nv_exp2f.exit1315 ], [ %4973, %4945 ] + %5242 = phi float [ %7294, %__nv_exp2f.exit1315 ], [ %4974, %4945 ] + %5243 = phi float [ %7295, %__nv_exp2f.exit1315 ], [ %4975, %4945 ] + %5244 = phi float [ %7296, %__nv_exp2f.exit1315 ], [ %4976, %4945 ] + %5245 = phi float [ %7297, %__nv_exp2f.exit1315 ], [ %4977, %4945 ] + %5246 = phi float [ %7298, %__nv_exp2f.exit1315 ], [ %4978, %4945 ] + %5247 = phi float [ %7299, %__nv_exp2f.exit1315 ], [ %4979, %4945 ] + %5248 = phi float [ %7300, %__nv_exp2f.exit1315 ], [ %4980, %4945 ] + %5249 = phi float [ %7301, %__nv_exp2f.exit1315 ], [ %4981, %4945 ] + %5250 = phi float [ %7302, %__nv_exp2f.exit1315 ], [ %4982, %4945 ] + %5251 = phi float [ %7303, %__nv_exp2f.exit1315 ], [ %4983, %4945 ] + %5252 = phi float [ %7304, %__nv_exp2f.exit1315 ], [ %4984, %4945 ] + %5253 = phi float [ %7305, %__nv_exp2f.exit1315 ], [ %4985, %4945 ] + %5254 = phi float [ %7306, %__nv_exp2f.exit1315 ], [ %4986, %4945 ] + %5255 = phi float [ %7307, %__nv_exp2f.exit1315 ], [ %4987, %4945 ] + %5256 = phi float [ %7308, %__nv_exp2f.exit1315 ], [ %4988, %4945 ] + %5257 = phi float [ %7309, %__nv_exp2f.exit1315 ], [ %4989, %4945 ] + %5258 = phi float [ %7310, %__nv_exp2f.exit1315 ], [ %4990, %4945 ] + %5259 = phi float [ %7311, %__nv_exp2f.exit1315 ], [ %4991, %4945 ] + %5260 = phi float [ %7312, %__nv_exp2f.exit1315 ], [ %4992, %4945 ] + %5261 = phi float [ %7313, %__nv_exp2f.exit1315 ], [ %4993, %4945 ] + %5262 = phi float [ %7314, %__nv_exp2f.exit1315 ], [ %4994, %4945 ] + %5263 = phi float [ %7315, %__nv_exp2f.exit1315 ], [ %4995, %4945 ] + %5264 = phi float [ %7316, %__nv_exp2f.exit1315 ], [ %4996, %4945 ] + %5265 = phi float [ %7317, %__nv_exp2f.exit1315 ], [ %4997, %4945 ] + %5266 = phi float [ %7318, %__nv_exp2f.exit1315 ], [ %4998, %4945 ] + %5267 = phi float [ %7319, %__nv_exp2f.exit1315 ], [ %4999, %4945 ] + %5268 = phi float [ %7320, %__nv_exp2f.exit1315 ], [ %5000, %4945 ] + %5269 = phi float [ %7321, %__nv_exp2f.exit1315 ], [ %5001, %4945 ] + %5270 = phi float [ %7322, %__nv_exp2f.exit1315 ], [ %5002, %4945 ] + %5271 = phi float [ %7323, %__nv_exp2f.exit1315 ], [ %5003, %4945 ] + %5272 = phi float [ %7324, %__nv_exp2f.exit1315 ], [ %5004, %4945 ] + %5273 = phi float [ %7325, %__nv_exp2f.exit1315 ], [ %5005, %4945 ] + %5274 = phi float [ %7326, %__nv_exp2f.exit1315 ], [ %5006, %4945 ] + %5275 = phi float [ %7327, %__nv_exp2f.exit1315 ], [ %5007, %4945 ] + %5276 = phi float [ %7328, %__nv_exp2f.exit1315 ], [ %5008, %4945 ] + %5277 = phi float [ %7329, %__nv_exp2f.exit1315 ], [ %5009, %4945 ] + %5278 = phi i32 [ %7333, %__nv_exp2f.exit1315 ], [ 0, %4945 ] + %5279 = phi <16 x i32> [ %7332, %__nv_exp2f.exit1315 ], [ %4942, %4945 ] + %5280 = icmp slt i32 %5278, %4884, !dbg !213 + %5281 = icmp slt i32 %5278, %4885, !dbg !213 + %5282 = add i32 %5146, 1, !dbg !213 + %5283 = icmp sgt i32 %5282, 1, !dbg !213 + %5284 = select i1 %5283, i32 0, i32 %5282, !dbg !213 + %5285 = add i32 %5148, 1, !dbg !213 + %5286 = icmp sgt i32 %5285, 2, !dbg !213 + %5287 = select i1 %5286, i32 0, i32 %5285, !dbg !213 + tail call void @llvm.nvvm.cp.async.wait.group(i32 4), !dbg !244 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !244 + %5288 = shl i32 %5287, 13, !dbg !244 + %5289 = getelementptr bfloat, ptr addrspace(3) @global_smem, i32 %5288, !dbg !244 + %5290 = shl i32 %5284, 6, !dbg !246 + %5291 = getelementptr float, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %5290, !dbg !246 + %5292 = getelementptr inbounds nuw i8, ptr addrspace(3) %5291, i32 %4812, !dbg !246 + %5293 = load float, ptr addrspace(3) %5292, align 8, !dbg !246 + %5294 = getelementptr inbounds nuw i8, ptr addrspace(3) %5292, i32 4, !dbg !246 + %5295 = load float, ptr addrspace(3) %5294, align 4, !dbg !246 + %5296 = getelementptr inbounds nuw i8, ptr addrspace(3) %5291, i32 %4815, !dbg !246 + %5297 = load float, ptr addrspace(3) %5296, align 8, !dbg !246 + %5298 = getelementptr inbounds nuw i8, ptr addrspace(3) %5296, i32 4, !dbg !246 + %5299 = load float, ptr addrspace(3) %5298, align 4, !dbg !246 + %5300 = getelementptr inbounds nuw i8, ptr addrspace(3) %5291, i32 %4817, !dbg !246 + %5301 = load float, ptr addrspace(3) %5300, align 8, !dbg !246 + %5302 = getelementptr inbounds nuw i8, ptr addrspace(3) %5300, i32 4, !dbg !246 + %5303 = load float, ptr addrspace(3) %5302, align 4, !dbg !246 + %5304 = getelementptr inbounds nuw i8, ptr addrspace(3) %5291, i32 %4819, !dbg !246 + %5305 = load float, ptr addrspace(3) %5304, align 8, !dbg !246 + %5306 = getelementptr inbounds nuw i8, ptr addrspace(3) %5304, i32 4, !dbg !246 + %5307 = load float, ptr addrspace(3) %5306, align 4, !dbg !246 + %5308 = getelementptr inbounds nuw i8, ptr addrspace(3) %5291, i32 %4821, !dbg !246 + %5309 = load float, ptr addrspace(3) %5308, align 8, !dbg !246 + %5310 = getelementptr inbounds nuw i8, ptr addrspace(3) %5308, i32 4, !dbg !246 + %5311 = load float, ptr addrspace(3) %5310, align 4, !dbg !246 + %5312 = getelementptr inbounds nuw i8, ptr addrspace(3) %5291, i32 %4823, !dbg !246 + %5313 = load float, ptr addrspace(3) %5312, align 8, !dbg !246 + %5314 = getelementptr inbounds nuw i8, ptr addrspace(3) %5312, i32 4, !dbg !246 + %5315 = load float, ptr addrspace(3) %5314, align 4, !dbg !246 + %5316 = getelementptr inbounds nuw i8, ptr addrspace(3) %5291, i32 %4825, !dbg !246 + %5317 = load float, ptr addrspace(3) %5316, align 8, !dbg !246 + %5318 = getelementptr inbounds nuw i8, ptr addrspace(3) %5316, i32 4, !dbg !246 + %5319 = load float, ptr addrspace(3) %5318, align 4, !dbg !246 + %5320 = getelementptr inbounds nuw i8, ptr addrspace(3) %5291, i32 %4827, !dbg !246 + %5321 = load float, ptr addrspace(3) %5320, align 8, !dbg !246 + %5322 = getelementptr inbounds nuw i8, ptr addrspace(3) %5320, i32 4, !dbg !246 + %5323 = load float, ptr addrspace(3) %5322, align 4, !dbg !246 + %5324 = fcmp oeq float %5293, 0xFFF0000000000000, !dbg !251 + %5325 = fcmp oeq float %5295, 0xFFF0000000000000, !dbg !251 + %5326 = fcmp oeq float %5297, 0xFFF0000000000000, !dbg !251 + %5327 = fcmp oeq float %5299, 0xFFF0000000000000, !dbg !251 + %5328 = fcmp oeq float %5301, 0xFFF0000000000000, !dbg !251 + %5329 = fcmp oeq float %5303, 0xFFF0000000000000, !dbg !251 + %5330 = fcmp oeq float %5305, 0xFFF0000000000000, !dbg !251 + %5331 = fcmp oeq float %5307, 0xFFF0000000000000, !dbg !251 + %5332 = fcmp oeq float %5309, 0xFFF0000000000000, !dbg !251 + %5333 = fcmp oeq float %5311, 0xFFF0000000000000, !dbg !251 + %5334 = fcmp oeq float %5313, 0xFFF0000000000000, !dbg !251 + %5335 = fcmp oeq float %5315, 0xFFF0000000000000, !dbg !251 + %5336 = fcmp oeq float %5317, 0xFFF0000000000000, !dbg !251 + %5337 = fcmp oeq float %5319, 0xFFF0000000000000, !dbg !251 + %5338 = fcmp oeq float %5321, 0xFFF0000000000000, !dbg !251 + %5339 = fcmp oeq float %5323, 0xFFF0000000000000, !dbg !251 + %5340 = select i1 %5324, float 0.000000e+00, float %5293, !dbg !252 + %5341 = select i1 %5325, float 0.000000e+00, float %5295, !dbg !252 + %5342 = select i1 %5326, float 0.000000e+00, float %5297, !dbg !252 + %5343 = select i1 %5327, float 0.000000e+00, float %5299, !dbg !252 + %5344 = select i1 %5328, float 0.000000e+00, float %5301, !dbg !252 + %5345 = select i1 %5329, float 0.000000e+00, float %5303, !dbg !252 + %5346 = select i1 %5330, float 0.000000e+00, float %5305, !dbg !252 + %5347 = select i1 %5331, float 0.000000e+00, float %5307, !dbg !252 + %5348 = select i1 %5332, float 0.000000e+00, float %5309, !dbg !252 + %5349 = select i1 %5333, float 0.000000e+00, float %5311, !dbg !252 + %5350 = select i1 %5334, float 0.000000e+00, float %5313, !dbg !252 + %5351 = select i1 %5335, float 0.000000e+00, float %5315, !dbg !252 + %5352 = select i1 %5336, float 0.000000e+00, float %5317, !dbg !252 + %5353 = select i1 %5337, float 0.000000e+00, float %5319, !dbg !252 + %5354 = select i1 %5338, float 0.000000e+00, float %5321, !dbg !252 + %5355 = select i1 %5339, float 0.000000e+00, float %5323, !dbg !252 + %5356 = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 %36, i32 0, i32 31), !dbg !226 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !226 + %5357 = shl i32 %5356, 11, !dbg !226 + %5358 = and i32 %5357, 8192, !dbg !226 + %5359 = add i32 %5358, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !226 + %5360 = lshr exact i32 %5359, 4, !dbg !226 + %5361 = and i32 %5360, 16383, !dbg !226 + %5362 = zext nneg i32 %5361 to i64, !dbg !226 + %5363 = or disjoint i64 %5362, 4611686293372403712, !dbg !226 + %5364 = ptrtoint ptr addrspace(3) %5289 to i32, !dbg !226 + %5365 = lshr exact i32 %5364, 4, !dbg !226 + %5366 = and i32 %5365, 16383, !dbg !226 + %5367 = zext nneg i32 %5366 to i64, !dbg !226 + %5368 = or disjoint i64 %5367, 4611686293338849280, !dbg !226 + %5369 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $32, $33, 0, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,l,l"(i64 %5363, i64 %5368) #3, !dbg !226 + %5370 = or disjoint i32 %5358, 32, !dbg !226 + %5371 = add i32 %5370, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !226 + %5372 = lshr exact i32 %5371, 4, !dbg !226 + %5373 = and i32 %5372, 16383, !dbg !226 + %5374 = zext nneg i32 %5373 to i64, !dbg !226 + %5375 = or disjoint i64 %5374, 4611686293372403712, !dbg !226 + %5376 = add i32 %5364, 32, !dbg !226 + %5377 = lshr exact i32 %5376, 4, !dbg !226 + %5378 = and i32 %5377, 16383, !dbg !226 + %5379 = zext nneg i32 %5378 to i64, !dbg !226 + %5380 = or disjoint i64 %5379, 4611686293338849280, !dbg !226 + %5381 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5369, 0, !dbg !226 + %5382 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5369, 1, !dbg !226 + %5383 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5369, 2, !dbg !226 + %5384 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5369, 3, !dbg !226 + %5385 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5369, 4, !dbg !226 + %5386 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5369, 5, !dbg !226 + %5387 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5369, 6, !dbg !226 + %5388 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5369, 7, !dbg !226 + %5389 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5369, 8, !dbg !226 + %5390 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5369, 9, !dbg !226 + %5391 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5369, 10, !dbg !226 + %5392 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5369, 11, !dbg !226 + %5393 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5369, 12, !dbg !226 + %5394 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5369, 13, !dbg !226 + %5395 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5369, 14, !dbg !226 + %5396 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5369, 15, !dbg !226 + %5397 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5369, 16, !dbg !226 + %5398 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5369, 17, !dbg !226 + %5399 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5369, 18, !dbg !226 + %5400 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5369, 19, !dbg !226 + %5401 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5369, 20, !dbg !226 + %5402 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5369, 21, !dbg !226 + %5403 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5369, 22, !dbg !226 + %5404 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5369, 23, !dbg !226 + %5405 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5369, 24, !dbg !226 + %5406 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5369, 25, !dbg !226 + %5407 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5369, 26, !dbg !226 + %5408 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5369, 27, !dbg !226 + %5409 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5369, 28, !dbg !226 + %5410 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5369, 29, !dbg !226 + %5411 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5369, 30, !dbg !226 + %5412 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5369, 31, !dbg !226 + %5413 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %5381, float %5382, float %5383, float %5384, float %5385, float %5386, float %5387, float %5388, float %5389, float %5390, float %5391, float %5392, float %5393, float %5394, float %5395, float %5396, float %5397, float %5398, float %5399, float %5400, float %5401, float %5402, float %5403, float %5404, float %5405, float %5406, float %5407, float %5408, float %5409, float %5410, float %5411, float %5412, i64 %5375, i64 %5380, i1 true) #3, !dbg !226 + %5414 = or disjoint i32 %5358, 64, !dbg !226 + %5415 = add i32 %5414, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !226 + %5416 = lshr exact i32 %5415, 4, !dbg !226 + %5417 = and i32 %5416, 16383, !dbg !226 + %5418 = zext nneg i32 %5417 to i64, !dbg !226 + %5419 = or disjoint i64 %5418, 4611686293372403712, !dbg !226 + %5420 = add i32 %5364, 64, !dbg !226 + %5421 = lshr exact i32 %5420, 4, !dbg !226 + %5422 = and i32 %5421, 16383, !dbg !226 + %5423 = zext nneg i32 %5422 to i64, !dbg !226 + %5424 = or disjoint i64 %5423, 4611686293338849280, !dbg !226 + %5425 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5413, 0, !dbg !226 + %5426 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5413, 1, !dbg !226 + %5427 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5413, 2, !dbg !226 + %5428 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5413, 3, !dbg !226 + %5429 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5413, 4, !dbg !226 + %5430 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5413, 5, !dbg !226 + %5431 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5413, 6, !dbg !226 + %5432 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5413, 7, !dbg !226 + %5433 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5413, 8, !dbg !226 + %5434 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5413, 9, !dbg !226 + %5435 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5413, 10, !dbg !226 + %5436 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5413, 11, !dbg !226 + %5437 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5413, 12, !dbg !226 + %5438 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5413, 13, !dbg !226 + %5439 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5413, 14, !dbg !226 + %5440 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5413, 15, !dbg !226 + %5441 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5413, 16, !dbg !226 + %5442 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5413, 17, !dbg !226 + %5443 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5413, 18, !dbg !226 + %5444 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5413, 19, !dbg !226 + %5445 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5413, 20, !dbg !226 + %5446 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5413, 21, !dbg !226 + %5447 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5413, 22, !dbg !226 + %5448 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5413, 23, !dbg !226 + %5449 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5413, 24, !dbg !226 + %5450 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5413, 25, !dbg !226 + %5451 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5413, 26, !dbg !226 + %5452 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5413, 27, !dbg !226 + %5453 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5413, 28, !dbg !226 + %5454 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5413, 29, !dbg !226 + %5455 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5413, 30, !dbg !226 + %5456 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5413, 31, !dbg !226 + %5457 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %5425, float %5426, float %5427, float %5428, float %5429, float %5430, float %5431, float %5432, float %5433, float %5434, float %5435, float %5436, float %5437, float %5438, float %5439, float %5440, float %5441, float %5442, float %5443, float %5444, float %5445, float %5446, float %5447, float %5448, float %5449, float %5450, float %5451, float %5452, float %5453, float %5454, float %5455, float %5456, i64 %5419, i64 %5424, i1 true) #3, !dbg !226 + %5458 = or disjoint i32 %5358, 96, !dbg !226 + %5459 = add i32 %5458, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !226 + %5460 = lshr exact i32 %5459, 4, !dbg !226 + %5461 = and i32 %5460, 16383, !dbg !226 + %5462 = zext nneg i32 %5461 to i64, !dbg !226 + %5463 = or disjoint i64 %5462, 4611686293372403712, !dbg !226 + %5464 = add i32 %5364, 96, !dbg !226 + %5465 = lshr exact i32 %5464, 4, !dbg !226 + %5466 = and i32 %5465, 16383, !dbg !226 + %5467 = zext nneg i32 %5466 to i64, !dbg !226 + %5468 = or disjoint i64 %5467, 4611686293338849280, !dbg !226 + %5469 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5457, 0, !dbg !226 + %5470 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5457, 1, !dbg !226 + %5471 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5457, 2, !dbg !226 + %5472 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5457, 3, !dbg !226 + %5473 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5457, 4, !dbg !226 + %5474 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5457, 5, !dbg !226 + %5475 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5457, 6, !dbg !226 + %5476 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5457, 7, !dbg !226 + %5477 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5457, 8, !dbg !226 + %5478 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5457, 9, !dbg !226 + %5479 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5457, 10, !dbg !226 + %5480 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5457, 11, !dbg !226 + %5481 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5457, 12, !dbg !226 + %5482 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5457, 13, !dbg !226 + %5483 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5457, 14, !dbg !226 + %5484 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5457, 15, !dbg !226 + %5485 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5457, 16, !dbg !226 + %5486 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5457, 17, !dbg !226 + %5487 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5457, 18, !dbg !226 + %5488 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5457, 19, !dbg !226 + %5489 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5457, 20, !dbg !226 + %5490 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5457, 21, !dbg !226 + %5491 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5457, 22, !dbg !226 + %5492 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5457, 23, !dbg !226 + %5493 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5457, 24, !dbg !226 + %5494 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5457, 25, !dbg !226 + %5495 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5457, 26, !dbg !226 + %5496 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5457, 27, !dbg !226 + %5497 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5457, 28, !dbg !226 + %5498 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5457, 29, !dbg !226 + %5499 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5457, 30, !dbg !226 + %5500 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5457, 31, !dbg !226 + %5501 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %5469, float %5470, float %5471, float %5472, float %5473, float %5474, float %5475, float %5476, float %5477, float %5478, float %5479, float %5480, float %5481, float %5482, float %5483, float %5484, float %5485, float %5486, float %5487, float %5488, float %5489, float %5490, float %5491, float %5492, float %5493, float %5494, float %5495, float %5496, float %5497, float %5498, float %5499, float %5500, i64 %5463, i64 %5468, i1 true) #3, !dbg !226 + %5502 = or disjoint i32 %5358, 16384, !dbg !226 + %5503 = add i32 %5502, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !226 + %5504 = lshr exact i32 %5503, 4, !dbg !226 + %5505 = and i32 %5504, 16383, !dbg !226 + %5506 = zext nneg i32 %5505 to i64, !dbg !226 + %5507 = or disjoint i64 %5506, 4611686293372403712, !dbg !226 + %5508 = add i32 %5364, 8192, !dbg !226 + %5509 = lshr exact i32 %5508, 4, !dbg !226 + %5510 = and i32 %5509, 16383, !dbg !226 + %5511 = zext nneg i32 %5510 to i64, !dbg !226 + %5512 = or disjoint i64 %5511, 4611686293338849280, !dbg !226 + %5513 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5501, 0, !dbg !226 + %5514 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5501, 1, !dbg !226 + %5515 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5501, 2, !dbg !226 + %5516 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5501, 3, !dbg !226 + %5517 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5501, 4, !dbg !226 + %5518 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5501, 5, !dbg !226 + %5519 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5501, 6, !dbg !226 + %5520 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5501, 7, !dbg !226 + %5521 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5501, 8, !dbg !226 + %5522 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5501, 9, !dbg !226 + %5523 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5501, 10, !dbg !226 + %5524 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5501, 11, !dbg !226 + %5525 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5501, 12, !dbg !226 + %5526 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5501, 13, !dbg !226 + %5527 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5501, 14, !dbg !226 + %5528 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5501, 15, !dbg !226 + %5529 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5501, 16, !dbg !226 + %5530 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5501, 17, !dbg !226 + %5531 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5501, 18, !dbg !226 + %5532 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5501, 19, !dbg !226 + %5533 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5501, 20, !dbg !226 + %5534 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5501, 21, !dbg !226 + %5535 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5501, 22, !dbg !226 + %5536 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5501, 23, !dbg !226 + %5537 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5501, 24, !dbg !226 + %5538 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5501, 25, !dbg !226 + %5539 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5501, 26, !dbg !226 + %5540 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5501, 27, !dbg !226 + %5541 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5501, 28, !dbg !226 + %5542 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5501, 29, !dbg !226 + %5543 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5501, 30, !dbg !226 + %5544 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5501, 31, !dbg !226 + %5545 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %5513, float %5514, float %5515, float %5516, float %5517, float %5518, float %5519, float %5520, float %5521, float %5522, float %5523, float %5524, float %5525, float %5526, float %5527, float %5528, float %5529, float %5530, float %5531, float %5532, float %5533, float %5534, float %5535, float %5536, float %5537, float %5538, float %5539, float %5540, float %5541, float %5542, float %5543, float %5544, i64 %5507, i64 %5512, i1 true) #3, !dbg !226 + %5546 = or disjoint i32 %5358, 16416, !dbg !226 + %5547 = add i32 %5546, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !226 + %5548 = lshr exact i32 %5547, 4, !dbg !226 + %5549 = and i32 %5548, 16383, !dbg !226 + %5550 = zext nneg i32 %5549 to i64, !dbg !226 + %5551 = or disjoint i64 %5550, 4611686293372403712, !dbg !226 + %5552 = add i32 %5364, 8224, !dbg !226 + %5553 = lshr exact i32 %5552, 4, !dbg !226 + %5554 = and i32 %5553, 16383, !dbg !226 + %5555 = zext nneg i32 %5554 to i64, !dbg !226 + %5556 = or disjoint i64 %5555, 4611686293338849280, !dbg !226 + %5557 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5545, 0, !dbg !226 + %5558 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5545, 1, !dbg !226 + %5559 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5545, 2, !dbg !226 + %5560 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5545, 3, !dbg !226 + %5561 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5545, 4, !dbg !226 + %5562 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5545, 5, !dbg !226 + %5563 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5545, 6, !dbg !226 + %5564 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5545, 7, !dbg !226 + %5565 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5545, 8, !dbg !226 + %5566 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5545, 9, !dbg !226 + %5567 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5545, 10, !dbg !226 + %5568 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5545, 11, !dbg !226 + %5569 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5545, 12, !dbg !226 + %5570 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5545, 13, !dbg !226 + %5571 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5545, 14, !dbg !226 + %5572 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5545, 15, !dbg !226 + %5573 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5545, 16, !dbg !226 + %5574 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5545, 17, !dbg !226 + %5575 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5545, 18, !dbg !226 + %5576 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5545, 19, !dbg !226 + %5577 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5545, 20, !dbg !226 + %5578 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5545, 21, !dbg !226 + %5579 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5545, 22, !dbg !226 + %5580 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5545, 23, !dbg !226 + %5581 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5545, 24, !dbg !226 + %5582 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5545, 25, !dbg !226 + %5583 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5545, 26, !dbg !226 + %5584 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5545, 27, !dbg !226 + %5585 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5545, 28, !dbg !226 + %5586 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5545, 29, !dbg !226 + %5587 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5545, 30, !dbg !226 + %5588 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5545, 31, !dbg !226 + %5589 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %5557, float %5558, float %5559, float %5560, float %5561, float %5562, float %5563, float %5564, float %5565, float %5566, float %5567, float %5568, float %5569, float %5570, float %5571, float %5572, float %5573, float %5574, float %5575, float %5576, float %5577, float %5578, float %5579, float %5580, float %5581, float %5582, float %5583, float %5584, float %5585, float %5586, float %5587, float %5588, i64 %5551, i64 %5556, i1 true) #3, !dbg !226 + %5590 = or disjoint i32 %5358, 16448, !dbg !226 + %5591 = add i32 %5590, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !226 + %5592 = lshr exact i32 %5591, 4, !dbg !226 + %5593 = and i32 %5592, 16383, !dbg !226 + %5594 = zext nneg i32 %5593 to i64, !dbg !226 + %5595 = or disjoint i64 %5594, 4611686293372403712, !dbg !226 + %5596 = add i32 %5364, 8256, !dbg !226 + %5597 = lshr exact i32 %5596, 4, !dbg !226 + %5598 = and i32 %5597, 16383, !dbg !226 + %5599 = zext nneg i32 %5598 to i64, !dbg !226 + %5600 = or disjoint i64 %5599, 4611686293338849280, !dbg !226 + %5601 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5589, 0, !dbg !226 + %5602 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5589, 1, !dbg !226 + %5603 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5589, 2, !dbg !226 + %5604 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5589, 3, !dbg !226 + %5605 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5589, 4, !dbg !226 + %5606 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5589, 5, !dbg !226 + %5607 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5589, 6, !dbg !226 + %5608 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5589, 7, !dbg !226 + %5609 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5589, 8, !dbg !226 + %5610 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5589, 9, !dbg !226 + %5611 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5589, 10, !dbg !226 + %5612 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5589, 11, !dbg !226 + %5613 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5589, 12, !dbg !226 + %5614 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5589, 13, !dbg !226 + %5615 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5589, 14, !dbg !226 + %5616 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5589, 15, !dbg !226 + %5617 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5589, 16, !dbg !226 + %5618 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5589, 17, !dbg !226 + %5619 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5589, 18, !dbg !226 + %5620 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5589, 19, !dbg !226 + %5621 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5589, 20, !dbg !226 + %5622 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5589, 21, !dbg !226 + %5623 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5589, 22, !dbg !226 + %5624 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5589, 23, !dbg !226 + %5625 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5589, 24, !dbg !226 + %5626 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5589, 25, !dbg !226 + %5627 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5589, 26, !dbg !226 + %5628 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5589, 27, !dbg !226 + %5629 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5589, 28, !dbg !226 + %5630 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5589, 29, !dbg !226 + %5631 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5589, 30, !dbg !226 + %5632 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5589, 31, !dbg !226 + %5633 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %5601, float %5602, float %5603, float %5604, float %5605, float %5606, float %5607, float %5608, float %5609, float %5610, float %5611, float %5612, float %5613, float %5614, float %5615, float %5616, float %5617, float %5618, float %5619, float %5620, float %5621, float %5622, float %5623, float %5624, float %5625, float %5626, float %5627, float %5628, float %5629, float %5630, float %5631, float %5632, i64 %5595, i64 %5600, i1 true) #3, !dbg !226 + %5634 = or disjoint i32 %5358, 16480, !dbg !226 + %5635 = add i32 %5634, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !226 + %5636 = lshr exact i32 %5635, 4, !dbg !226 + %5637 = and i32 %5636, 16383, !dbg !226 + %5638 = zext nneg i32 %5637 to i64, !dbg !226 + %5639 = or disjoint i64 %5638, 4611686293372403712, !dbg !226 + %5640 = add i32 %5364, 8288, !dbg !226 + %5641 = lshr exact i32 %5640, 4, !dbg !226 + %5642 = and i32 %5641, 16383, !dbg !226 + %5643 = zext nneg i32 %5642 to i64, !dbg !226 + %5644 = or disjoint i64 %5643, 4611686293338849280, !dbg !226 + %5645 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5633, 0, !dbg !226 + %5646 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5633, 1, !dbg !226 + %5647 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5633, 2, !dbg !226 + %5648 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5633, 3, !dbg !226 + %5649 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5633, 4, !dbg !226 + %5650 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5633, 5, !dbg !226 + %5651 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5633, 6, !dbg !226 + %5652 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5633, 7, !dbg !226 + %5653 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5633, 8, !dbg !226 + %5654 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5633, 9, !dbg !226 + %5655 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5633, 10, !dbg !226 + %5656 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5633, 11, !dbg !226 + %5657 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5633, 12, !dbg !226 + %5658 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5633, 13, !dbg !226 + %5659 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5633, 14, !dbg !226 + %5660 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5633, 15, !dbg !226 + %5661 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5633, 16, !dbg !226 + %5662 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5633, 17, !dbg !226 + %5663 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5633, 18, !dbg !226 + %5664 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5633, 19, !dbg !226 + %5665 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5633, 20, !dbg !226 + %5666 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5633, 21, !dbg !226 + %5667 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5633, 22, !dbg !226 + %5668 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5633, 23, !dbg !226 + %5669 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5633, 24, !dbg !226 + %5670 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5633, 25, !dbg !226 + %5671 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5633, 26, !dbg !226 + %5672 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5633, 27, !dbg !226 + %5673 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5633, 28, !dbg !226 + %5674 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5633, 29, !dbg !226 + %5675 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5633, 30, !dbg !226 + %5676 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5633, 31, !dbg !226 + %5677 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %5645, float %5646, float %5647, float %5648, float %5649, float %5650, float %5651, float %5652, float %5653, float %5654, float %5655, float %5656, float %5657, float %5658, float %5659, float %5660, float %5661, float %5662, float %5663, float %5664, float %5665, float %5666, float %5667, float %5668, float %5669, float %5670, float %5671, float %5672, float %5673, float %5674, float %5675, float %5676, i64 %5639, i64 %5644, i1 true) #3, !dbg !226 + %5678 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5677, 0, !dbg !226 + %5679 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5677, 1, !dbg !226 + %5680 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5677, 2, !dbg !226 + %5681 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5677, 3, !dbg !226 + %5682 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5677, 4, !dbg !226 + %5683 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5677, 5, !dbg !226 + %5684 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5677, 6, !dbg !226 + %5685 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5677, 7, !dbg !226 + %5686 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5677, 8, !dbg !226 + %5687 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5677, 9, !dbg !226 + %5688 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5677, 10, !dbg !226 + %5689 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5677, 11, !dbg !226 + %5690 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5677, 12, !dbg !226 + %5691 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5677, 13, !dbg !226 + %5692 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5677, 14, !dbg !226 + %5693 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5677, 15, !dbg !226 + %5694 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5677, 16, !dbg !226 + %5695 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5677, 17, !dbg !226 + %5696 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5677, 18, !dbg !226 + %5697 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5677, 19, !dbg !226 + %5698 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5677, 20, !dbg !226 + %5699 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5677, 21, !dbg !226 + %5700 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5677, 22, !dbg !226 + %5701 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5677, 23, !dbg !226 + %5702 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5677, 24, !dbg !226 + %5703 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5677, 25, !dbg !226 + %5704 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5677, 26, !dbg !226 + %5705 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5677, 27, !dbg !226 + %5706 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5677, 28, !dbg !226 + %5707 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5677, 29, !dbg !226 + %5708 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5677, 30, !dbg !226 + %5709 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %5677, 31, !dbg !226 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !226 + %5710 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } asm sideeffect "// wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37\0A\09wgmma.wait_group.sync.aligned 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37"(float %5678, float %5679, float %5680, float %5681, float %5682, float %5683, float %5684, float %5685, float %5686, float %5687, float %5688, float %5689, float %5690, float %5691, float %5692, float %5693, float %5694, float %5695, float %5696, float %5697, float %5698, float %5699, float %5700, float %5701, float %5702, float %5703, float %5704, float %5705, float %5706, float %5707, float %5708, float %5709, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328), i32 0, i32 0, ptr addrspace(3) %5289, i32 0, i32 0) #3, !dbg !226 + %5711 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %5710, 0, !dbg !226 + %5712 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %5710, 1, !dbg !226 + %5713 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %5710, 2, !dbg !226 + %5714 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %5710, 3, !dbg !226 + %5715 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %5710, 4, !dbg !226 + %5716 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %5710, 5, !dbg !226 + %5717 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %5710, 6, !dbg !226 + %5718 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %5710, 7, !dbg !226 + %5719 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %5710, 8, !dbg !226 + %5720 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %5710, 9, !dbg !226 + %5721 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %5710, 10, !dbg !226 + %5722 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %5710, 11, !dbg !226 + %5723 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %5710, 12, !dbg !226 + %5724 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %5710, 13, !dbg !226 + %5725 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %5710, 14, !dbg !226 + %5726 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %5710, 15, !dbg !226 + %5727 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %5710, 16, !dbg !226 + %5728 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %5710, 17, !dbg !226 + %5729 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %5710, 18, !dbg !226 + %5730 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %5710, 19, !dbg !226 + %5731 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %5710, 20, !dbg !226 + %5732 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %5710, 21, !dbg !226 + %5733 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %5710, 22, !dbg !226 + %5734 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %5710, 23, !dbg !226 + %5735 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %5710, 24, !dbg !226 + %5736 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %5710, 25, !dbg !226 + %5737 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %5710, 26, !dbg !226 + %5738 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %5710, 27, !dbg !226 + %5739 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %5710, 28, !dbg !226 + %5740 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %5710, 29, !dbg !226 + %5741 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %5710, 30, !dbg !226 + %5742 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %5710, 31, !dbg !226 + %5743 = fmul float %5711, 0x3FB6A09E60000000, !dbg !253 + %5744 = fmul float %5712, 0x3FB6A09E60000000, !dbg !253 + %5745 = fmul float %5713, 0x3FB6A09E60000000, !dbg !253 + %5746 = fmul float %5714, 0x3FB6A09E60000000, !dbg !253 + %5747 = fmul float %5715, 0x3FB6A09E60000000, !dbg !253 + %5748 = fmul float %5716, 0x3FB6A09E60000000, !dbg !253 + %5749 = fmul float %5717, 0x3FB6A09E60000000, !dbg !253 + %5750 = fmul float %5718, 0x3FB6A09E60000000, !dbg !253 + %5751 = fmul float %5719, 0x3FB6A09E60000000, !dbg !253 + %5752 = fmul float %5720, 0x3FB6A09E60000000, !dbg !253 + %5753 = fmul float %5721, 0x3FB6A09E60000000, !dbg !253 + %5754 = fmul float %5722, 0x3FB6A09E60000000, !dbg !253 + %5755 = fmul float %5723, 0x3FB6A09E60000000, !dbg !253 + %5756 = fmul float %5724, 0x3FB6A09E60000000, !dbg !253 + %5757 = fmul float %5725, 0x3FB6A09E60000000, !dbg !253 + %5758 = fmul float %5726, 0x3FB6A09E60000000, !dbg !253 + %5759 = fmul float %5727, 0x3FB6A09E60000000, !dbg !253 + %5760 = fmul float %5728, 0x3FB6A09E60000000, !dbg !253 + %5761 = fmul float %5729, 0x3FB6A09E60000000, !dbg !253 + %5762 = fmul float %5730, 0x3FB6A09E60000000, !dbg !253 + %5763 = fmul float %5731, 0x3FB6A09E60000000, !dbg !253 + %5764 = fmul float %5732, 0x3FB6A09E60000000, !dbg !253 + %5765 = fmul float %5733, 0x3FB6A09E60000000, !dbg !253 + %5766 = fmul float %5734, 0x3FB6A09E60000000, !dbg !253 + %5767 = fmul float %5735, 0x3FB6A09E60000000, !dbg !253 + %5768 = fmul float %5736, 0x3FB6A09E60000000, !dbg !253 + %5769 = fmul float %5737, 0x3FB6A09E60000000, !dbg !253 + %5770 = fmul float %5738, 0x3FB6A09E60000000, !dbg !253 + %5771 = fmul float %5739, 0x3FB6A09E60000000, !dbg !253 + %5772 = fmul float %5740, 0x3FB6A09E60000000, !dbg !253 + %5773 = fmul float %5741, 0x3FB6A09E60000000, !dbg !253 + %5774 = fmul float %5742, 0x3FB6A09E60000000, !dbg !253 + %5775 = extractelement <16 x i32> %5279, i64 15, !dbg !254 + %5776 = icmp sge i32 %5775, %4470, !dbg !254 + %5777 = extractelement <16 x i32> %5279, i64 14, !dbg !254 + %5778 = icmp sge i32 %5777, %4470, !dbg !254 + %5779 = icmp sge i32 %5775, %4471, !dbg !254 + %5780 = icmp sge i32 %5777, %4471, !dbg !254 + %5781 = extractelement <16 x i32> %5279, i64 13, !dbg !254 + %5782 = icmp sge i32 %5781, %4470, !dbg !254 + %5783 = extractelement <16 x i32> %5279, i64 12, !dbg !254 + %5784 = icmp sge i32 %5783, %4470, !dbg !254 + %5785 = icmp sge i32 %5781, %4471, !dbg !254 + %5786 = icmp sge i32 %5783, %4471, !dbg !254 + %5787 = extractelement <16 x i32> %5279, i64 11, !dbg !254 + %5788 = icmp sge i32 %5787, %4470, !dbg !254 + %5789 = extractelement <16 x i32> %5279, i64 10, !dbg !254 + %5790 = icmp sge i32 %5789, %4470, !dbg !254 + %5791 = icmp sge i32 %5787, %4471, !dbg !254 + %5792 = icmp sge i32 %5789, %4471, !dbg !254 + %5793 = extractelement <16 x i32> %5279, i64 9, !dbg !254 + %5794 = icmp sge i32 %5793, %4470, !dbg !254 + %5795 = extractelement <16 x i32> %5279, i64 8, !dbg !254 + %5796 = icmp sge i32 %5795, %4470, !dbg !254 + %5797 = icmp sge i32 %5793, %4471, !dbg !254 + %5798 = icmp sge i32 %5795, %4471, !dbg !254 + %5799 = extractelement <16 x i32> %5279, i64 7, !dbg !254 + %5800 = icmp sge i32 %5799, %4470, !dbg !254 + %5801 = extractelement <16 x i32> %5279, i64 6, !dbg !254 + %5802 = icmp sge i32 %5801, %4470, !dbg !254 + %5803 = icmp sge i32 %5799, %4471, !dbg !254 + %5804 = icmp sge i32 %5801, %4471, !dbg !254 + %5805 = extractelement <16 x i32> %5279, i64 5, !dbg !254 + %5806 = icmp sge i32 %5805, %4470, !dbg !254 + %5807 = extractelement <16 x i32> %5279, i64 4, !dbg !254 + %5808 = icmp sge i32 %5807, %4470, !dbg !254 + %5809 = icmp sge i32 %5805, %4471, !dbg !254 + %5810 = icmp sge i32 %5807, %4471, !dbg !254 + %5811 = extractelement <16 x i32> %5279, i64 3, !dbg !254 + %5812 = icmp sge i32 %5811, %4470, !dbg !254 + %5813 = extractelement <16 x i32> %5279, i64 2, !dbg !254 + %5814 = icmp sge i32 %5813, %4470, !dbg !254 + %5815 = icmp sge i32 %5811, %4471, !dbg !254 + %5816 = icmp sge i32 %5813, %4471, !dbg !254 + %5817 = extractelement <16 x i32> %5279, i64 1, !dbg !254 + %5818 = icmp sge i32 %5817, %4470, !dbg !254 + %5819 = extractelement <16 x i32> %5279, i64 0, !dbg !254 + %5820 = icmp sge i32 %5819, %4470, !dbg !254 + %5821 = icmp sge i32 %5817, %4471, !dbg !254 + %5822 = icmp sge i32 %5819, %4471, !dbg !254 + %5823 = sext <16 x i32> %5279 to <16 x i64>, !dbg !255 + %5824 = icmp sgt <16 x i64> %4944, %5823, !dbg !255 + %5825 = extractelement <16 x i1> %5824, i64 15, !dbg !256 + %5826 = and i1 %5776, %5825, !dbg !256 + %5827 = extractelement <16 x i1> %5824, i64 14, !dbg !256 + %5828 = and i1 %5778, %5827, !dbg !256 + %5829 = and i1 %5779, %5825, !dbg !256 + %5830 = and i1 %5780, %5827, !dbg !256 + %5831 = extractelement <16 x i1> %5824, i64 13, !dbg !256 + %5832 = and i1 %5782, %5831, !dbg !256 + %5833 = extractelement <16 x i1> %5824, i64 12, !dbg !256 + %5834 = and i1 %5784, %5833, !dbg !256 + %5835 = and i1 %5785, %5831, !dbg !256 + %5836 = and i1 %5786, %5833, !dbg !256 + %5837 = extractelement <16 x i1> %5824, i64 11, !dbg !256 + %5838 = and i1 %5788, %5837, !dbg !256 + %5839 = extractelement <16 x i1> %5824, i64 10, !dbg !256 + %5840 = and i1 %5790, %5839, !dbg !256 + %5841 = and i1 %5791, %5837, !dbg !256 + %5842 = and i1 %5792, %5839, !dbg !256 + %5843 = extractelement <16 x i1> %5824, i64 9, !dbg !256 + %5844 = and i1 %5794, %5843, !dbg !256 + %5845 = extractelement <16 x i1> %5824, i64 8, !dbg !256 + %5846 = and i1 %5796, %5845, !dbg !256 + %5847 = and i1 %5797, %5843, !dbg !256 + %5848 = and i1 %5798, %5845, !dbg !256 + %5849 = extractelement <16 x i1> %5824, i64 7, !dbg !256 + %5850 = and i1 %5800, %5849, !dbg !256 + %5851 = extractelement <16 x i1> %5824, i64 6, !dbg !256 + %5852 = and i1 %5802, %5851, !dbg !256 + %5853 = and i1 %5803, %5849, !dbg !256 + %5854 = and i1 %5804, %5851, !dbg !256 + %5855 = extractelement <16 x i1> %5824, i64 5, !dbg !256 + %5856 = and i1 %5806, %5855, !dbg !256 + %5857 = extractelement <16 x i1> %5824, i64 4, !dbg !256 + %5858 = and i1 %5808, %5857, !dbg !256 + %5859 = and i1 %5809, %5855, !dbg !256 + %5860 = and i1 %5810, %5857, !dbg !256 + %5861 = extractelement <16 x i1> %5824, i64 3, !dbg !256 + %5862 = and i1 %5812, %5861, !dbg !256 + %5863 = extractelement <16 x i1> %5824, i64 2, !dbg !256 + %5864 = and i1 %5814, %5863, !dbg !256 + %5865 = and i1 %5815, %5861, !dbg !256 + %5866 = and i1 %5816, %5863, !dbg !256 + %5867 = extractelement <16 x i1> %5824, i64 1, !dbg !256 + %5868 = and i1 %5818, %5867, !dbg !256 + %5869 = extractelement <16 x i1> %5824, i64 0, !dbg !256 + %5870 = and i1 %5820, %5869, !dbg !256 + %5871 = and i1 %5821, %5867, !dbg !256 + %5872 = and i1 %5822, %5869, !dbg !256 + %5873 = fmul float %5743, 0x3FF7154760000000, !dbg !257 + %5874 = select i1 %5826, float %5873, float 0xFFF0000000000000, !dbg !258 + %5875 = fmul float %5744, 0x3FF7154760000000, !dbg !257 + %5876 = select i1 %5828, float %5875, float 0xFFF0000000000000, !dbg !258 + %5877 = fmul float %5745, 0x3FF7154760000000, !dbg !257 + %5878 = select i1 %5829, float %5877, float 0xFFF0000000000000, !dbg !258 + %5879 = fmul float %5746, 0x3FF7154760000000, !dbg !257 + %5880 = select i1 %5830, float %5879, float 0xFFF0000000000000, !dbg !258 + %5881 = fmul float %5747, 0x3FF7154760000000, !dbg !257 + %5882 = select i1 %5832, float %5881, float 0xFFF0000000000000, !dbg !258 + %5883 = fmul float %5748, 0x3FF7154760000000, !dbg !257 + %5884 = select i1 %5834, float %5883, float 0xFFF0000000000000, !dbg !258 + %5885 = fmul float %5749, 0x3FF7154760000000, !dbg !257 + %5886 = select i1 %5835, float %5885, float 0xFFF0000000000000, !dbg !258 + %5887 = fmul float %5750, 0x3FF7154760000000, !dbg !257 + %5888 = select i1 %5836, float %5887, float 0xFFF0000000000000, !dbg !258 + %5889 = fmul float %5751, 0x3FF7154760000000, !dbg !257 + %5890 = select i1 %5838, float %5889, float 0xFFF0000000000000, !dbg !258 + %5891 = fmul float %5752, 0x3FF7154760000000, !dbg !257 + %5892 = select i1 %5840, float %5891, float 0xFFF0000000000000, !dbg !258 + %5893 = fmul float %5753, 0x3FF7154760000000, !dbg !257 + %5894 = select i1 %5841, float %5893, float 0xFFF0000000000000, !dbg !258 + %5895 = fmul float %5754, 0x3FF7154760000000, !dbg !257 + %5896 = select i1 %5842, float %5895, float 0xFFF0000000000000, !dbg !258 + %5897 = fmul float %5755, 0x3FF7154760000000, !dbg !257 + %5898 = select i1 %5844, float %5897, float 0xFFF0000000000000, !dbg !258 + %5899 = fmul float %5756, 0x3FF7154760000000, !dbg !257 + %5900 = select i1 %5846, float %5899, float 0xFFF0000000000000, !dbg !258 + %5901 = fmul float %5757, 0x3FF7154760000000, !dbg !257 + %5902 = select i1 %5847, float %5901, float 0xFFF0000000000000, !dbg !258 + %5903 = fmul float %5758, 0x3FF7154760000000, !dbg !257 + %5904 = select i1 %5848, float %5903, float 0xFFF0000000000000, !dbg !258 + %5905 = fmul float %5759, 0x3FF7154760000000, !dbg !257 + %5906 = select i1 %5850, float %5905, float 0xFFF0000000000000, !dbg !258 + %5907 = fmul float %5760, 0x3FF7154760000000, !dbg !257 + %5908 = select i1 %5852, float %5907, float 0xFFF0000000000000, !dbg !258 + %5909 = fmul float %5761, 0x3FF7154760000000, !dbg !257 + %5910 = select i1 %5853, float %5909, float 0xFFF0000000000000, !dbg !258 + %5911 = fmul float %5762, 0x3FF7154760000000, !dbg !257 + %5912 = select i1 %5854, float %5911, float 0xFFF0000000000000, !dbg !258 + %5913 = fmul float %5763, 0x3FF7154760000000, !dbg !257 + %5914 = select i1 %5856, float %5913, float 0xFFF0000000000000, !dbg !258 + %5915 = fmul float %5764, 0x3FF7154760000000, !dbg !257 + %5916 = select i1 %5858, float %5915, float 0xFFF0000000000000, !dbg !258 + %5917 = fmul float %5765, 0x3FF7154760000000, !dbg !257 + %5918 = select i1 %5859, float %5917, float 0xFFF0000000000000, !dbg !258 + %5919 = fmul float %5766, 0x3FF7154760000000, !dbg !257 + %5920 = select i1 %5860, float %5919, float 0xFFF0000000000000, !dbg !258 + %5921 = fmul float %5767, 0x3FF7154760000000, !dbg !257 + %5922 = select i1 %5862, float %5921, float 0xFFF0000000000000, !dbg !258 + %5923 = fmul float %5768, 0x3FF7154760000000, !dbg !257 + %5924 = select i1 %5864, float %5923, float 0xFFF0000000000000, !dbg !258 + %5925 = fmul float %5769, 0x3FF7154760000000, !dbg !257 + %5926 = select i1 %5865, float %5925, float 0xFFF0000000000000, !dbg !258 + %5927 = fmul float %5770, 0x3FF7154760000000, !dbg !257 + %5928 = select i1 %5866, float %5927, float 0xFFF0000000000000, !dbg !258 + %5929 = fmul float %5771, 0x3FF7154760000000, !dbg !257 + %5930 = select i1 %5868, float %5929, float 0xFFF0000000000000, !dbg !258 + %5931 = fmul float %5772, 0x3FF7154760000000, !dbg !257 + %5932 = select i1 %5870, float %5931, float 0xFFF0000000000000, !dbg !258 + %5933 = fmul float %5773, 0x3FF7154760000000, !dbg !257 + %5934 = select i1 %5871, float %5933, float 0xFFF0000000000000, !dbg !258 + %5935 = fmul float %5774, 0x3FF7154760000000, !dbg !257 + %5936 = select i1 %5872, float %5935, float 0xFFF0000000000000, !dbg !258 + %5937 = fsub float %5874, %5340, !dbg !259 + %5938 = fsub float %5876, %5341, !dbg !259 + %5939 = fsub float %5878, %5340, !dbg !259 + %5940 = fsub float %5880, %5341, !dbg !259 + %5941 = fsub float %5882, %5342, !dbg !259 + %5942 = fsub float %5884, %5343, !dbg !259 + %5943 = fsub float %5886, %5342, !dbg !259 + %5944 = fsub float %5888, %5343, !dbg !259 + %5945 = fsub float %5890, %5344, !dbg !259 + %5946 = fsub float %5892, %5345, !dbg !259 + %5947 = fsub float %5894, %5344, !dbg !259 + %5948 = fsub float %5896, %5345, !dbg !259 + %5949 = fsub float %5898, %5346, !dbg !259 + %5950 = fsub float %5900, %5347, !dbg !259 + %5951 = fsub float %5902, %5346, !dbg !259 + %5952 = fsub float %5904, %5347, !dbg !259 + %5953 = fsub float %5906, %5348, !dbg !259 + %5954 = fsub float %5908, %5349, !dbg !259 + %5955 = fsub float %5910, %5348, !dbg !259 + %5956 = fsub float %5912, %5349, !dbg !259 + %5957 = fsub float %5914, %5350, !dbg !259 + %5958 = fsub float %5916, %5351, !dbg !259 + %5959 = fsub float %5918, %5350, !dbg !259 + %5960 = fsub float %5920, %5351, !dbg !259 + %5961 = fsub float %5922, %5352, !dbg !259 + %5962 = fsub float %5924, %5353, !dbg !259 + %5963 = fsub float %5926, %5352, !dbg !259 + %5964 = fsub float %5928, %5353, !dbg !259 + %5965 = fsub float %5930, %5354, !dbg !259 + %5966 = fsub float %5932, %5355, !dbg !259 + %5967 = fsub float %5934, %5354, !dbg !259 + %5968 = fsub float %5936, %5355, !dbg !259 + %5969 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !260 + %.not.i1220 = icmp eq i32 %5969, 0, !dbg !260 + br i1 %.not.i1220, label %5972, label %5970, !dbg !260 + +5970: ; preds = %.lr.ph + %5971 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %5937) #3, !dbg !260 + br label %__nv_exp2f.exit1222, !dbg !260 + +5972: ; preds = %.lr.ph + %5973 = tail call float @llvm.nvvm.ex2.approx.f(float %5937) #3, !dbg !260 + br label %__nv_exp2f.exit1222, !dbg !260 + +__nv_exp2f.exit1222: ; preds = %5970, %5972 + %.0.i1221 = phi float [ %5971, %5970 ], [ %5973, %5972 ], !dbg !260 + %5974 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !260 + %.not.i1223 = icmp eq i32 %5974, 0, !dbg !260 + br i1 %.not.i1223, label %5977, label %5975, !dbg !260 + +5975: ; preds = %__nv_exp2f.exit1222 + %5976 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %5938) #3, !dbg !260 + br label %__nv_exp2f.exit1225, !dbg !260 + +5977: ; preds = %__nv_exp2f.exit1222 + %5978 = tail call float @llvm.nvvm.ex2.approx.f(float %5938) #3, !dbg !260 + br label %__nv_exp2f.exit1225, !dbg !260 + +__nv_exp2f.exit1225: ; preds = %5975, %5977 + %.0.i1224 = phi float [ %5976, %5975 ], [ %5978, %5977 ], !dbg !260 + %5979 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !260 + %.not.i1226 = icmp eq i32 %5979, 0, !dbg !260 + br i1 %.not.i1226, label %5982, label %5980, !dbg !260 + +5980: ; preds = %__nv_exp2f.exit1225 + %5981 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %5939) #3, !dbg !260 + br label %__nv_exp2f.exit1228, !dbg !260 + +5982: ; preds = %__nv_exp2f.exit1225 + %5983 = tail call float @llvm.nvvm.ex2.approx.f(float %5939) #3, !dbg !260 + br label %__nv_exp2f.exit1228, !dbg !260 + +__nv_exp2f.exit1228: ; preds = %5980, %5982 + %.0.i1227 = phi float [ %5981, %5980 ], [ %5983, %5982 ], !dbg !260 + %5984 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !260 + %.not.i1229 = icmp eq i32 %5984, 0, !dbg !260 + br i1 %.not.i1229, label %5987, label %5985, !dbg !260 + +5985: ; preds = %__nv_exp2f.exit1228 + %5986 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %5940) #3, !dbg !260 + br label %__nv_exp2f.exit1231, !dbg !260 + +5987: ; preds = %__nv_exp2f.exit1228 + %5988 = tail call float @llvm.nvvm.ex2.approx.f(float %5940) #3, !dbg !260 + br label %__nv_exp2f.exit1231, !dbg !260 + +__nv_exp2f.exit1231: ; preds = %5985, %5987 + %.0.i1230 = phi float [ %5986, %5985 ], [ %5988, %5987 ], !dbg !260 + %5989 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !260 + %.not.i1232 = icmp eq i32 %5989, 0, !dbg !260 + br i1 %.not.i1232, label %5992, label %5990, !dbg !260 + +5990: ; preds = %__nv_exp2f.exit1231 + %5991 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %5941) #3, !dbg !260 + br label %__nv_exp2f.exit1234, !dbg !260 + +5992: ; preds = %__nv_exp2f.exit1231 + %5993 = tail call float @llvm.nvvm.ex2.approx.f(float %5941) #3, !dbg !260 + br label %__nv_exp2f.exit1234, !dbg !260 + +__nv_exp2f.exit1234: ; preds = %5990, %5992 + %.0.i1233 = phi float [ %5991, %5990 ], [ %5993, %5992 ], !dbg !260 + %5994 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !260 + %.not.i1235 = icmp eq i32 %5994, 0, !dbg !260 + br i1 %.not.i1235, label %5997, label %5995, !dbg !260 + +5995: ; preds = %__nv_exp2f.exit1234 + %5996 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %5942) #3, !dbg !260 + br label %__nv_exp2f.exit1237, !dbg !260 + +5997: ; preds = %__nv_exp2f.exit1234 + %5998 = tail call float @llvm.nvvm.ex2.approx.f(float %5942) #3, !dbg !260 + br label %__nv_exp2f.exit1237, !dbg !260 + +__nv_exp2f.exit1237: ; preds = %5995, %5997 + %.0.i1236 = phi float [ %5996, %5995 ], [ %5998, %5997 ], !dbg !260 + %5999 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !260 + %.not.i1238 = icmp eq i32 %5999, 0, !dbg !260 + br i1 %.not.i1238, label %6002, label %6000, !dbg !260 + +6000: ; preds = %__nv_exp2f.exit1237 + %6001 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %5943) #3, !dbg !260 + br label %__nv_exp2f.exit1240, !dbg !260 + +6002: ; preds = %__nv_exp2f.exit1237 + %6003 = tail call float @llvm.nvvm.ex2.approx.f(float %5943) #3, !dbg !260 + br label %__nv_exp2f.exit1240, !dbg !260 + +__nv_exp2f.exit1240: ; preds = %6000, %6002 + %.0.i1239 = phi float [ %6001, %6000 ], [ %6003, %6002 ], !dbg !260 + %6004 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !260 + %.not.i1241 = icmp eq i32 %6004, 0, !dbg !260 + br i1 %.not.i1241, label %6007, label %6005, !dbg !260 + +6005: ; preds = %__nv_exp2f.exit1240 + %6006 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %5944) #3, !dbg !260 + br label %__nv_exp2f.exit1243, !dbg !260 + +6007: ; preds = %__nv_exp2f.exit1240 + %6008 = tail call float @llvm.nvvm.ex2.approx.f(float %5944) #3, !dbg !260 + br label %__nv_exp2f.exit1243, !dbg !260 + +__nv_exp2f.exit1243: ; preds = %6005, %6007 + %.0.i1242 = phi float [ %6006, %6005 ], [ %6008, %6007 ], !dbg !260 + %6009 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !260 + %.not.i1244 = icmp eq i32 %6009, 0, !dbg !260 + br i1 %.not.i1244, label %6012, label %6010, !dbg !260 + +6010: ; preds = %__nv_exp2f.exit1243 + %6011 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %5945) #3, !dbg !260 + br label %__nv_exp2f.exit1246, !dbg !260 + +6012: ; preds = %__nv_exp2f.exit1243 + %6013 = tail call float @llvm.nvvm.ex2.approx.f(float %5945) #3, !dbg !260 + br label %__nv_exp2f.exit1246, !dbg !260 + +__nv_exp2f.exit1246: ; preds = %6010, %6012 + %.0.i1245 = phi float [ %6011, %6010 ], [ %6013, %6012 ], !dbg !260 + %6014 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !260 + %.not.i1247 = icmp eq i32 %6014, 0, !dbg !260 + br i1 %.not.i1247, label %6017, label %6015, !dbg !260 + +6015: ; preds = %__nv_exp2f.exit1246 + %6016 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %5946) #3, !dbg !260 + br label %__nv_exp2f.exit1249, !dbg !260 + +6017: ; preds = %__nv_exp2f.exit1246 + %6018 = tail call float @llvm.nvvm.ex2.approx.f(float %5946) #3, !dbg !260 + br label %__nv_exp2f.exit1249, !dbg !260 + +__nv_exp2f.exit1249: ; preds = %6015, %6017 + %.0.i1248 = phi float [ %6016, %6015 ], [ %6018, %6017 ], !dbg !260 + %6019 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !260 + %.not.i1250 = icmp eq i32 %6019, 0, !dbg !260 + br i1 %.not.i1250, label %6022, label %6020, !dbg !260 + +6020: ; preds = %__nv_exp2f.exit1249 + %6021 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %5947) #3, !dbg !260 + br label %__nv_exp2f.exit1252, !dbg !260 + +6022: ; preds = %__nv_exp2f.exit1249 + %6023 = tail call float @llvm.nvvm.ex2.approx.f(float %5947) #3, !dbg !260 + br label %__nv_exp2f.exit1252, !dbg !260 + +__nv_exp2f.exit1252: ; preds = %6020, %6022 + %.0.i1251 = phi float [ %6021, %6020 ], [ %6023, %6022 ], !dbg !260 + %6024 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !260 + %.not.i1253 = icmp eq i32 %6024, 0, !dbg !260 + br i1 %.not.i1253, label %6027, label %6025, !dbg !260 + +6025: ; preds = %__nv_exp2f.exit1252 + %6026 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %5948) #3, !dbg !260 + br label %__nv_exp2f.exit1255, !dbg !260 + +6027: ; preds = %__nv_exp2f.exit1252 + %6028 = tail call float @llvm.nvvm.ex2.approx.f(float %5948) #3, !dbg !260 + br label %__nv_exp2f.exit1255, !dbg !260 + +__nv_exp2f.exit1255: ; preds = %6025, %6027 + %.0.i1254 = phi float [ %6026, %6025 ], [ %6028, %6027 ], !dbg !260 + %6029 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !260 + %.not.i1256 = icmp eq i32 %6029, 0, !dbg !260 + br i1 %.not.i1256, label %6032, label %6030, !dbg !260 + +6030: ; preds = %__nv_exp2f.exit1255 + %6031 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %5949) #3, !dbg !260 + br label %__nv_exp2f.exit1258, !dbg !260 + +6032: ; preds = %__nv_exp2f.exit1255 + %6033 = tail call float @llvm.nvvm.ex2.approx.f(float %5949) #3, !dbg !260 + br label %__nv_exp2f.exit1258, !dbg !260 + +__nv_exp2f.exit1258: ; preds = %6030, %6032 + %.0.i1257 = phi float [ %6031, %6030 ], [ %6033, %6032 ], !dbg !260 + %6034 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !260 + %.not.i1259 = icmp eq i32 %6034, 0, !dbg !260 + br i1 %.not.i1259, label %6037, label %6035, !dbg !260 + +6035: ; preds = %__nv_exp2f.exit1258 + %6036 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %5950) #3, !dbg !260 + br label %__nv_exp2f.exit1261, !dbg !260 + +6037: ; preds = %__nv_exp2f.exit1258 + %6038 = tail call float @llvm.nvvm.ex2.approx.f(float %5950) #3, !dbg !260 + br label %__nv_exp2f.exit1261, !dbg !260 + +__nv_exp2f.exit1261: ; preds = %6035, %6037 + %.0.i1260 = phi float [ %6036, %6035 ], [ %6038, %6037 ], !dbg !260 + %6039 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !260 + %.not.i1262 = icmp eq i32 %6039, 0, !dbg !260 + br i1 %.not.i1262, label %6042, label %6040, !dbg !260 + +6040: ; preds = %__nv_exp2f.exit1261 + %6041 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %5951) #3, !dbg !260 + br label %__nv_exp2f.exit1264, !dbg !260 + +6042: ; preds = %__nv_exp2f.exit1261 + %6043 = tail call float @llvm.nvvm.ex2.approx.f(float %5951) #3, !dbg !260 + br label %__nv_exp2f.exit1264, !dbg !260 + +__nv_exp2f.exit1264: ; preds = %6040, %6042 + %.0.i1263 = phi float [ %6041, %6040 ], [ %6043, %6042 ], !dbg !260 + %6044 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !260 + %.not.i1265 = icmp eq i32 %6044, 0, !dbg !260 + br i1 %.not.i1265, label %6047, label %6045, !dbg !260 + +6045: ; preds = %__nv_exp2f.exit1264 + %6046 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %5952) #3, !dbg !260 + br label %__nv_exp2f.exit1267, !dbg !260 + +6047: ; preds = %__nv_exp2f.exit1264 + %6048 = tail call float @llvm.nvvm.ex2.approx.f(float %5952) #3, !dbg !260 + br label %__nv_exp2f.exit1267, !dbg !260 + +__nv_exp2f.exit1267: ; preds = %6045, %6047 + %.0.i1266 = phi float [ %6046, %6045 ], [ %6048, %6047 ], !dbg !260 + %6049 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !260 + %.not.i1268 = icmp eq i32 %6049, 0, !dbg !260 + br i1 %.not.i1268, label %6052, label %6050, !dbg !260 + +6050: ; preds = %__nv_exp2f.exit1267 + %6051 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %5953) #3, !dbg !260 + br label %__nv_exp2f.exit1270, !dbg !260 + +6052: ; preds = %__nv_exp2f.exit1267 + %6053 = tail call float @llvm.nvvm.ex2.approx.f(float %5953) #3, !dbg !260 + br label %__nv_exp2f.exit1270, !dbg !260 + +__nv_exp2f.exit1270: ; preds = %6050, %6052 + %.0.i1269 = phi float [ %6051, %6050 ], [ %6053, %6052 ], !dbg !260 + %6054 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !260 + %.not.i1271 = icmp eq i32 %6054, 0, !dbg !260 + br i1 %.not.i1271, label %6057, label %6055, !dbg !260 + +6055: ; preds = %__nv_exp2f.exit1270 + %6056 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %5954) #3, !dbg !260 + br label %__nv_exp2f.exit1273, !dbg !260 + +6057: ; preds = %__nv_exp2f.exit1270 + %6058 = tail call float @llvm.nvvm.ex2.approx.f(float %5954) #3, !dbg !260 + br label %__nv_exp2f.exit1273, !dbg !260 + +__nv_exp2f.exit1273: ; preds = %6055, %6057 + %.0.i1272 = phi float [ %6056, %6055 ], [ %6058, %6057 ], !dbg !260 + %6059 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !260 + %.not.i1274 = icmp eq i32 %6059, 0, !dbg !260 + br i1 %.not.i1274, label %6062, label %6060, !dbg !260 + +6060: ; preds = %__nv_exp2f.exit1273 + %6061 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %5955) #3, !dbg !260 + br label %__nv_exp2f.exit1276, !dbg !260 + +6062: ; preds = %__nv_exp2f.exit1273 + %6063 = tail call float @llvm.nvvm.ex2.approx.f(float %5955) #3, !dbg !260 + br label %__nv_exp2f.exit1276, !dbg !260 + +__nv_exp2f.exit1276: ; preds = %6060, %6062 + %.0.i1275 = phi float [ %6061, %6060 ], [ %6063, %6062 ], !dbg !260 + %6064 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !260 + %.not.i1277 = icmp eq i32 %6064, 0, !dbg !260 + br i1 %.not.i1277, label %6067, label %6065, !dbg !260 + +6065: ; preds = %__nv_exp2f.exit1276 + %6066 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %5956) #3, !dbg !260 + br label %__nv_exp2f.exit1279, !dbg !260 + +6067: ; preds = %__nv_exp2f.exit1276 + %6068 = tail call float @llvm.nvvm.ex2.approx.f(float %5956) #3, !dbg !260 + br label %__nv_exp2f.exit1279, !dbg !260 + +__nv_exp2f.exit1279: ; preds = %6065, %6067 + %.0.i1278 = phi float [ %6066, %6065 ], [ %6068, %6067 ], !dbg !260 + %6069 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !260 + %.not.i1280 = icmp eq i32 %6069, 0, !dbg !260 + br i1 %.not.i1280, label %6072, label %6070, !dbg !260 + +6070: ; preds = %__nv_exp2f.exit1279 + %6071 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %5957) #3, !dbg !260 + br label %__nv_exp2f.exit1282, !dbg !260 + +6072: ; preds = %__nv_exp2f.exit1279 + %6073 = tail call float @llvm.nvvm.ex2.approx.f(float %5957) #3, !dbg !260 + br label %__nv_exp2f.exit1282, !dbg !260 + +__nv_exp2f.exit1282: ; preds = %6070, %6072 + %.0.i1281 = phi float [ %6071, %6070 ], [ %6073, %6072 ], !dbg !260 + %6074 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !260 + %.not.i1283 = icmp eq i32 %6074, 0, !dbg !260 + br i1 %.not.i1283, label %6077, label %6075, !dbg !260 + +6075: ; preds = %__nv_exp2f.exit1282 + %6076 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %5958) #3, !dbg !260 + br label %__nv_exp2f.exit1285, !dbg !260 + +6077: ; preds = %__nv_exp2f.exit1282 + %6078 = tail call float @llvm.nvvm.ex2.approx.f(float %5958) #3, !dbg !260 + br label %__nv_exp2f.exit1285, !dbg !260 + +__nv_exp2f.exit1285: ; preds = %6075, %6077 + %.0.i1284 = phi float [ %6076, %6075 ], [ %6078, %6077 ], !dbg !260 + %6079 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !260 + %.not.i1286 = icmp eq i32 %6079, 0, !dbg !260 + br i1 %.not.i1286, label %6082, label %6080, !dbg !260 + +6080: ; preds = %__nv_exp2f.exit1285 + %6081 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %5959) #3, !dbg !260 + br label %__nv_exp2f.exit1288, !dbg !260 + +6082: ; preds = %__nv_exp2f.exit1285 + %6083 = tail call float @llvm.nvvm.ex2.approx.f(float %5959) #3, !dbg !260 + br label %__nv_exp2f.exit1288, !dbg !260 + +__nv_exp2f.exit1288: ; preds = %6080, %6082 + %.0.i1287 = phi float [ %6081, %6080 ], [ %6083, %6082 ], !dbg !260 + %6084 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !260 + %.not.i1289 = icmp eq i32 %6084, 0, !dbg !260 + br i1 %.not.i1289, label %6087, label %6085, !dbg !260 + +6085: ; preds = %__nv_exp2f.exit1288 + %6086 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %5960) #3, !dbg !260 + br label %__nv_exp2f.exit1291, !dbg !260 + +6087: ; preds = %__nv_exp2f.exit1288 + %6088 = tail call float @llvm.nvvm.ex2.approx.f(float %5960) #3, !dbg !260 + br label %__nv_exp2f.exit1291, !dbg !260 + +__nv_exp2f.exit1291: ; preds = %6085, %6087 + %.0.i1290 = phi float [ %6086, %6085 ], [ %6088, %6087 ], !dbg !260 + %6089 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !260 + %.not.i1292 = icmp eq i32 %6089, 0, !dbg !260 + br i1 %.not.i1292, label %6092, label %6090, !dbg !260 + +6090: ; preds = %__nv_exp2f.exit1291 + %6091 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %5961) #3, !dbg !260 + br label %__nv_exp2f.exit1294, !dbg !260 + +6092: ; preds = %__nv_exp2f.exit1291 + %6093 = tail call float @llvm.nvvm.ex2.approx.f(float %5961) #3, !dbg !260 + br label %__nv_exp2f.exit1294, !dbg !260 + +__nv_exp2f.exit1294: ; preds = %6090, %6092 + %.0.i1293 = phi float [ %6091, %6090 ], [ %6093, %6092 ], !dbg !260 + %6094 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !260 + %.not.i1295 = icmp eq i32 %6094, 0, !dbg !260 + br i1 %.not.i1295, label %6097, label %6095, !dbg !260 + +6095: ; preds = %__nv_exp2f.exit1294 + %6096 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %5962) #3, !dbg !260 + br label %__nv_exp2f.exit1297, !dbg !260 + +6097: ; preds = %__nv_exp2f.exit1294 + %6098 = tail call float @llvm.nvvm.ex2.approx.f(float %5962) #3, !dbg !260 + br label %__nv_exp2f.exit1297, !dbg !260 + +__nv_exp2f.exit1297: ; preds = %6095, %6097 + %.0.i1296 = phi float [ %6096, %6095 ], [ %6098, %6097 ], !dbg !260 + %6099 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !260 + %.not.i1298 = icmp eq i32 %6099, 0, !dbg !260 + br i1 %.not.i1298, label %6102, label %6100, !dbg !260 + +6100: ; preds = %__nv_exp2f.exit1297 + %6101 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %5963) #3, !dbg !260 + br label %__nv_exp2f.exit1300, !dbg !260 + +6102: ; preds = %__nv_exp2f.exit1297 + %6103 = tail call float @llvm.nvvm.ex2.approx.f(float %5963) #3, !dbg !260 + br label %__nv_exp2f.exit1300, !dbg !260 + +__nv_exp2f.exit1300: ; preds = %6100, %6102 + %.0.i1299 = phi float [ %6101, %6100 ], [ %6103, %6102 ], !dbg !260 + %6104 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !260 + %.not.i1301 = icmp eq i32 %6104, 0, !dbg !260 + br i1 %.not.i1301, label %6107, label %6105, !dbg !260 + +6105: ; preds = %__nv_exp2f.exit1300 + %6106 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %5964) #3, !dbg !260 + br label %__nv_exp2f.exit1303, !dbg !260 + +6107: ; preds = %__nv_exp2f.exit1300 + %6108 = tail call float @llvm.nvvm.ex2.approx.f(float %5964) #3, !dbg !260 + br label %__nv_exp2f.exit1303, !dbg !260 + +__nv_exp2f.exit1303: ; preds = %6105, %6107 + %.0.i1302 = phi float [ %6106, %6105 ], [ %6108, %6107 ], !dbg !260 + %6109 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !260 + %.not.i1304 = icmp eq i32 %6109, 0, !dbg !260 + br i1 %.not.i1304, label %6112, label %6110, !dbg !260 + +6110: ; preds = %__nv_exp2f.exit1303 + %6111 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %5965) #3, !dbg !260 + br label %__nv_exp2f.exit1306, !dbg !260 + +6112: ; preds = %__nv_exp2f.exit1303 + %6113 = tail call float @llvm.nvvm.ex2.approx.f(float %5965) #3, !dbg !260 + br label %__nv_exp2f.exit1306, !dbg !260 + +__nv_exp2f.exit1306: ; preds = %6110, %6112 + %.0.i1305 = phi float [ %6111, %6110 ], [ %6113, %6112 ], !dbg !260 + %6114 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !260 + %.not.i1307 = icmp eq i32 %6114, 0, !dbg !260 + br i1 %.not.i1307, label %6117, label %6115, !dbg !260 + +6115: ; preds = %__nv_exp2f.exit1306 + %6116 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %5966) #3, !dbg !260 + br label %__nv_exp2f.exit1309, !dbg !260 + +6117: ; preds = %__nv_exp2f.exit1306 + %6118 = tail call float @llvm.nvvm.ex2.approx.f(float %5966) #3, !dbg !260 + br label %__nv_exp2f.exit1309, !dbg !260 + +__nv_exp2f.exit1309: ; preds = %6115, %6117 + %.0.i1308 = phi float [ %6116, %6115 ], [ %6118, %6117 ], !dbg !260 + %6119 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !260 + %.not.i1310 = icmp eq i32 %6119, 0, !dbg !260 + br i1 %.not.i1310, label %6122, label %6120, !dbg !260 + +6120: ; preds = %__nv_exp2f.exit1309 + %6121 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %5967) #3, !dbg !260 + br label %__nv_exp2f.exit1312, !dbg !260 + +6122: ; preds = %__nv_exp2f.exit1309 + %6123 = tail call float @llvm.nvvm.ex2.approx.f(float %5967) #3, !dbg !260 + br label %__nv_exp2f.exit1312, !dbg !260 + +__nv_exp2f.exit1312: ; preds = %6120, %6122 + %.0.i1311 = phi float [ %6121, %6120 ], [ %6123, %6122 ], !dbg !260 + %6124 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !260 + %.not.i1313 = icmp eq i32 %6124, 0, !dbg !260 + br i1 %.not.i1313, label %6127, label %6125, !dbg !260 + +6125: ; preds = %__nv_exp2f.exit1312 + %6126 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %5968) #3, !dbg !260 + br label %__nv_exp2f.exit1315, !dbg !260 + +6127: ; preds = %__nv_exp2f.exit1312 + %6128 = tail call float @llvm.nvvm.ex2.approx.f(float %5968) #3, !dbg !260 + br label %__nv_exp2f.exit1315, !dbg !260 + +__nv_exp2f.exit1315: ; preds = %6125, %6127 + %.0.i1314 = phi float [ %6126, %6125 ], [ %6128, %6127 ], !dbg !260 + %6129 = getelementptr bfloat, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %5288, !dbg !244 + %6130 = insertelement <2 x float> poison, float %.0.i1221, i64 0, !dbg !261 + %6131 = insertelement <2 x float> %6130, float %.0.i1224, i64 1, !dbg !261 + %6132 = fptrunc <2 x float> %6131 to <2 x bfloat>, !dbg !261 + %6133 = insertelement <2 x float> poison, float %.0.i1227, i64 0, !dbg !261 + %6134 = insertelement <2 x float> %6133, float %.0.i1230, i64 1, !dbg !261 + %6135 = fptrunc <2 x float> %6134 to <2 x bfloat>, !dbg !261 + %6136 = insertelement <2 x float> poison, float %.0.i1233, i64 0, !dbg !261 + %6137 = insertelement <2 x float> %6136, float %.0.i1236, i64 1, !dbg !261 + %6138 = fptrunc <2 x float> %6137 to <2 x bfloat>, !dbg !261 + %6139 = insertelement <2 x float> poison, float %.0.i1239, i64 0, !dbg !261 + %6140 = insertelement <2 x float> %6139, float %.0.i1242, i64 1, !dbg !261 + %6141 = fptrunc <2 x float> %6140 to <2 x bfloat>, !dbg !261 + %6142 = insertelement <2 x float> poison, float %.0.i1245, i64 0, !dbg !261 + %6143 = insertelement <2 x float> %6142, float %.0.i1248, i64 1, !dbg !261 + %6144 = fptrunc <2 x float> %6143 to <2 x bfloat>, !dbg !261 + %6145 = insertelement <2 x float> poison, float %.0.i1251, i64 0, !dbg !261 + %6146 = insertelement <2 x float> %6145, float %.0.i1254, i64 1, !dbg !261 + %6147 = fptrunc <2 x float> %6146 to <2 x bfloat>, !dbg !261 + %6148 = insertelement <2 x float> poison, float %.0.i1257, i64 0, !dbg !261 + %6149 = insertelement <2 x float> %6148, float %.0.i1260, i64 1, !dbg !261 + %6150 = fptrunc <2 x float> %6149 to <2 x bfloat>, !dbg !261 + %6151 = insertelement <2 x float> poison, float %.0.i1263, i64 0, !dbg !261 + %6152 = insertelement <2 x float> %6151, float %.0.i1266, i64 1, !dbg !261 + %6153 = fptrunc <2 x float> %6152 to <2 x bfloat>, !dbg !261 + %6154 = insertelement <2 x float> poison, float %.0.i1269, i64 0, !dbg !261 + %6155 = insertelement <2 x float> %6154, float %.0.i1272, i64 1, !dbg !261 + %6156 = fptrunc <2 x float> %6155 to <2 x bfloat>, !dbg !261 + %6157 = insertelement <2 x float> poison, float %.0.i1275, i64 0, !dbg !261 + %6158 = insertelement <2 x float> %6157, float %.0.i1278, i64 1, !dbg !261 + %6159 = fptrunc <2 x float> %6158 to <2 x bfloat>, !dbg !261 + %6160 = insertelement <2 x float> poison, float %.0.i1281, i64 0, !dbg !261 + %6161 = insertelement <2 x float> %6160, float %.0.i1284, i64 1, !dbg !261 + %6162 = fptrunc <2 x float> %6161 to <2 x bfloat>, !dbg !261 + %6163 = insertelement <2 x float> poison, float %.0.i1287, i64 0, !dbg !261 + %6164 = insertelement <2 x float> %6163, float %.0.i1290, i64 1, !dbg !261 + %6165 = fptrunc <2 x float> %6164 to <2 x bfloat>, !dbg !261 + %6166 = insertelement <2 x float> poison, float %.0.i1293, i64 0, !dbg !261 + %6167 = insertelement <2 x float> %6166, float %.0.i1296, i64 1, !dbg !261 + %6168 = fptrunc <2 x float> %6167 to <2 x bfloat>, !dbg !261 + %6169 = insertelement <2 x float> poison, float %.0.i1299, i64 0, !dbg !261 + %6170 = insertelement <2 x float> %6169, float %.0.i1302, i64 1, !dbg !261 + %6171 = fptrunc <2 x float> %6170 to <2 x bfloat>, !dbg !261 + %6172 = insertelement <2 x float> poison, float %.0.i1305, i64 0, !dbg !261 + %6173 = insertelement <2 x float> %6172, float %.0.i1308, i64 1, !dbg !261 + %6174 = fptrunc <2 x float> %6173 to <2 x bfloat>, !dbg !261 + %6175 = insertelement <2 x float> poison, float %.0.i1311, i64 0, !dbg !261 + %6176 = insertelement <2 x float> %6175, float %.0.i1314, i64 1, !dbg !261 + %6177 = fptrunc <2 x float> %6176 to <2 x bfloat>, !dbg !261 + %6178 = bitcast <2 x bfloat> %6132 to i32, !dbg !262 + %6179 = bitcast <2 x bfloat> %6135 to i32, !dbg !262 + %6180 = bitcast <2 x bfloat> %6138 to i32, !dbg !262 + %6181 = bitcast <2 x bfloat> %6141 to i32, !dbg !262 + %6182 = bitcast <2 x bfloat> %6144 to i32, !dbg !262 + %6183 = bitcast <2 x bfloat> %6147 to i32, !dbg !262 + %6184 = bitcast <2 x bfloat> %6150 to i32, !dbg !262 + %6185 = bitcast <2 x bfloat> %6153 to i32, !dbg !262 + %6186 = bitcast <2 x bfloat> %6156 to i32, !dbg !262 + %6187 = bitcast <2 x bfloat> %6159 to i32, !dbg !262 + %6188 = bitcast <2 x bfloat> %6162 to i32, !dbg !262 + %6189 = bitcast <2 x bfloat> %6165 to i32, !dbg !262 + %6190 = bitcast <2 x bfloat> %6168 to i32, !dbg !262 + %6191 = bitcast <2 x bfloat> %6171 to i32, !dbg !262 + %6192 = bitcast <2 x bfloat> %6174 to i32, !dbg !262 + %6193 = bitcast <2 x bfloat> %6177 to i32, !dbg !262 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !262 + %6194 = ptrtoint ptr addrspace(3) %6129 to i32, !dbg !262 + %6195 = lshr exact i32 %6194, 4, !dbg !262 + %6196 = and i32 %6195, 16383, !dbg !262 + %6197 = zext nneg i32 %6196 to i64, !dbg !262 + %6198 = or disjoint i64 %6197, 4611686293338849280, !dbg !262 + %6199 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %5150, float %5151, float %5152, float %5153, float %5154, float %5155, float %5156, float %5157, float %5158, float %5159, float %5160, float %5161, float %5162, float %5163, float %5164, float %5165, float %5166, float %5167, float %5168, float %5169, float %5170, float %5171, float %5172, float %5173, float %5174, float %5175, float %5176, float %5177, float %5178, float %5179, float %5180, float %5181, float %5182, float %5183, float %5184, float %5185, float %5186, float %5187, float %5188, float %5189, float %5190, float %5191, float %5192, float %5193, float %5194, float %5195, float %5196, float %5197, float %5198, float %5199, float %5200, float %5201, float %5202, float %5203, float %5204, float %5205, float %5206, float %5207, float %5208, float %5209, float %5210, float %5211, float %5212, float %5213, i32 %6178, i32 %6179, i32 %6180, i32 %6181, i64 %6198, i1 true) #3, !dbg !262 + %6200 = add i32 %6194, 2048, !dbg !262 + %6201 = lshr exact i32 %6200, 4, !dbg !262 + %6202 = and i32 %6201, 16383, !dbg !262 + %6203 = zext nneg i32 %6202 to i64, !dbg !262 + %6204 = or disjoint i64 %6203, 4611686293338849280, !dbg !262 + %6205 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 0, !dbg !262 + %6206 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 1, !dbg !262 + %6207 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 2, !dbg !262 + %6208 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 3, !dbg !262 + %6209 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 4, !dbg !262 + %6210 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 5, !dbg !262 + %6211 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 6, !dbg !262 + %6212 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 7, !dbg !262 + %6213 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 8, !dbg !262 + %6214 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 9, !dbg !262 + %6215 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 10, !dbg !262 + %6216 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 11, !dbg !262 + %6217 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 12, !dbg !262 + %6218 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 13, !dbg !262 + %6219 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 14, !dbg !262 + %6220 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 15, !dbg !262 + %6221 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 16, !dbg !262 + %6222 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 17, !dbg !262 + %6223 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 18, !dbg !262 + %6224 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 19, !dbg !262 + %6225 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 20, !dbg !262 + %6226 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 21, !dbg !262 + %6227 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 22, !dbg !262 + %6228 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 23, !dbg !262 + %6229 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 24, !dbg !262 + %6230 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 25, !dbg !262 + %6231 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 26, !dbg !262 + %6232 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 27, !dbg !262 + %6233 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 28, !dbg !262 + %6234 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 29, !dbg !262 + %6235 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 30, !dbg !262 + %6236 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 31, !dbg !262 + %6237 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 32, !dbg !262 + %6238 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 33, !dbg !262 + %6239 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 34, !dbg !262 + %6240 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 35, !dbg !262 + %6241 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 36, !dbg !262 + %6242 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 37, !dbg !262 + %6243 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 38, !dbg !262 + %6244 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 39, !dbg !262 + %6245 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 40, !dbg !262 + %6246 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 41, !dbg !262 + %6247 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 42, !dbg !262 + %6248 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 43, !dbg !262 + %6249 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 44, !dbg !262 + %6250 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 45, !dbg !262 + %6251 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 46, !dbg !262 + %6252 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 47, !dbg !262 + %6253 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 48, !dbg !262 + %6254 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 49, !dbg !262 + %6255 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 50, !dbg !262 + %6256 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 51, !dbg !262 + %6257 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 52, !dbg !262 + %6258 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 53, !dbg !262 + %6259 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 54, !dbg !262 + %6260 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 55, !dbg !262 + %6261 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 56, !dbg !262 + %6262 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 57, !dbg !262 + %6263 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 58, !dbg !262 + %6264 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 59, !dbg !262 + %6265 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 60, !dbg !262 + %6266 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 61, !dbg !262 + %6267 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 62, !dbg !262 + %6268 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6199, 63, !dbg !262 + %6269 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %6205, float %6206, float %6207, float %6208, float %6209, float %6210, float %6211, float %6212, float %6213, float %6214, float %6215, float %6216, float %6217, float %6218, float %6219, float %6220, float %6221, float %6222, float %6223, float %6224, float %6225, float %6226, float %6227, float %6228, float %6229, float %6230, float %6231, float %6232, float %6233, float %6234, float %6235, float %6236, float %6237, float %6238, float %6239, float %6240, float %6241, float %6242, float %6243, float %6244, float %6245, float %6246, float %6247, float %6248, float %6249, float %6250, float %6251, float %6252, float %6253, float %6254, float %6255, float %6256, float %6257, float %6258, float %6259, float %6260, float %6261, float %6262, float %6263, float %6264, float %6265, float %6266, float %6267, float %6268, i32 %6182, i32 %6183, i32 %6184, i32 %6185, i64 %6204, i1 true) #3, !dbg !262 + %6270 = add i32 %6194, 4096, !dbg !262 + %6271 = lshr exact i32 %6270, 4, !dbg !262 + %6272 = and i32 %6271, 16383, !dbg !262 + %6273 = zext nneg i32 %6272 to i64, !dbg !262 + %6274 = or disjoint i64 %6273, 4611686293338849280, !dbg !262 + %6275 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 0, !dbg !262 + %6276 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 1, !dbg !262 + %6277 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 2, !dbg !262 + %6278 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 3, !dbg !262 + %6279 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 4, !dbg !262 + %6280 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 5, !dbg !262 + %6281 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 6, !dbg !262 + %6282 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 7, !dbg !262 + %6283 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 8, !dbg !262 + %6284 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 9, !dbg !262 + %6285 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 10, !dbg !262 + %6286 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 11, !dbg !262 + %6287 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 12, !dbg !262 + %6288 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 13, !dbg !262 + %6289 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 14, !dbg !262 + %6290 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 15, !dbg !262 + %6291 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 16, !dbg !262 + %6292 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 17, !dbg !262 + %6293 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 18, !dbg !262 + %6294 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 19, !dbg !262 + %6295 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 20, !dbg !262 + %6296 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 21, !dbg !262 + %6297 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 22, !dbg !262 + %6298 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 23, !dbg !262 + %6299 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 24, !dbg !262 + %6300 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 25, !dbg !262 + %6301 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 26, !dbg !262 + %6302 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 27, !dbg !262 + %6303 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 28, !dbg !262 + %6304 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 29, !dbg !262 + %6305 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 30, !dbg !262 + %6306 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 31, !dbg !262 + %6307 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 32, !dbg !262 + %6308 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 33, !dbg !262 + %6309 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 34, !dbg !262 + %6310 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 35, !dbg !262 + %6311 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 36, !dbg !262 + %6312 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 37, !dbg !262 + %6313 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 38, !dbg !262 + %6314 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 39, !dbg !262 + %6315 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 40, !dbg !262 + %6316 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 41, !dbg !262 + %6317 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 42, !dbg !262 + %6318 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 43, !dbg !262 + %6319 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 44, !dbg !262 + %6320 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 45, !dbg !262 + %6321 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 46, !dbg !262 + %6322 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 47, !dbg !262 + %6323 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 48, !dbg !262 + %6324 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 49, !dbg !262 + %6325 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 50, !dbg !262 + %6326 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 51, !dbg !262 + %6327 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 52, !dbg !262 + %6328 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 53, !dbg !262 + %6329 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 54, !dbg !262 + %6330 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 55, !dbg !262 + %6331 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 56, !dbg !262 + %6332 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 57, !dbg !262 + %6333 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 58, !dbg !262 + %6334 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 59, !dbg !262 + %6335 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 60, !dbg !262 + %6336 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 61, !dbg !262 + %6337 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 62, !dbg !262 + %6338 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6269, 63, !dbg !262 + %6339 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %6275, float %6276, float %6277, float %6278, float %6279, float %6280, float %6281, float %6282, float %6283, float %6284, float %6285, float %6286, float %6287, float %6288, float %6289, float %6290, float %6291, float %6292, float %6293, float %6294, float %6295, float %6296, float %6297, float %6298, float %6299, float %6300, float %6301, float %6302, float %6303, float %6304, float %6305, float %6306, float %6307, float %6308, float %6309, float %6310, float %6311, float %6312, float %6313, float %6314, float %6315, float %6316, float %6317, float %6318, float %6319, float %6320, float %6321, float %6322, float %6323, float %6324, float %6325, float %6326, float %6327, float %6328, float %6329, float %6330, float %6331, float %6332, float %6333, float %6334, float %6335, float %6336, float %6337, float %6338, i32 %6186, i32 %6187, i32 %6188, i32 %6189, i64 %6274, i1 true) #3, !dbg !262 + %6340 = add i32 %6194, 6144, !dbg !262 + %6341 = lshr exact i32 %6340, 4, !dbg !262 + %6342 = and i32 %6341, 16383, !dbg !262 + %6343 = zext nneg i32 %6342 to i64, !dbg !262 + %6344 = or disjoint i64 %6343, 4611686293338849280, !dbg !262 + %6345 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 0, !dbg !262 + %6346 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 1, !dbg !262 + %6347 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 2, !dbg !262 + %6348 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 3, !dbg !262 + %6349 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 4, !dbg !262 + %6350 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 5, !dbg !262 + %6351 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 6, !dbg !262 + %6352 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 7, !dbg !262 + %6353 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 8, !dbg !262 + %6354 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 9, !dbg !262 + %6355 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 10, !dbg !262 + %6356 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 11, !dbg !262 + %6357 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 12, !dbg !262 + %6358 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 13, !dbg !262 + %6359 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 14, !dbg !262 + %6360 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 15, !dbg !262 + %6361 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 16, !dbg !262 + %6362 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 17, !dbg !262 + %6363 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 18, !dbg !262 + %6364 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 19, !dbg !262 + %6365 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 20, !dbg !262 + %6366 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 21, !dbg !262 + %6367 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 22, !dbg !262 + %6368 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 23, !dbg !262 + %6369 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 24, !dbg !262 + %6370 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 25, !dbg !262 + %6371 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 26, !dbg !262 + %6372 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 27, !dbg !262 + %6373 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 28, !dbg !262 + %6374 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 29, !dbg !262 + %6375 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 30, !dbg !262 + %6376 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 31, !dbg !262 + %6377 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 32, !dbg !262 + %6378 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 33, !dbg !262 + %6379 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 34, !dbg !262 + %6380 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 35, !dbg !262 + %6381 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 36, !dbg !262 + %6382 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 37, !dbg !262 + %6383 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 38, !dbg !262 + %6384 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 39, !dbg !262 + %6385 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 40, !dbg !262 + %6386 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 41, !dbg !262 + %6387 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 42, !dbg !262 + %6388 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 43, !dbg !262 + %6389 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 44, !dbg !262 + %6390 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 45, !dbg !262 + %6391 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 46, !dbg !262 + %6392 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 47, !dbg !262 + %6393 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 48, !dbg !262 + %6394 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 49, !dbg !262 + %6395 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 50, !dbg !262 + %6396 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 51, !dbg !262 + %6397 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 52, !dbg !262 + %6398 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 53, !dbg !262 + %6399 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 54, !dbg !262 + %6400 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 55, !dbg !262 + %6401 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 56, !dbg !262 + %6402 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 57, !dbg !262 + %6403 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 58, !dbg !262 + %6404 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 59, !dbg !262 + %6405 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 60, !dbg !262 + %6406 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 61, !dbg !262 + %6407 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 62, !dbg !262 + %6408 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6339, 63, !dbg !262 + %6409 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %6345, float %6346, float %6347, float %6348, float %6349, float %6350, float %6351, float %6352, float %6353, float %6354, float %6355, float %6356, float %6357, float %6358, float %6359, float %6360, float %6361, float %6362, float %6363, float %6364, float %6365, float %6366, float %6367, float %6368, float %6369, float %6370, float %6371, float %6372, float %6373, float %6374, float %6375, float %6376, float %6377, float %6378, float %6379, float %6380, float %6381, float %6382, float %6383, float %6384, float %6385, float %6386, float %6387, float %6388, float %6389, float %6390, float %6391, float %6392, float %6393, float %6394, float %6395, float %6396, float %6397, float %6398, float %6399, float %6400, float %6401, float %6402, float %6403, float %6404, float %6405, float %6406, float %6407, float %6408, i32 %6190, i32 %6191, i32 %6192, i32 %6193, i64 %6344, i1 true) #3, !dbg !262 + %6410 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 0, !dbg !262 + %6411 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 1, !dbg !262 + %6412 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 2, !dbg !262 + %6413 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 3, !dbg !262 + %6414 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 4, !dbg !262 + %6415 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 5, !dbg !262 + %6416 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 6, !dbg !262 + %6417 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 7, !dbg !262 + %6418 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 8, !dbg !262 + %6419 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 9, !dbg !262 + %6420 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 10, !dbg !262 + %6421 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 11, !dbg !262 + %6422 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 12, !dbg !262 + %6423 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 13, !dbg !262 + %6424 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 14, !dbg !262 + %6425 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 15, !dbg !262 + %6426 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 16, !dbg !262 + %6427 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 17, !dbg !262 + %6428 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 18, !dbg !262 + %6429 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 19, !dbg !262 + %6430 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 20, !dbg !262 + %6431 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 21, !dbg !262 + %6432 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 22, !dbg !262 + %6433 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 23, !dbg !262 + %6434 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 24, !dbg !262 + %6435 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 25, !dbg !262 + %6436 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 26, !dbg !262 + %6437 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 27, !dbg !262 + %6438 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 28, !dbg !262 + %6439 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 29, !dbg !262 + %6440 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 30, !dbg !262 + %6441 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 31, !dbg !262 + %6442 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 32, !dbg !262 + %6443 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 33, !dbg !262 + %6444 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 34, !dbg !262 + %6445 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 35, !dbg !262 + %6446 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 36, !dbg !262 + %6447 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 37, !dbg !262 + %6448 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 38, !dbg !262 + %6449 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 39, !dbg !262 + %6450 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 40, !dbg !262 + %6451 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 41, !dbg !262 + %6452 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 42, !dbg !262 + %6453 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 43, !dbg !262 + %6454 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 44, !dbg !262 + %6455 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 45, !dbg !262 + %6456 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 46, !dbg !262 + %6457 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 47, !dbg !262 + %6458 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 48, !dbg !262 + %6459 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 49, !dbg !262 + %6460 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 50, !dbg !262 + %6461 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 51, !dbg !262 + %6462 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 52, !dbg !262 + %6463 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 53, !dbg !262 + %6464 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 54, !dbg !262 + %6465 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 55, !dbg !262 + %6466 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 56, !dbg !262 + %6467 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 57, !dbg !262 + %6468 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 58, !dbg !262 + %6469 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 59, !dbg !262 + %6470 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 60, !dbg !262 + %6471 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 61, !dbg !262 + %6472 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 62, !dbg !262 + %6473 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6409, 63, !dbg !262 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !262 + %6474 = getelementptr float, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %5290, !dbg !248 + %6475 = getelementptr inbounds nuw i8, ptr addrspace(3) %6474, i32 %4812, !dbg !248 + %6476 = load float, ptr addrspace(3) %6475, align 8, !dbg !248 + %6477 = getelementptr inbounds nuw i8, ptr addrspace(3) %6475, i32 4, !dbg !248 + %6478 = load float, ptr addrspace(3) %6477, align 4, !dbg !248 + %6479 = getelementptr inbounds nuw i8, ptr addrspace(3) %6474, i32 %4815, !dbg !248 + %6480 = load float, ptr addrspace(3) %6479, align 8, !dbg !248 + %6481 = getelementptr inbounds nuw i8, ptr addrspace(3) %6479, i32 4, !dbg !248 + %6482 = load float, ptr addrspace(3) %6481, align 4, !dbg !248 + %6483 = getelementptr inbounds nuw i8, ptr addrspace(3) %6474, i32 %4817, !dbg !248 + %6484 = load float, ptr addrspace(3) %6483, align 8, !dbg !248 + %6485 = getelementptr inbounds nuw i8, ptr addrspace(3) %6483, i32 4, !dbg !248 + %6486 = load float, ptr addrspace(3) %6485, align 4, !dbg !248 + %6487 = getelementptr inbounds nuw i8, ptr addrspace(3) %6474, i32 %4819, !dbg !248 + %6488 = load float, ptr addrspace(3) %6487, align 8, !dbg !248 + %6489 = getelementptr inbounds nuw i8, ptr addrspace(3) %6487, i32 4, !dbg !248 + %6490 = load float, ptr addrspace(3) %6489, align 4, !dbg !248 + %6491 = getelementptr inbounds nuw i8, ptr addrspace(3) %6474, i32 %4821, !dbg !248 + %6492 = load float, ptr addrspace(3) %6491, align 8, !dbg !248 + %6493 = getelementptr inbounds nuw i8, ptr addrspace(3) %6491, i32 4, !dbg !248 + %6494 = load float, ptr addrspace(3) %6493, align 4, !dbg !248 + %6495 = getelementptr inbounds nuw i8, ptr addrspace(3) %6474, i32 %4823, !dbg !248 + %6496 = load float, ptr addrspace(3) %6495, align 8, !dbg !248 + %6497 = getelementptr inbounds nuw i8, ptr addrspace(3) %6495, i32 4, !dbg !248 + %6498 = load float, ptr addrspace(3) %6497, align 4, !dbg !248 + %6499 = getelementptr inbounds nuw i8, ptr addrspace(3) %6474, i32 %4825, !dbg !248 + %6500 = load float, ptr addrspace(3) %6499, align 8, !dbg !248 + %6501 = getelementptr inbounds nuw i8, ptr addrspace(3) %6499, i32 4, !dbg !248 + %6502 = load float, ptr addrspace(3) %6501, align 4, !dbg !248 + %6503 = getelementptr inbounds nuw i8, ptr addrspace(3) %6474, i32 %4827, !dbg !248 + %6504 = load float, ptr addrspace(3) %6503, align 8, !dbg !248 + %6505 = getelementptr inbounds nuw i8, ptr addrspace(3) %6503, i32 4, !dbg !248 + %6506 = load float, ptr addrspace(3) %6505, align 4, !dbg !248 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !263 + %6507 = add i32 %5358, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !263 + %6508 = lshr exact i32 %6507, 4, !dbg !263 + %6509 = and i32 %6508, 16383, !dbg !263 + %6510 = zext nneg i32 %6509 to i64, !dbg !263 + %6511 = or disjoint i64 %6510, 4611686293372403712, !dbg !263 + %6512 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $32, $33, 0, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,l,l"(i64 %6511, i64 %6198) #3, !dbg !263 + %6513 = add i32 %5370, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !263 + %6514 = lshr exact i32 %6513, 4, !dbg !263 + %6515 = and i32 %6514, 16383, !dbg !263 + %6516 = zext nneg i32 %6515 to i64, !dbg !263 + %6517 = or disjoint i64 %6516, 4611686293372403712, !dbg !263 + %6518 = add i32 %6194, 32, !dbg !263 + %6519 = lshr exact i32 %6518, 4, !dbg !263 + %6520 = and i32 %6519, 16383, !dbg !263 + %6521 = zext nneg i32 %6520 to i64, !dbg !263 + %6522 = or disjoint i64 %6521, 4611686293338849280, !dbg !263 + %6523 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6512, 0, !dbg !263 + %6524 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6512, 1, !dbg !263 + %6525 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6512, 2, !dbg !263 + %6526 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6512, 3, !dbg !263 + %6527 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6512, 4, !dbg !263 + %6528 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6512, 5, !dbg !263 + %6529 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6512, 6, !dbg !263 + %6530 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6512, 7, !dbg !263 + %6531 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6512, 8, !dbg !263 + %6532 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6512, 9, !dbg !263 + %6533 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6512, 10, !dbg !263 + %6534 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6512, 11, !dbg !263 + %6535 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6512, 12, !dbg !263 + %6536 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6512, 13, !dbg !263 + %6537 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6512, 14, !dbg !263 + %6538 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6512, 15, !dbg !263 + %6539 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6512, 16, !dbg !263 + %6540 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6512, 17, !dbg !263 + %6541 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6512, 18, !dbg !263 + %6542 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6512, 19, !dbg !263 + %6543 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6512, 20, !dbg !263 + %6544 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6512, 21, !dbg !263 + %6545 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6512, 22, !dbg !263 + %6546 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6512, 23, !dbg !263 + %6547 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6512, 24, !dbg !263 + %6548 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6512, 25, !dbg !263 + %6549 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6512, 26, !dbg !263 + %6550 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6512, 27, !dbg !263 + %6551 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6512, 28, !dbg !263 + %6552 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6512, 29, !dbg !263 + %6553 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6512, 30, !dbg !263 + %6554 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6512, 31, !dbg !263 + %6555 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %6523, float %6524, float %6525, float %6526, float %6527, float %6528, float %6529, float %6530, float %6531, float %6532, float %6533, float %6534, float %6535, float %6536, float %6537, float %6538, float %6539, float %6540, float %6541, float %6542, float %6543, float %6544, float %6545, float %6546, float %6547, float %6548, float %6549, float %6550, float %6551, float %6552, float %6553, float %6554, i64 %6517, i64 %6522, i1 true) #3, !dbg !263 + %6556 = add i32 %5414, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !263 + %6557 = lshr exact i32 %6556, 4, !dbg !263 + %6558 = and i32 %6557, 16383, !dbg !263 + %6559 = zext nneg i32 %6558 to i64, !dbg !263 + %6560 = or disjoint i64 %6559, 4611686293372403712, !dbg !263 + %6561 = add i32 %6194, 64, !dbg !263 + %6562 = lshr exact i32 %6561, 4, !dbg !263 + %6563 = and i32 %6562, 16383, !dbg !263 + %6564 = zext nneg i32 %6563 to i64, !dbg !263 + %6565 = or disjoint i64 %6564, 4611686293338849280, !dbg !263 + %6566 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6555, 0, !dbg !263 + %6567 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6555, 1, !dbg !263 + %6568 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6555, 2, !dbg !263 + %6569 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6555, 3, !dbg !263 + %6570 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6555, 4, !dbg !263 + %6571 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6555, 5, !dbg !263 + %6572 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6555, 6, !dbg !263 + %6573 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6555, 7, !dbg !263 + %6574 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6555, 8, !dbg !263 + %6575 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6555, 9, !dbg !263 + %6576 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6555, 10, !dbg !263 + %6577 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6555, 11, !dbg !263 + %6578 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6555, 12, !dbg !263 + %6579 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6555, 13, !dbg !263 + %6580 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6555, 14, !dbg !263 + %6581 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6555, 15, !dbg !263 + %6582 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6555, 16, !dbg !263 + %6583 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6555, 17, !dbg !263 + %6584 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6555, 18, !dbg !263 + %6585 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6555, 19, !dbg !263 + %6586 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6555, 20, !dbg !263 + %6587 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6555, 21, !dbg !263 + %6588 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6555, 22, !dbg !263 + %6589 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6555, 23, !dbg !263 + %6590 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6555, 24, !dbg !263 + %6591 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6555, 25, !dbg !263 + %6592 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6555, 26, !dbg !263 + %6593 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6555, 27, !dbg !263 + %6594 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6555, 28, !dbg !263 + %6595 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6555, 29, !dbg !263 + %6596 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6555, 30, !dbg !263 + %6597 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6555, 31, !dbg !263 + %6598 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %6566, float %6567, float %6568, float %6569, float %6570, float %6571, float %6572, float %6573, float %6574, float %6575, float %6576, float %6577, float %6578, float %6579, float %6580, float %6581, float %6582, float %6583, float %6584, float %6585, float %6586, float %6587, float %6588, float %6589, float %6590, float %6591, float %6592, float %6593, float %6594, float %6595, float %6596, float %6597, i64 %6560, i64 %6565, i1 true) #3, !dbg !263 + %6599 = add i32 %5458, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !263 + %6600 = lshr exact i32 %6599, 4, !dbg !263 + %6601 = and i32 %6600, 16383, !dbg !263 + %6602 = zext nneg i32 %6601 to i64, !dbg !263 + %6603 = or disjoint i64 %6602, 4611686293372403712, !dbg !263 + %6604 = add i32 %6194, 96, !dbg !263 + %6605 = lshr exact i32 %6604, 4, !dbg !263 + %6606 = and i32 %6605, 16383, !dbg !263 + %6607 = zext nneg i32 %6606 to i64, !dbg !263 + %6608 = or disjoint i64 %6607, 4611686293338849280, !dbg !263 + %6609 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6598, 0, !dbg !263 + %6610 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6598, 1, !dbg !263 + %6611 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6598, 2, !dbg !263 + %6612 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6598, 3, !dbg !263 + %6613 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6598, 4, !dbg !263 + %6614 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6598, 5, !dbg !263 + %6615 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6598, 6, !dbg !263 + %6616 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6598, 7, !dbg !263 + %6617 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6598, 8, !dbg !263 + %6618 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6598, 9, !dbg !263 + %6619 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6598, 10, !dbg !263 + %6620 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6598, 11, !dbg !263 + %6621 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6598, 12, !dbg !263 + %6622 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6598, 13, !dbg !263 + %6623 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6598, 14, !dbg !263 + %6624 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6598, 15, !dbg !263 + %6625 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6598, 16, !dbg !263 + %6626 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6598, 17, !dbg !263 + %6627 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6598, 18, !dbg !263 + %6628 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6598, 19, !dbg !263 + %6629 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6598, 20, !dbg !263 + %6630 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6598, 21, !dbg !263 + %6631 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6598, 22, !dbg !263 + %6632 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6598, 23, !dbg !263 + %6633 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6598, 24, !dbg !263 + %6634 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6598, 25, !dbg !263 + %6635 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6598, 26, !dbg !263 + %6636 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6598, 27, !dbg !263 + %6637 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6598, 28, !dbg !263 + %6638 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6598, 29, !dbg !263 + %6639 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6598, 30, !dbg !263 + %6640 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6598, 31, !dbg !263 + %6641 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %6609, float %6610, float %6611, float %6612, float %6613, float %6614, float %6615, float %6616, float %6617, float %6618, float %6619, float %6620, float %6621, float %6622, float %6623, float %6624, float %6625, float %6626, float %6627, float %6628, float %6629, float %6630, float %6631, float %6632, float %6633, float %6634, float %6635, float %6636, float %6637, float %6638, float %6639, float %6640, i64 %6603, i64 %6608, i1 true) #3, !dbg !263 + %6642 = add i32 %5502, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !263 + %6643 = lshr exact i32 %6642, 4, !dbg !263 + %6644 = and i32 %6643, 16383, !dbg !263 + %6645 = zext nneg i32 %6644 to i64, !dbg !263 + %6646 = or disjoint i64 %6645, 4611686293372403712, !dbg !263 + %6647 = add i32 %6194, 8192, !dbg !263 + %6648 = lshr exact i32 %6647, 4, !dbg !263 + %6649 = and i32 %6648, 16383, !dbg !263 + %6650 = zext nneg i32 %6649 to i64, !dbg !263 + %6651 = or disjoint i64 %6650, 4611686293338849280, !dbg !263 + %6652 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6641, 0, !dbg !263 + %6653 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6641, 1, !dbg !263 + %6654 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6641, 2, !dbg !263 + %6655 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6641, 3, !dbg !263 + %6656 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6641, 4, !dbg !263 + %6657 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6641, 5, !dbg !263 + %6658 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6641, 6, !dbg !263 + %6659 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6641, 7, !dbg !263 + %6660 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6641, 8, !dbg !263 + %6661 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6641, 9, !dbg !263 + %6662 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6641, 10, !dbg !263 + %6663 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6641, 11, !dbg !263 + %6664 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6641, 12, !dbg !263 + %6665 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6641, 13, !dbg !263 + %6666 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6641, 14, !dbg !263 + %6667 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6641, 15, !dbg !263 + %6668 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6641, 16, !dbg !263 + %6669 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6641, 17, !dbg !263 + %6670 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6641, 18, !dbg !263 + %6671 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6641, 19, !dbg !263 + %6672 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6641, 20, !dbg !263 + %6673 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6641, 21, !dbg !263 + %6674 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6641, 22, !dbg !263 + %6675 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6641, 23, !dbg !263 + %6676 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6641, 24, !dbg !263 + %6677 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6641, 25, !dbg !263 + %6678 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6641, 26, !dbg !263 + %6679 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6641, 27, !dbg !263 + %6680 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6641, 28, !dbg !263 + %6681 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6641, 29, !dbg !263 + %6682 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6641, 30, !dbg !263 + %6683 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6641, 31, !dbg !263 + %6684 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %6652, float %6653, float %6654, float %6655, float %6656, float %6657, float %6658, float %6659, float %6660, float %6661, float %6662, float %6663, float %6664, float %6665, float %6666, float %6667, float %6668, float %6669, float %6670, float %6671, float %6672, float %6673, float %6674, float %6675, float %6676, float %6677, float %6678, float %6679, float %6680, float %6681, float %6682, float %6683, i64 %6646, i64 %6651, i1 true) #3, !dbg !263 + %6685 = add i32 %5546, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !263 + %6686 = lshr exact i32 %6685, 4, !dbg !263 + %6687 = and i32 %6686, 16383, !dbg !263 + %6688 = zext nneg i32 %6687 to i64, !dbg !263 + %6689 = or disjoint i64 %6688, 4611686293372403712, !dbg !263 + %6690 = add i32 %6194, 8224, !dbg !263 + %6691 = lshr exact i32 %6690, 4, !dbg !263 + %6692 = and i32 %6691, 16383, !dbg !263 + %6693 = zext nneg i32 %6692 to i64, !dbg !263 + %6694 = or disjoint i64 %6693, 4611686293338849280, !dbg !263 + %6695 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6684, 0, !dbg !263 + %6696 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6684, 1, !dbg !263 + %6697 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6684, 2, !dbg !263 + %6698 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6684, 3, !dbg !263 + %6699 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6684, 4, !dbg !263 + %6700 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6684, 5, !dbg !263 + %6701 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6684, 6, !dbg !263 + %6702 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6684, 7, !dbg !263 + %6703 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6684, 8, !dbg !263 + %6704 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6684, 9, !dbg !263 + %6705 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6684, 10, !dbg !263 + %6706 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6684, 11, !dbg !263 + %6707 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6684, 12, !dbg !263 + %6708 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6684, 13, !dbg !263 + %6709 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6684, 14, !dbg !263 + %6710 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6684, 15, !dbg !263 + %6711 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6684, 16, !dbg !263 + %6712 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6684, 17, !dbg !263 + %6713 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6684, 18, !dbg !263 + %6714 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6684, 19, !dbg !263 + %6715 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6684, 20, !dbg !263 + %6716 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6684, 21, !dbg !263 + %6717 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6684, 22, !dbg !263 + %6718 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6684, 23, !dbg !263 + %6719 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6684, 24, !dbg !263 + %6720 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6684, 25, !dbg !263 + %6721 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6684, 26, !dbg !263 + %6722 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6684, 27, !dbg !263 + %6723 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6684, 28, !dbg !263 + %6724 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6684, 29, !dbg !263 + %6725 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6684, 30, !dbg !263 + %6726 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6684, 31, !dbg !263 + %6727 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %6695, float %6696, float %6697, float %6698, float %6699, float %6700, float %6701, float %6702, float %6703, float %6704, float %6705, float %6706, float %6707, float %6708, float %6709, float %6710, float %6711, float %6712, float %6713, float %6714, float %6715, float %6716, float %6717, float %6718, float %6719, float %6720, float %6721, float %6722, float %6723, float %6724, float %6725, float %6726, i64 %6689, i64 %6694, i1 true) #3, !dbg !263 + %6728 = add i32 %5590, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !263 + %6729 = lshr exact i32 %6728, 4, !dbg !263 + %6730 = and i32 %6729, 16383, !dbg !263 + %6731 = zext nneg i32 %6730 to i64, !dbg !263 + %6732 = or disjoint i64 %6731, 4611686293372403712, !dbg !263 + %6733 = add i32 %6194, 8256, !dbg !263 + %6734 = lshr exact i32 %6733, 4, !dbg !263 + %6735 = and i32 %6734, 16383, !dbg !263 + %6736 = zext nneg i32 %6735 to i64, !dbg !263 + %6737 = or disjoint i64 %6736, 4611686293338849280, !dbg !263 + %6738 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6727, 0, !dbg !263 + %6739 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6727, 1, !dbg !263 + %6740 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6727, 2, !dbg !263 + %6741 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6727, 3, !dbg !263 + %6742 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6727, 4, !dbg !263 + %6743 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6727, 5, !dbg !263 + %6744 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6727, 6, !dbg !263 + %6745 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6727, 7, !dbg !263 + %6746 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6727, 8, !dbg !263 + %6747 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6727, 9, !dbg !263 + %6748 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6727, 10, !dbg !263 + %6749 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6727, 11, !dbg !263 + %6750 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6727, 12, !dbg !263 + %6751 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6727, 13, !dbg !263 + %6752 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6727, 14, !dbg !263 + %6753 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6727, 15, !dbg !263 + %6754 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6727, 16, !dbg !263 + %6755 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6727, 17, !dbg !263 + %6756 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6727, 18, !dbg !263 + %6757 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6727, 19, !dbg !263 + %6758 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6727, 20, !dbg !263 + %6759 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6727, 21, !dbg !263 + %6760 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6727, 22, !dbg !263 + %6761 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6727, 23, !dbg !263 + %6762 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6727, 24, !dbg !263 + %6763 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6727, 25, !dbg !263 + %6764 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6727, 26, !dbg !263 + %6765 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6727, 27, !dbg !263 + %6766 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6727, 28, !dbg !263 + %6767 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6727, 29, !dbg !263 + %6768 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6727, 30, !dbg !263 + %6769 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6727, 31, !dbg !263 + %6770 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %6738, float %6739, float %6740, float %6741, float %6742, float %6743, float %6744, float %6745, float %6746, float %6747, float %6748, float %6749, float %6750, float %6751, float %6752, float %6753, float %6754, float %6755, float %6756, float %6757, float %6758, float %6759, float %6760, float %6761, float %6762, float %6763, float %6764, float %6765, float %6766, float %6767, float %6768, float %6769, i64 %6732, i64 %6737, i1 true) #3, !dbg !263 + %6771 = add i32 %5634, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !263 + %6772 = lshr exact i32 %6771, 4, !dbg !263 + %6773 = and i32 %6772, 16383, !dbg !263 + %6774 = zext nneg i32 %6773 to i64, !dbg !263 + %6775 = or disjoint i64 %6774, 4611686293372403712, !dbg !263 + %6776 = add i32 %6194, 8288, !dbg !263 + %6777 = lshr exact i32 %6776, 4, !dbg !263 + %6778 = and i32 %6777, 16383, !dbg !263 + %6779 = zext nneg i32 %6778 to i64, !dbg !263 + %6780 = or disjoint i64 %6779, 4611686293338849280, !dbg !263 + %6781 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6770, 0, !dbg !263 + %6782 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6770, 1, !dbg !263 + %6783 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6770, 2, !dbg !263 + %6784 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6770, 3, !dbg !263 + %6785 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6770, 4, !dbg !263 + %6786 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6770, 5, !dbg !263 + %6787 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6770, 6, !dbg !263 + %6788 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6770, 7, !dbg !263 + %6789 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6770, 8, !dbg !263 + %6790 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6770, 9, !dbg !263 + %6791 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6770, 10, !dbg !263 + %6792 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6770, 11, !dbg !263 + %6793 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6770, 12, !dbg !263 + %6794 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6770, 13, !dbg !263 + %6795 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6770, 14, !dbg !263 + %6796 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6770, 15, !dbg !263 + %6797 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6770, 16, !dbg !263 + %6798 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6770, 17, !dbg !263 + %6799 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6770, 18, !dbg !263 + %6800 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6770, 19, !dbg !263 + %6801 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6770, 20, !dbg !263 + %6802 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6770, 21, !dbg !263 + %6803 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6770, 22, !dbg !263 + %6804 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6770, 23, !dbg !263 + %6805 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6770, 24, !dbg !263 + %6806 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6770, 25, !dbg !263 + %6807 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6770, 26, !dbg !263 + %6808 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6770, 27, !dbg !263 + %6809 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6770, 28, !dbg !263 + %6810 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6770, 29, !dbg !263 + %6811 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6770, 30, !dbg !263 + %6812 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6770, 31, !dbg !263 + %6813 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %6781, float %6782, float %6783, float %6784, float %6785, float %6786, float %6787, float %6788, float %6789, float %6790, float %6791, float %6792, float %6793, float %6794, float %6795, float %6796, float %6797, float %6798, float %6799, float %6800, float %6801, float %6802, float %6803, float %6804, float %6805, float %6806, float %6807, float %6808, float %6809, float %6810, float %6811, float %6812, i64 %6775, i64 %6780, i1 true) #3, !dbg !263 + %6814 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6813, 0, !dbg !263 + %6815 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6813, 1, !dbg !263 + %6816 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6813, 2, !dbg !263 + %6817 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6813, 3, !dbg !263 + %6818 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6813, 4, !dbg !263 + %6819 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6813, 5, !dbg !263 + %6820 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6813, 6, !dbg !263 + %6821 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6813, 7, !dbg !263 + %6822 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6813, 8, !dbg !263 + %6823 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6813, 9, !dbg !263 + %6824 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6813, 10, !dbg !263 + %6825 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6813, 11, !dbg !263 + %6826 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6813, 12, !dbg !263 + %6827 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6813, 13, !dbg !263 + %6828 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6813, 14, !dbg !263 + %6829 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6813, 15, !dbg !263 + %6830 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6813, 16, !dbg !263 + %6831 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6813, 17, !dbg !263 + %6832 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6813, 18, !dbg !263 + %6833 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6813, 19, !dbg !263 + %6834 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6813, 20, !dbg !263 + %6835 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6813, 21, !dbg !263 + %6836 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6813, 22, !dbg !263 + %6837 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6813, 23, !dbg !263 + %6838 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6813, 24, !dbg !263 + %6839 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6813, 25, !dbg !263 + %6840 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6813, 26, !dbg !263 + %6841 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6813, 27, !dbg !263 + %6842 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6813, 28, !dbg !263 + %6843 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6813, 29, !dbg !263 + %6844 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6813, 30, !dbg !263 + %6845 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6813, 31, !dbg !263 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !263 + %6846 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } asm sideeffect "// wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37\0A\09wgmma.wait_group.sync.aligned 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37"(float %6814, float %6815, float %6816, float %6817, float %6818, float %6819, float %6820, float %6821, float %6822, float %6823, float %6824, float %6825, float %6826, float %6827, float %6828, float %6829, float %6830, float %6831, float %6832, float %6833, float %6834, float %6835, float %6836, float %6837, float %6838, float %6839, float %6840, float %6841, float %6842, float %6843, float %6844, float %6845, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096), i32 0, i32 0, ptr addrspace(3) %6129, i32 0, i32 0) #3, !dbg !263 + %6847 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6846, 0, !dbg !263 + %6848 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6846, 1, !dbg !263 + %6849 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6846, 2, !dbg !263 + %6850 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6846, 3, !dbg !263 + %6851 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6846, 4, !dbg !263 + %6852 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6846, 5, !dbg !263 + %6853 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6846, 6, !dbg !263 + %6854 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6846, 7, !dbg !263 + %6855 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6846, 8, !dbg !263 + %6856 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6846, 9, !dbg !263 + %6857 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6846, 10, !dbg !263 + %6858 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6846, 11, !dbg !263 + %6859 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6846, 12, !dbg !263 + %6860 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6846, 13, !dbg !263 + %6861 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6846, 14, !dbg !263 + %6862 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6846, 15, !dbg !263 + %6863 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6846, 16, !dbg !263 + %6864 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6846, 17, !dbg !263 + %6865 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6846, 18, !dbg !263 + %6866 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6846, 19, !dbg !263 + %6867 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6846, 20, !dbg !263 + %6868 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6846, 21, !dbg !263 + %6869 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6846, 22, !dbg !263 + %6870 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6846, 23, !dbg !263 + %6871 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6846, 24, !dbg !263 + %6872 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6846, 25, !dbg !263 + %6873 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6846, 26, !dbg !263 + %6874 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6846, 27, !dbg !263 + %6875 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6846, 28, !dbg !263 + %6876 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6846, 29, !dbg !263 + %6877 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6846, 30, !dbg !263 + %6878 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6846, 31, !dbg !263 + %6879 = fsub float %6847, %6476, !dbg !264 + %6880 = fsub float %6848, %6478, !dbg !264 + %6881 = fsub float %6849, %6476, !dbg !264 + %6882 = fsub float %6850, %6478, !dbg !264 + %6883 = fsub float %6851, %6480, !dbg !264 + %6884 = fsub float %6852, %6482, !dbg !264 + %6885 = fsub float %6853, %6480, !dbg !264 + %6886 = fsub float %6854, %6482, !dbg !264 + %6887 = fsub float %6855, %6484, !dbg !264 + %6888 = fsub float %6856, %6486, !dbg !264 + %6889 = fsub float %6857, %6484, !dbg !264 + %6890 = fsub float %6858, %6486, !dbg !264 + %6891 = fsub float %6859, %6488, !dbg !264 + %6892 = fsub float %6860, %6490, !dbg !264 + %6893 = fsub float %6861, %6488, !dbg !264 + %6894 = fsub float %6862, %6490, !dbg !264 + %6895 = fsub float %6863, %6492, !dbg !264 + %6896 = fsub float %6864, %6494, !dbg !264 + %6897 = fsub float %6865, %6492, !dbg !264 + %6898 = fsub float %6866, %6494, !dbg !264 + %6899 = fsub float %6867, %6496, !dbg !264 + %6900 = fsub float %6868, %6498, !dbg !264 + %6901 = fsub float %6869, %6496, !dbg !264 + %6902 = fsub float %6870, %6498, !dbg !264 + %6903 = fsub float %6871, %6500, !dbg !264 + %6904 = fsub float %6872, %6502, !dbg !264 + %6905 = fsub float %6873, %6500, !dbg !264 + %6906 = fsub float %6874, %6502, !dbg !264 + %6907 = fsub float %6875, %6504, !dbg !264 + %6908 = fsub float %6876, %6506, !dbg !264 + %6909 = fsub float %6877, %6504, !dbg !264 + %6910 = fsub float %6878, %6506, !dbg !264 + %6911 = fmul float %.0.i1221, %6879, !dbg !265 + %6912 = fmul float %.0.i1224, %6880, !dbg !265 + %6913 = fmul float %.0.i1227, %6881, !dbg !265 + %6914 = fmul float %.0.i1230, %6882, !dbg !265 + %6915 = fmul float %.0.i1233, %6883, !dbg !265 + %6916 = fmul float %.0.i1236, %6884, !dbg !265 + %6917 = fmul float %.0.i1239, %6885, !dbg !265 + %6918 = fmul float %.0.i1242, %6886, !dbg !265 + %6919 = fmul float %.0.i1245, %6887, !dbg !265 + %6920 = fmul float %.0.i1248, %6888, !dbg !265 + %6921 = fmul float %.0.i1251, %6889, !dbg !265 + %6922 = fmul float %.0.i1254, %6890, !dbg !265 + %6923 = fmul float %.0.i1257, %6891, !dbg !265 + %6924 = fmul float %.0.i1260, %6892, !dbg !265 + %6925 = fmul float %.0.i1263, %6893, !dbg !265 + %6926 = fmul float %.0.i1266, %6894, !dbg !265 + %6927 = fmul float %.0.i1269, %6895, !dbg !265 + %6928 = fmul float %.0.i1272, %6896, !dbg !265 + %6929 = fmul float %.0.i1275, %6897, !dbg !265 + %6930 = fmul float %.0.i1278, %6898, !dbg !265 + %6931 = fmul float %.0.i1281, %6899, !dbg !265 + %6932 = fmul float %.0.i1284, %6900, !dbg !265 + %6933 = fmul float %.0.i1287, %6901, !dbg !265 + %6934 = fmul float %.0.i1290, %6902, !dbg !265 + %6935 = fmul float %.0.i1293, %6903, !dbg !265 + %6936 = fmul float %.0.i1296, %6904, !dbg !265 + %6937 = fmul float %.0.i1299, %6905, !dbg !265 + %6938 = fmul float %.0.i1302, %6906, !dbg !265 + %6939 = fmul float %.0.i1305, %6907, !dbg !265 + %6940 = fmul float %.0.i1308, %6908, !dbg !265 + %6941 = fmul float %.0.i1311, %6909, !dbg !265 + %6942 = fmul float %.0.i1314, %6910, !dbg !265 + %6943 = fptrunc float %6911 to bfloat, !dbg !266 + %6944 = select i1 %5826, bfloat %6943, bfloat 0xR0000, !dbg !267 + %6945 = fptrunc float %6912 to bfloat, !dbg !266 + %6946 = select i1 %5828, bfloat %6945, bfloat 0xR0000, !dbg !267 + %6947 = fptrunc float %6913 to bfloat, !dbg !266 + %6948 = select i1 %5829, bfloat %6947, bfloat 0xR0000, !dbg !267 + %6949 = fptrunc float %6914 to bfloat, !dbg !266 + %6950 = select i1 %5830, bfloat %6949, bfloat 0xR0000, !dbg !267 + %6951 = fptrunc float %6915 to bfloat, !dbg !266 + %6952 = select i1 %5832, bfloat %6951, bfloat 0xR0000, !dbg !267 + %6953 = fptrunc float %6916 to bfloat, !dbg !266 + %6954 = select i1 %5834, bfloat %6953, bfloat 0xR0000, !dbg !267 + %6955 = fptrunc float %6917 to bfloat, !dbg !266 + %6956 = select i1 %5835, bfloat %6955, bfloat 0xR0000, !dbg !267 + %6957 = fptrunc float %6918 to bfloat, !dbg !266 + %6958 = select i1 %5836, bfloat %6957, bfloat 0xR0000, !dbg !267 + %6959 = fptrunc float %6919 to bfloat, !dbg !266 + %6960 = select i1 %5838, bfloat %6959, bfloat 0xR0000, !dbg !267 + %6961 = fptrunc float %6920 to bfloat, !dbg !266 + %6962 = select i1 %5840, bfloat %6961, bfloat 0xR0000, !dbg !267 + %6963 = fptrunc float %6921 to bfloat, !dbg !266 + %6964 = select i1 %5841, bfloat %6963, bfloat 0xR0000, !dbg !267 + %6965 = fptrunc float %6922 to bfloat, !dbg !266 + %6966 = select i1 %5842, bfloat %6965, bfloat 0xR0000, !dbg !267 + %6967 = fptrunc float %6923 to bfloat, !dbg !266 + %6968 = select i1 %5844, bfloat %6967, bfloat 0xR0000, !dbg !267 + %6969 = fptrunc float %6924 to bfloat, !dbg !266 + %6970 = select i1 %5846, bfloat %6969, bfloat 0xR0000, !dbg !267 + %6971 = fptrunc float %6925 to bfloat, !dbg !266 + %6972 = select i1 %5847, bfloat %6971, bfloat 0xR0000, !dbg !267 + %6973 = fptrunc float %6926 to bfloat, !dbg !266 + %6974 = select i1 %5848, bfloat %6973, bfloat 0xR0000, !dbg !267 + %6975 = fptrunc float %6927 to bfloat, !dbg !266 + %6976 = select i1 %5850, bfloat %6975, bfloat 0xR0000, !dbg !267 + %6977 = fptrunc float %6928 to bfloat, !dbg !266 + %6978 = select i1 %5852, bfloat %6977, bfloat 0xR0000, !dbg !267 + %6979 = fptrunc float %6929 to bfloat, !dbg !266 + %6980 = select i1 %5853, bfloat %6979, bfloat 0xR0000, !dbg !267 + %6981 = fptrunc float %6930 to bfloat, !dbg !266 + %6982 = select i1 %5854, bfloat %6981, bfloat 0xR0000, !dbg !267 + %6983 = fptrunc float %6931 to bfloat, !dbg !266 + %6984 = select i1 %5856, bfloat %6983, bfloat 0xR0000, !dbg !267 + %6985 = fptrunc float %6932 to bfloat, !dbg !266 + %6986 = select i1 %5858, bfloat %6985, bfloat 0xR0000, !dbg !267 + %6987 = fptrunc float %6933 to bfloat, !dbg !266 + %6988 = select i1 %5859, bfloat %6987, bfloat 0xR0000, !dbg !267 + %6989 = fptrunc float %6934 to bfloat, !dbg !266 + %6990 = select i1 %5860, bfloat %6989, bfloat 0xR0000, !dbg !267 + %6991 = fptrunc float %6935 to bfloat, !dbg !266 + %6992 = select i1 %5862, bfloat %6991, bfloat 0xR0000, !dbg !267 + %6993 = fptrunc float %6936 to bfloat, !dbg !266 + %6994 = select i1 %5864, bfloat %6993, bfloat 0xR0000, !dbg !267 + %6995 = fptrunc float %6937 to bfloat, !dbg !266 + %6996 = select i1 %5865, bfloat %6995, bfloat 0xR0000, !dbg !267 + %6997 = fptrunc float %6938 to bfloat, !dbg !266 + %6998 = select i1 %5866, bfloat %6997, bfloat 0xR0000, !dbg !267 + %6999 = fptrunc float %6939 to bfloat, !dbg !266 + %7000 = select i1 %5868, bfloat %6999, bfloat 0xR0000, !dbg !267 + %7001 = fptrunc float %6940 to bfloat, !dbg !266 + %7002 = select i1 %5870, bfloat %7001, bfloat 0xR0000, !dbg !267 + %7003 = fptrunc float %6941 to bfloat, !dbg !266 + %7004 = select i1 %5871, bfloat %7003, bfloat 0xR0000, !dbg !267 + %7005 = fptrunc float %6942 to bfloat, !dbg !266 + %7006 = select i1 %5872, bfloat %7005, bfloat 0xR0000, !dbg !267 + %7007 = insertelement <2 x bfloat> poison, bfloat %6944, i64 0, !dbg !268 + %7008 = insertelement <2 x bfloat> %7007, bfloat %6946, i64 1, !dbg !268 + %7009 = bitcast <2 x bfloat> %7008 to i32, !dbg !268 + %7010 = insertelement <2 x bfloat> poison, bfloat %6948, i64 0, !dbg !268 + %7011 = insertelement <2 x bfloat> %7010, bfloat %6950, i64 1, !dbg !268 + %7012 = bitcast <2 x bfloat> %7011 to i32, !dbg !268 + %7013 = insertelement <2 x bfloat> poison, bfloat %6952, i64 0, !dbg !268 + %7014 = insertelement <2 x bfloat> %7013, bfloat %6954, i64 1, !dbg !268 + %7015 = bitcast <2 x bfloat> %7014 to i32, !dbg !268 + %7016 = insertelement <2 x bfloat> poison, bfloat %6956, i64 0, !dbg !268 + %7017 = insertelement <2 x bfloat> %7016, bfloat %6958, i64 1, !dbg !268 + %7018 = bitcast <2 x bfloat> %7017 to i32, !dbg !268 + %7019 = insertelement <2 x bfloat> poison, bfloat %6960, i64 0, !dbg !268 + %7020 = insertelement <2 x bfloat> %7019, bfloat %6962, i64 1, !dbg !268 + %7021 = bitcast <2 x bfloat> %7020 to i32, !dbg !268 + %7022 = insertelement <2 x bfloat> poison, bfloat %6964, i64 0, !dbg !268 + %7023 = insertelement <2 x bfloat> %7022, bfloat %6966, i64 1, !dbg !268 + %7024 = bitcast <2 x bfloat> %7023 to i32, !dbg !268 + %7025 = insertelement <2 x bfloat> poison, bfloat %6968, i64 0, !dbg !268 + %7026 = insertelement <2 x bfloat> %7025, bfloat %6970, i64 1, !dbg !268 + %7027 = bitcast <2 x bfloat> %7026 to i32, !dbg !268 + %7028 = insertelement <2 x bfloat> poison, bfloat %6972, i64 0, !dbg !268 + %7029 = insertelement <2 x bfloat> %7028, bfloat %6974, i64 1, !dbg !268 + %7030 = bitcast <2 x bfloat> %7029 to i32, !dbg !268 + %7031 = insertelement <2 x bfloat> poison, bfloat %6976, i64 0, !dbg !268 + %7032 = insertelement <2 x bfloat> %7031, bfloat %6978, i64 1, !dbg !268 + %7033 = bitcast <2 x bfloat> %7032 to i32, !dbg !268 + %7034 = insertelement <2 x bfloat> poison, bfloat %6980, i64 0, !dbg !268 + %7035 = insertelement <2 x bfloat> %7034, bfloat %6982, i64 1, !dbg !268 + %7036 = bitcast <2 x bfloat> %7035 to i32, !dbg !268 + %7037 = insertelement <2 x bfloat> poison, bfloat %6984, i64 0, !dbg !268 + %7038 = insertelement <2 x bfloat> %7037, bfloat %6986, i64 1, !dbg !268 + %7039 = bitcast <2 x bfloat> %7038 to i32, !dbg !268 + %7040 = insertelement <2 x bfloat> poison, bfloat %6988, i64 0, !dbg !268 + %7041 = insertelement <2 x bfloat> %7040, bfloat %6990, i64 1, !dbg !268 + %7042 = bitcast <2 x bfloat> %7041 to i32, !dbg !268 + %7043 = insertelement <2 x bfloat> poison, bfloat %6992, i64 0, !dbg !268 + %7044 = insertelement <2 x bfloat> %7043, bfloat %6994, i64 1, !dbg !268 + %7045 = bitcast <2 x bfloat> %7044 to i32, !dbg !268 + %7046 = insertelement <2 x bfloat> poison, bfloat %6996, i64 0, !dbg !268 + %7047 = insertelement <2 x bfloat> %7046, bfloat %6998, i64 1, !dbg !268 + %7048 = bitcast <2 x bfloat> %7047 to i32, !dbg !268 + %7049 = insertelement <2 x bfloat> poison, bfloat %7000, i64 0, !dbg !268 + %7050 = insertelement <2 x bfloat> %7049, bfloat %7002, i64 1, !dbg !268 + %7051 = bitcast <2 x bfloat> %7050 to i32, !dbg !268 + %7052 = insertelement <2 x bfloat> poison, bfloat %7004, i64 0, !dbg !268 + %7053 = insertelement <2 x bfloat> %7052, bfloat %7006, i64 1, !dbg !268 + %7054 = bitcast <2 x bfloat> %7053 to i32, !dbg !268 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !268 + %7055 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %5214, float %5215, float %5216, float %5217, float %5218, float %5219, float %5220, float %5221, float %5222, float %5223, float %5224, float %5225, float %5226, float %5227, float %5228, float %5229, float %5230, float %5231, float %5232, float %5233, float %5234, float %5235, float %5236, float %5237, float %5238, float %5239, float %5240, float %5241, float %5242, float %5243, float %5244, float %5245, float %5246, float %5247, float %5248, float %5249, float %5250, float %5251, float %5252, float %5253, float %5254, float %5255, float %5256, float %5257, float %5258, float %5259, float %5260, float %5261, float %5262, float %5263, float %5264, float %5265, float %5266, float %5267, float %5268, float %5269, float %5270, float %5271, float %5272, float %5273, float %5274, float %5275, float %5276, float %5277, i32 %7009, i32 %7012, i32 %7015, i32 %7018, i64 %5368, i1 true) #3, !dbg !268 + %7056 = add i32 %5364, 2048, !dbg !268 + %7057 = lshr exact i32 %7056, 4, !dbg !268 + %7058 = and i32 %7057, 16383, !dbg !268 + %7059 = zext nneg i32 %7058 to i64, !dbg !268 + %7060 = or disjoint i64 %7059, 4611686293338849280, !dbg !268 + %7061 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 0, !dbg !268 + %7062 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 1, !dbg !268 + %7063 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 2, !dbg !268 + %7064 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 3, !dbg !268 + %7065 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 4, !dbg !268 + %7066 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 5, !dbg !268 + %7067 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 6, !dbg !268 + %7068 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 7, !dbg !268 + %7069 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 8, !dbg !268 + %7070 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 9, !dbg !268 + %7071 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 10, !dbg !268 + %7072 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 11, !dbg !268 + %7073 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 12, !dbg !268 + %7074 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 13, !dbg !268 + %7075 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 14, !dbg !268 + %7076 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 15, !dbg !268 + %7077 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 16, !dbg !268 + %7078 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 17, !dbg !268 + %7079 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 18, !dbg !268 + %7080 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 19, !dbg !268 + %7081 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 20, !dbg !268 + %7082 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 21, !dbg !268 + %7083 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 22, !dbg !268 + %7084 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 23, !dbg !268 + %7085 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 24, !dbg !268 + %7086 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 25, !dbg !268 + %7087 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 26, !dbg !268 + %7088 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 27, !dbg !268 + %7089 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 28, !dbg !268 + %7090 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 29, !dbg !268 + %7091 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 30, !dbg !268 + %7092 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 31, !dbg !268 + %7093 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 32, !dbg !268 + %7094 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 33, !dbg !268 + %7095 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 34, !dbg !268 + %7096 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 35, !dbg !268 + %7097 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 36, !dbg !268 + %7098 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 37, !dbg !268 + %7099 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 38, !dbg !268 + %7100 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 39, !dbg !268 + %7101 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 40, !dbg !268 + %7102 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 41, !dbg !268 + %7103 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 42, !dbg !268 + %7104 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 43, !dbg !268 + %7105 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 44, !dbg !268 + %7106 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 45, !dbg !268 + %7107 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 46, !dbg !268 + %7108 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 47, !dbg !268 + %7109 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 48, !dbg !268 + %7110 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 49, !dbg !268 + %7111 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 50, !dbg !268 + %7112 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 51, !dbg !268 + %7113 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 52, !dbg !268 + %7114 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 53, !dbg !268 + %7115 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 54, !dbg !268 + %7116 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 55, !dbg !268 + %7117 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 56, !dbg !268 + %7118 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 57, !dbg !268 + %7119 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 58, !dbg !268 + %7120 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 59, !dbg !268 + %7121 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 60, !dbg !268 + %7122 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 61, !dbg !268 + %7123 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 62, !dbg !268 + %7124 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7055, 63, !dbg !268 + %7125 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %7061, float %7062, float %7063, float %7064, float %7065, float %7066, float %7067, float %7068, float %7069, float %7070, float %7071, float %7072, float %7073, float %7074, float %7075, float %7076, float %7077, float %7078, float %7079, float %7080, float %7081, float %7082, float %7083, float %7084, float %7085, float %7086, float %7087, float %7088, float %7089, float %7090, float %7091, float %7092, float %7093, float %7094, float %7095, float %7096, float %7097, float %7098, float %7099, float %7100, float %7101, float %7102, float %7103, float %7104, float %7105, float %7106, float %7107, float %7108, float %7109, float %7110, float %7111, float %7112, float %7113, float %7114, float %7115, float %7116, float %7117, float %7118, float %7119, float %7120, float %7121, float %7122, float %7123, float %7124, i32 %7021, i32 %7024, i32 %7027, i32 %7030, i64 %7060, i1 true) #3, !dbg !268 + %7126 = add i32 %5364, 4096, !dbg !268 + %7127 = lshr exact i32 %7126, 4, !dbg !268 + %7128 = and i32 %7127, 16383, !dbg !268 + %7129 = zext nneg i32 %7128 to i64, !dbg !268 + %7130 = or disjoint i64 %7129, 4611686293338849280, !dbg !268 + %7131 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 0, !dbg !268 + %7132 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 1, !dbg !268 + %7133 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 2, !dbg !268 + %7134 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 3, !dbg !268 + %7135 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 4, !dbg !268 + %7136 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 5, !dbg !268 + %7137 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 6, !dbg !268 + %7138 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 7, !dbg !268 + %7139 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 8, !dbg !268 + %7140 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 9, !dbg !268 + %7141 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 10, !dbg !268 + %7142 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 11, !dbg !268 + %7143 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 12, !dbg !268 + %7144 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 13, !dbg !268 + %7145 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 14, !dbg !268 + %7146 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 15, !dbg !268 + %7147 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 16, !dbg !268 + %7148 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 17, !dbg !268 + %7149 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 18, !dbg !268 + %7150 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 19, !dbg !268 + %7151 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 20, !dbg !268 + %7152 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 21, !dbg !268 + %7153 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 22, !dbg !268 + %7154 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 23, !dbg !268 + %7155 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 24, !dbg !268 + %7156 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 25, !dbg !268 + %7157 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 26, !dbg !268 + %7158 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 27, !dbg !268 + %7159 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 28, !dbg !268 + %7160 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 29, !dbg !268 + %7161 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 30, !dbg !268 + %7162 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 31, !dbg !268 + %7163 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 32, !dbg !268 + %7164 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 33, !dbg !268 + %7165 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 34, !dbg !268 + %7166 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 35, !dbg !268 + %7167 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 36, !dbg !268 + %7168 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 37, !dbg !268 + %7169 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 38, !dbg !268 + %7170 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 39, !dbg !268 + %7171 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 40, !dbg !268 + %7172 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 41, !dbg !268 + %7173 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 42, !dbg !268 + %7174 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 43, !dbg !268 + %7175 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 44, !dbg !268 + %7176 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 45, !dbg !268 + %7177 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 46, !dbg !268 + %7178 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 47, !dbg !268 + %7179 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 48, !dbg !268 + %7180 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 49, !dbg !268 + %7181 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 50, !dbg !268 + %7182 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 51, !dbg !268 + %7183 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 52, !dbg !268 + %7184 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 53, !dbg !268 + %7185 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 54, !dbg !268 + %7186 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 55, !dbg !268 + %7187 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 56, !dbg !268 + %7188 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 57, !dbg !268 + %7189 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 58, !dbg !268 + %7190 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 59, !dbg !268 + %7191 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 60, !dbg !268 + %7192 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 61, !dbg !268 + %7193 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 62, !dbg !268 + %7194 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7125, 63, !dbg !268 + %7195 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %7131, float %7132, float %7133, float %7134, float %7135, float %7136, float %7137, float %7138, float %7139, float %7140, float %7141, float %7142, float %7143, float %7144, float %7145, float %7146, float %7147, float %7148, float %7149, float %7150, float %7151, float %7152, float %7153, float %7154, float %7155, float %7156, float %7157, float %7158, float %7159, float %7160, float %7161, float %7162, float %7163, float %7164, float %7165, float %7166, float %7167, float %7168, float %7169, float %7170, float %7171, float %7172, float %7173, float %7174, float %7175, float %7176, float %7177, float %7178, float %7179, float %7180, float %7181, float %7182, float %7183, float %7184, float %7185, float %7186, float %7187, float %7188, float %7189, float %7190, float %7191, float %7192, float %7193, float %7194, i32 %7033, i32 %7036, i32 %7039, i32 %7042, i64 %7130, i1 true) #3, !dbg !268 + %7196 = add i32 %5364, 6144, !dbg !268 + %7197 = lshr exact i32 %7196, 4, !dbg !268 + %7198 = and i32 %7197, 16383, !dbg !268 + %7199 = zext nneg i32 %7198 to i64, !dbg !268 + %7200 = or disjoint i64 %7199, 4611686293338849280, !dbg !268 + %7201 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 0, !dbg !268 + %7202 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 1, !dbg !268 + %7203 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 2, !dbg !268 + %7204 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 3, !dbg !268 + %7205 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 4, !dbg !268 + %7206 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 5, !dbg !268 + %7207 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 6, !dbg !268 + %7208 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 7, !dbg !268 + %7209 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 8, !dbg !268 + %7210 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 9, !dbg !268 + %7211 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 10, !dbg !268 + %7212 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 11, !dbg !268 + %7213 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 12, !dbg !268 + %7214 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 13, !dbg !268 + %7215 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 14, !dbg !268 + %7216 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 15, !dbg !268 + %7217 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 16, !dbg !268 + %7218 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 17, !dbg !268 + %7219 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 18, !dbg !268 + %7220 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 19, !dbg !268 + %7221 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 20, !dbg !268 + %7222 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 21, !dbg !268 + %7223 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 22, !dbg !268 + %7224 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 23, !dbg !268 + %7225 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 24, !dbg !268 + %7226 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 25, !dbg !268 + %7227 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 26, !dbg !268 + %7228 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 27, !dbg !268 + %7229 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 28, !dbg !268 + %7230 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 29, !dbg !268 + %7231 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 30, !dbg !268 + %7232 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 31, !dbg !268 + %7233 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 32, !dbg !268 + %7234 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 33, !dbg !268 + %7235 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 34, !dbg !268 + %7236 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 35, !dbg !268 + %7237 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 36, !dbg !268 + %7238 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 37, !dbg !268 + %7239 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 38, !dbg !268 + %7240 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 39, !dbg !268 + %7241 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 40, !dbg !268 + %7242 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 41, !dbg !268 + %7243 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 42, !dbg !268 + %7244 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 43, !dbg !268 + %7245 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 44, !dbg !268 + %7246 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 45, !dbg !268 + %7247 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 46, !dbg !268 + %7248 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 47, !dbg !268 + %7249 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 48, !dbg !268 + %7250 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 49, !dbg !268 + %7251 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 50, !dbg !268 + %7252 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 51, !dbg !268 + %7253 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 52, !dbg !268 + %7254 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 53, !dbg !268 + %7255 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 54, !dbg !268 + %7256 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 55, !dbg !268 + %7257 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 56, !dbg !268 + %7258 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 57, !dbg !268 + %7259 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 58, !dbg !268 + %7260 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 59, !dbg !268 + %7261 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 60, !dbg !268 + %7262 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 61, !dbg !268 + %7263 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 62, !dbg !268 + %7264 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7195, 63, !dbg !268 + %7265 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %7201, float %7202, float %7203, float %7204, float %7205, float %7206, float %7207, float %7208, float %7209, float %7210, float %7211, float %7212, float %7213, float %7214, float %7215, float %7216, float %7217, float %7218, float %7219, float %7220, float %7221, float %7222, float %7223, float %7224, float %7225, float %7226, float %7227, float %7228, float %7229, float %7230, float %7231, float %7232, float %7233, float %7234, float %7235, float %7236, float %7237, float %7238, float %7239, float %7240, float %7241, float %7242, float %7243, float %7244, float %7245, float %7246, float %7247, float %7248, float %7249, float %7250, float %7251, float %7252, float %7253, float %7254, float %7255, float %7256, float %7257, float %7258, float %7259, float %7260, float %7261, float %7262, float %7263, float %7264, i32 %7045, i32 %7048, i32 %7051, i32 %7054, i64 %7200, i1 true) #3, !dbg !268 + %7266 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 0, !dbg !268 + %7267 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 1, !dbg !268 + %7268 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 2, !dbg !268 + %7269 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 3, !dbg !268 + %7270 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 4, !dbg !268 + %7271 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 5, !dbg !268 + %7272 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 6, !dbg !268 + %7273 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 7, !dbg !268 + %7274 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 8, !dbg !268 + %7275 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 9, !dbg !268 + %7276 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 10, !dbg !268 + %7277 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 11, !dbg !268 + %7278 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 12, !dbg !268 + %7279 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 13, !dbg !268 + %7280 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 14, !dbg !268 + %7281 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 15, !dbg !268 + %7282 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 16, !dbg !268 + %7283 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 17, !dbg !268 + %7284 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 18, !dbg !268 + %7285 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 19, !dbg !268 + %7286 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 20, !dbg !268 + %7287 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 21, !dbg !268 + %7288 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 22, !dbg !268 + %7289 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 23, !dbg !268 + %7290 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 24, !dbg !268 + %7291 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 25, !dbg !268 + %7292 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 26, !dbg !268 + %7293 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 27, !dbg !268 + %7294 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 28, !dbg !268 + %7295 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 29, !dbg !268 + %7296 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 30, !dbg !268 + %7297 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 31, !dbg !268 + %7298 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 32, !dbg !268 + %7299 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 33, !dbg !268 + %7300 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 34, !dbg !268 + %7301 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 35, !dbg !268 + %7302 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 36, !dbg !268 + %7303 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 37, !dbg !268 + %7304 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 38, !dbg !268 + %7305 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 39, !dbg !268 + %7306 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 40, !dbg !268 + %7307 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 41, !dbg !268 + %7308 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 42, !dbg !268 + %7309 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 43, !dbg !268 + %7310 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 44, !dbg !268 + %7311 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 45, !dbg !268 + %7312 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 46, !dbg !268 + %7313 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 47, !dbg !268 + %7314 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 48, !dbg !268 + %7315 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 49, !dbg !268 + %7316 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 50, !dbg !268 + %7317 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 51, !dbg !268 + %7318 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 52, !dbg !268 + %7319 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 53, !dbg !268 + %7320 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 54, !dbg !268 + %7321 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 55, !dbg !268 + %7322 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 56, !dbg !268 + %7323 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 57, !dbg !268 + %7324 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 58, !dbg !268 + %7325 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 59, !dbg !268 + %7326 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 60, !dbg !268 + %7327 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 61, !dbg !268 + %7328 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 62, !dbg !268 + %7329 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7265, 63, !dbg !268 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !268 + %7330 = insertelement <16 x i32> poison, i32 %5145, i64 0, !dbg !269 + %7331 = shufflevector <16 x i32> %7330, <16 x i32> poison, <16 x i32> zeroinitializer, !dbg !269 + %7332 = add <16 x i32> %5279, %7331, !dbg !269 + %7333 = add nuw nsw i32 %5278, 1, !dbg !213 + %7334 = lshr i32 %7333, 1, !dbg !270 + %7335 = zext nneg i32 %7334 to i64, !dbg !271 + %7336 = getelementptr i32, ptr addrspace(1) %4707, i64 %7335, !dbg !271 + %7337 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #3, !dbg !272 + %7338 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b32 { $0 }, [ $1 + 0 ], $2;", "=r,l,l,b"(ptr addrspace(1) %7336, i64 %7337, i1 %5281) #3, !dbg !272 + %7339 = add nuw nsw i32 %7334, 1, !dbg !273 + %7340 = icmp slt i32 %7339, %4712, !dbg !274 + %7341 = getelementptr i8, ptr addrspace(1) %7336, i64 4, !dbg !275 + %7342 = and i1 %5281, %7340, !dbg !213 + %7343 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #3, !dbg !276 + %7344 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b32 { $0 }, [ $1 + 0 ], $2;", "=r,l,l,b"(ptr addrspace(1) %7341, i64 %7343, i1 %7342) #3, !dbg !276 + %7345 = and i32 %5278, 1, !dbg !277 + %7346 = sub i32 %7344, %7338, !dbg !278 + %7347 = shl i32 %7346, 7, !dbg !279 + %7348 = add i32 %7347, -64, !dbg !280 + %7349 = xor i32 %7345, 1, !dbg !281 + %7350 = mul nuw nsw i32 %7348, %7349, !dbg !281 + %7351 = shl nuw nsw i32 %7345, 6, !dbg !282 + %7352 = add i32 %7350, %7351, !dbg !283 + %7353 = shl i32 %7352, 12, !dbg !284 + %7354 = sext i32 %7353 to i64, !dbg !249 + %7355 = getelementptr bfloat, ptr addrspace(1) %.pn1201513, i64 %7354, !dbg !249 + %7356 = getelementptr bfloat, ptr addrspace(1) %.pn1041514, i64 %7354, !dbg !249 + %7357 = getelementptr bfloat, ptr addrspace(1) %.pn881515, i64 %7354, !dbg !249 + %7358 = getelementptr bfloat, ptr addrspace(1) %.pn721516, i64 %7354, !dbg !249 + %7359 = shl i32 %7352, 7, !dbg !285 + %7360 = sext i32 %7359 to i64, !dbg !250 + %7361 = getelementptr bfloat, ptr addrspace(1) %.pn1841517, i64 %7360, !dbg !250 + %7362 = getelementptr bfloat, ptr addrspace(1) %.pn1681518, i64 %7360, !dbg !250 + %7363 = getelementptr bfloat, ptr addrspace(1) %.pn1521519, i64 %7360, !dbg !250 + %7364 = getelementptr bfloat, ptr addrspace(1) %.pn1361520, i64 %7360, !dbg !250 + %7365 = add i32 %7352, %.pn2161521, !dbg !269 + %7366 = add i32 %7352, %.pn2121522, !dbg !269 + %7367 = add i32 %7352, %.pn2081523, !dbg !269 + %7368 = add i32 %7352, %.pn2041524, !dbg !269 + %7369 = add i32 %7352, %.pn2001525, !dbg !269 + %7370 = add i32 %7352, %.pn1961526, !dbg !269 + %7371 = add i32 %7352, %.pn1921527, !dbg !269 + %7372 = add i32 %7352, %.pn1881528, !dbg !269 + %7373 = add i32 %5147, 1, !dbg !213 + %7374 = icmp sgt i32 %7373, 1, !dbg !213 + %7375 = select i1 %7374, i32 0, i32 %7373, !dbg !213 + %7376 = add i32 %5149, 1, !dbg !213 + %7377 = icmp sgt i32 %7376, 2, !dbg !213 + %7378 = select i1 %7377, i32 0, i32 %7376, !dbg !213 + %7379 = shl i32 %7378, 13, !dbg !244 + %7380 = getelementptr bfloat, ptr addrspace(3) @global_smem, i32 %7379, !dbg !244 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !244 + %7381 = getelementptr inbounds nuw i8, ptr addrspace(3) %7380, i32 %4793, !dbg !244 + %7382 = select i1 %5280, i32 16, i32 0, !dbg !244 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %7381, ptr addrspace(1) %7355, i32 %7382) #3, !dbg !244 + %7383 = getelementptr inbounds nuw i8, ptr addrspace(3) %7380, i32 %4796, !dbg !244 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %7383, ptr addrspace(1) %7356, i32 %7382) #3, !dbg !244 + %7384 = getelementptr inbounds nuw i8, ptr addrspace(3) %7380, i32 %4798, !dbg !244 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %7384, ptr addrspace(1) %7357, i32 %7382) #3, !dbg !244 + %7385 = getelementptr inbounds nuw i8, ptr addrspace(3) %7380, i32 %4800, !dbg !244 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %7385, ptr addrspace(1) %7358, i32 %7382) #3, !dbg !244 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !244 + %7386 = sext i32 %7365 to i64, !dbg !245 + %7387 = getelementptr float, ptr addrspace(1) %5087, i64 %7386, !dbg !245 + %7388 = sext i32 %7366 to i64, !dbg !245 + %7389 = getelementptr float, ptr addrspace(1) %5087, i64 %7388, !dbg !245 + %7390 = sext i32 %7367 to i64, !dbg !245 + %7391 = getelementptr float, ptr addrspace(1) %5087, i64 %7390, !dbg !245 + %7392 = sext i32 %7368 to i64, !dbg !245 + %7393 = getelementptr float, ptr addrspace(1) %5087, i64 %7392, !dbg !245 + %7394 = sext i32 %7369 to i64, !dbg !245 + %7395 = getelementptr float, ptr addrspace(1) %5087, i64 %7394, !dbg !245 + %7396 = sext i32 %7370 to i64, !dbg !245 + %7397 = getelementptr float, ptr addrspace(1) %5087, i64 %7396, !dbg !245 + %7398 = sext i32 %7371 to i64, !dbg !245 + %7399 = getelementptr float, ptr addrspace(1) %5087, i64 %7398, !dbg !245 + %7400 = sext i32 %7372 to i64, !dbg !245 + %7401 = getelementptr float, ptr addrspace(1) %5087, i64 %7400, !dbg !245 + %7402 = shl i32 %7375, 6, !dbg !246 + %7403 = getelementptr float, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %7402, !dbg !246 + %7404 = getelementptr inbounds nuw i8, ptr addrspace(3) %7403, i32 %4812, !dbg !246 + %7405 = select i1 %5280, i32 8, i32 0, !dbg !246 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) %7404, ptr addrspace(1) %7387, i32 %7405, i1 %4811) #3, !dbg !246 + %7406 = getelementptr inbounds nuw i8, ptr addrspace(3) %7403, i32 %4815, !dbg !246 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %7406, ptr addrspace(1) %7389, i32 %7405, i1 %4811) #3, !dbg !246 + %7407 = getelementptr inbounds nuw i8, ptr addrspace(3) %7403, i32 %4817, !dbg !246 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %7407, ptr addrspace(1) %7391, i32 %7405, i1 %4811) #3, !dbg !246 + %7408 = getelementptr inbounds nuw i8, ptr addrspace(3) %7403, i32 %4819, !dbg !246 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %7408, ptr addrspace(1) %7393, i32 %7405, i1 %4811) #3, !dbg !246 + %7409 = getelementptr inbounds nuw i8, ptr addrspace(3) %7403, i32 %4821, !dbg !246 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %7409, ptr addrspace(1) %7395, i32 %7405, i1 %4811) #3, !dbg !246 + %7410 = getelementptr inbounds nuw i8, ptr addrspace(3) %7403, i32 %4823, !dbg !246 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %7410, ptr addrspace(1) %7397, i32 %7405, i1 %4811) #3, !dbg !246 + %7411 = getelementptr inbounds nuw i8, ptr addrspace(3) %7403, i32 %4825, !dbg !246 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %7411, ptr addrspace(1) %7399, i32 %7405, i1 %4811) #3, !dbg !246 + %7412 = getelementptr inbounds nuw i8, ptr addrspace(3) %7403, i32 %4827, !dbg !246 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %7412, ptr addrspace(1) %7401, i32 %7405, i1 %4811) #3, !dbg !246 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !246 + %7413 = getelementptr bfloat, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %7379, !dbg !244 + %7414 = getelementptr inbounds nuw i8, ptr addrspace(3) %7413, i32 %4793, !dbg !244 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %7414, ptr addrspace(1) %7361, i32 %7382) #3, !dbg !244 + %7415 = getelementptr inbounds nuw i8, ptr addrspace(3) %7413, i32 %4796, !dbg !244 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %7415, ptr addrspace(1) %7362, i32 %7382) #3, !dbg !244 + %7416 = getelementptr inbounds nuw i8, ptr addrspace(3) %7413, i32 %4798, !dbg !244 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %7416, ptr addrspace(1) %7363, i32 %7382) #3, !dbg !244 + %7417 = getelementptr inbounds nuw i8, ptr addrspace(3) %7413, i32 %4800, !dbg !244 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %7417, ptr addrspace(1) %7364, i32 %7382) #3, !dbg !244 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !244 + %7418 = getelementptr float, ptr addrspace(1) %5088, i64 %7386, !dbg !247 + %7419 = getelementptr float, ptr addrspace(1) %5088, i64 %7388, !dbg !247 + %7420 = getelementptr float, ptr addrspace(1) %5088, i64 %7390, !dbg !247 + %7421 = getelementptr float, ptr addrspace(1) %5088, i64 %7392, !dbg !247 + %7422 = getelementptr float, ptr addrspace(1) %5088, i64 %7394, !dbg !247 + %7423 = getelementptr float, ptr addrspace(1) %5088, i64 %7396, !dbg !247 + %7424 = getelementptr float, ptr addrspace(1) %5088, i64 %7398, !dbg !247 + %7425 = getelementptr float, ptr addrspace(1) %5088, i64 %7400, !dbg !247 + %7426 = getelementptr float, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %7402, !dbg !248 + %7427 = getelementptr inbounds nuw i8, ptr addrspace(3) %7426, i32 %4812, !dbg !248 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) %7427, ptr addrspace(1) %7418, i32 %7405, i1 %4811) #3, !dbg !248 + %7428 = getelementptr inbounds nuw i8, ptr addrspace(3) %7426, i32 %4815, !dbg !248 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %7428, ptr addrspace(1) %7419, i32 %7405, i1 %4811) #3, !dbg !248 + %7429 = getelementptr inbounds nuw i8, ptr addrspace(3) %7426, i32 %4817, !dbg !248 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %7429, ptr addrspace(1) %7420, i32 %7405, i1 %4811) #3, !dbg !248 + %7430 = getelementptr inbounds nuw i8, ptr addrspace(3) %7426, i32 %4819, !dbg !248 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %7430, ptr addrspace(1) %7421, i32 %7405, i1 %4811) #3, !dbg !248 + %7431 = getelementptr inbounds nuw i8, ptr addrspace(3) %7426, i32 %4821, !dbg !248 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %7431, ptr addrspace(1) %7422, i32 %7405, i1 %4811) #3, !dbg !248 + %7432 = getelementptr inbounds nuw i8, ptr addrspace(3) %7426, i32 %4823, !dbg !248 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %7432, ptr addrspace(1) %7423, i32 %7405, i1 %4811) #3, !dbg !248 + %7433 = getelementptr inbounds nuw i8, ptr addrspace(3) %7426, i32 %4825, !dbg !248 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %7433, ptr addrspace(1) %7424, i32 %7405, i1 %4811) #3, !dbg !248 + %7434 = getelementptr inbounds nuw i8, ptr addrspace(3) %7426, i32 %4827, !dbg !248 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %7434, ptr addrspace(1) %7425, i32 %7405, i1 %4811) #3, !dbg !248 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !248 + %exitcond.not = icmp eq i32 %7333, %smax, !dbg !213 + br i1 %exitcond.not, label %._crit_edge, label %.lr.ph, !dbg !213 + +._crit_edge: ; preds = %__nv_exp2f.exit1315, %4945 + %7435 = phi float [ %4946, %4945 ], [ %7266, %__nv_exp2f.exit1315 ] + %7436 = phi float [ %4947, %4945 ], [ %7267, %__nv_exp2f.exit1315 ] + %7437 = phi float [ %4948, %4945 ], [ %7268, %__nv_exp2f.exit1315 ] + %7438 = phi float [ %4949, %4945 ], [ %7269, %__nv_exp2f.exit1315 ] + %7439 = phi float [ %4950, %4945 ], [ %7270, %__nv_exp2f.exit1315 ] + %7440 = phi float [ %4951, %4945 ], [ %7271, %__nv_exp2f.exit1315 ] + %7441 = phi float [ %4952, %4945 ], [ %7272, %__nv_exp2f.exit1315 ] + %7442 = phi float [ %4953, %4945 ], [ %7273, %__nv_exp2f.exit1315 ] + %7443 = phi float [ %4954, %4945 ], [ %7274, %__nv_exp2f.exit1315 ] + %7444 = phi float [ %4955, %4945 ], [ %7275, %__nv_exp2f.exit1315 ] + %7445 = phi float [ %4956, %4945 ], [ %7276, %__nv_exp2f.exit1315 ] + %7446 = phi float [ %4957, %4945 ], [ %7277, %__nv_exp2f.exit1315 ] + %7447 = phi float [ %4958, %4945 ], [ %7278, %__nv_exp2f.exit1315 ] + %7448 = phi float [ %4959, %4945 ], [ %7279, %__nv_exp2f.exit1315 ] + %7449 = phi float [ %4960, %4945 ], [ %7280, %__nv_exp2f.exit1315 ] + %7450 = phi float [ %4961, %4945 ], [ %7281, %__nv_exp2f.exit1315 ] + %7451 = phi float [ %4962, %4945 ], [ %7282, %__nv_exp2f.exit1315 ] + %7452 = phi float [ %4963, %4945 ], [ %7283, %__nv_exp2f.exit1315 ] + %7453 = phi float [ %4964, %4945 ], [ %7284, %__nv_exp2f.exit1315 ] + %7454 = phi float [ %4965, %4945 ], [ %7285, %__nv_exp2f.exit1315 ] + %7455 = phi float [ %4966, %4945 ], [ %7286, %__nv_exp2f.exit1315 ] + %7456 = phi float [ %4967, %4945 ], [ %7287, %__nv_exp2f.exit1315 ] + %7457 = phi float [ %4968, %4945 ], [ %7288, %__nv_exp2f.exit1315 ] + %7458 = phi float [ %4969, %4945 ], [ %7289, %__nv_exp2f.exit1315 ] + %7459 = phi float [ %4970, %4945 ], [ %7290, %__nv_exp2f.exit1315 ] + %7460 = phi float [ %4971, %4945 ], [ %7291, %__nv_exp2f.exit1315 ] + %7461 = phi float [ %4972, %4945 ], [ %7292, %__nv_exp2f.exit1315 ] + %7462 = phi float [ %4973, %4945 ], [ %7293, %__nv_exp2f.exit1315 ] + %7463 = phi float [ %4974, %4945 ], [ %7294, %__nv_exp2f.exit1315 ] + %7464 = phi float [ %4975, %4945 ], [ %7295, %__nv_exp2f.exit1315 ] + %7465 = phi float [ %4976, %4945 ], [ %7296, %__nv_exp2f.exit1315 ] + %7466 = phi float [ %4977, %4945 ], [ %7297, %__nv_exp2f.exit1315 ] + %7467 = phi float [ %4978, %4945 ], [ %7298, %__nv_exp2f.exit1315 ] + %7468 = phi float [ %4979, %4945 ], [ %7299, %__nv_exp2f.exit1315 ] + %7469 = phi float [ %4980, %4945 ], [ %7300, %__nv_exp2f.exit1315 ] + %7470 = phi float [ %4981, %4945 ], [ %7301, %__nv_exp2f.exit1315 ] + %7471 = phi float [ %4982, %4945 ], [ %7302, %__nv_exp2f.exit1315 ] + %7472 = phi float [ %4983, %4945 ], [ %7303, %__nv_exp2f.exit1315 ] + %7473 = phi float [ %4984, %4945 ], [ %7304, %__nv_exp2f.exit1315 ] + %7474 = phi float [ %4985, %4945 ], [ %7305, %__nv_exp2f.exit1315 ] + %7475 = phi float [ %4986, %4945 ], [ %7306, %__nv_exp2f.exit1315 ] + %7476 = phi float [ %4987, %4945 ], [ %7307, %__nv_exp2f.exit1315 ] + %7477 = phi float [ %4988, %4945 ], [ %7308, %__nv_exp2f.exit1315 ] + %7478 = phi float [ %4989, %4945 ], [ %7309, %__nv_exp2f.exit1315 ] + %7479 = phi float [ %4990, %4945 ], [ %7310, %__nv_exp2f.exit1315 ] + %7480 = phi float [ %4991, %4945 ], [ %7311, %__nv_exp2f.exit1315 ] + %7481 = phi float [ %4992, %4945 ], [ %7312, %__nv_exp2f.exit1315 ] + %7482 = phi float [ %4993, %4945 ], [ %7313, %__nv_exp2f.exit1315 ] + %7483 = phi float [ %4994, %4945 ], [ %7314, %__nv_exp2f.exit1315 ] + %7484 = phi float [ %4995, %4945 ], [ %7315, %__nv_exp2f.exit1315 ] + %7485 = phi float [ %4996, %4945 ], [ %7316, %__nv_exp2f.exit1315 ] + %7486 = phi float [ %4997, %4945 ], [ %7317, %__nv_exp2f.exit1315 ] + %7487 = phi float [ %4998, %4945 ], [ %7318, %__nv_exp2f.exit1315 ] + %7488 = phi float [ %4999, %4945 ], [ %7319, %__nv_exp2f.exit1315 ] + %7489 = phi float [ %5000, %4945 ], [ %7320, %__nv_exp2f.exit1315 ] + %7490 = phi float [ %5001, %4945 ], [ %7321, %__nv_exp2f.exit1315 ] + %7491 = phi float [ %5002, %4945 ], [ %7322, %__nv_exp2f.exit1315 ] + %7492 = phi float [ %5003, %4945 ], [ %7323, %__nv_exp2f.exit1315 ] + %7493 = phi float [ %5004, %4945 ], [ %7324, %__nv_exp2f.exit1315 ] + %7494 = phi float [ %5005, %4945 ], [ %7325, %__nv_exp2f.exit1315 ] + %7495 = phi float [ %5006, %4945 ], [ %7326, %__nv_exp2f.exit1315 ] + %7496 = phi float [ %5007, %4945 ], [ %7327, %__nv_exp2f.exit1315 ] + %7497 = phi float [ %5008, %4945 ], [ %7328, %__nv_exp2f.exit1315 ] + %7498 = phi float [ %5009, %4945 ], [ %7329, %__nv_exp2f.exit1315 ] + %7499 = phi float [ %5010, %4945 ], [ %6410, %__nv_exp2f.exit1315 ] + %7500 = phi float [ %5011, %4945 ], [ %6411, %__nv_exp2f.exit1315 ] + %7501 = phi float [ %5012, %4945 ], [ %6412, %__nv_exp2f.exit1315 ] + %7502 = phi float [ %5013, %4945 ], [ %6413, %__nv_exp2f.exit1315 ] + %7503 = phi float [ %5014, %4945 ], [ %6414, %__nv_exp2f.exit1315 ] + %7504 = phi float [ %5015, %4945 ], [ %6415, %__nv_exp2f.exit1315 ] + %7505 = phi float [ %5016, %4945 ], [ %6416, %__nv_exp2f.exit1315 ] + %7506 = phi float [ %5017, %4945 ], [ %6417, %__nv_exp2f.exit1315 ] + %7507 = phi float [ %5018, %4945 ], [ %6418, %__nv_exp2f.exit1315 ] + %7508 = phi float [ %5019, %4945 ], [ %6419, %__nv_exp2f.exit1315 ] + %7509 = phi float [ %5020, %4945 ], [ %6420, %__nv_exp2f.exit1315 ] + %7510 = phi float [ %5021, %4945 ], [ %6421, %__nv_exp2f.exit1315 ] + %7511 = phi float [ %5022, %4945 ], [ %6422, %__nv_exp2f.exit1315 ] + %7512 = phi float [ %5023, %4945 ], [ %6423, %__nv_exp2f.exit1315 ] + %7513 = phi float [ %5024, %4945 ], [ %6424, %__nv_exp2f.exit1315 ] + %7514 = phi float [ %5025, %4945 ], [ %6425, %__nv_exp2f.exit1315 ] + %7515 = phi float [ %5026, %4945 ], [ %6426, %__nv_exp2f.exit1315 ] + %7516 = phi float [ %5027, %4945 ], [ %6427, %__nv_exp2f.exit1315 ] + %7517 = phi float [ %5028, %4945 ], [ %6428, %__nv_exp2f.exit1315 ] + %7518 = phi float [ %5029, %4945 ], [ %6429, %__nv_exp2f.exit1315 ] + %7519 = phi float [ %5030, %4945 ], [ %6430, %__nv_exp2f.exit1315 ] + %7520 = phi float [ %5031, %4945 ], [ %6431, %__nv_exp2f.exit1315 ] + %7521 = phi float [ %5032, %4945 ], [ %6432, %__nv_exp2f.exit1315 ] + %7522 = phi float [ %5033, %4945 ], [ %6433, %__nv_exp2f.exit1315 ] + %7523 = phi float [ %5034, %4945 ], [ %6434, %__nv_exp2f.exit1315 ] + %7524 = phi float [ %5035, %4945 ], [ %6435, %__nv_exp2f.exit1315 ] + %7525 = phi float [ %5036, %4945 ], [ %6436, %__nv_exp2f.exit1315 ] + %7526 = phi float [ %5037, %4945 ], [ %6437, %__nv_exp2f.exit1315 ] + %7527 = phi float [ %5038, %4945 ], [ %6438, %__nv_exp2f.exit1315 ] + %7528 = phi float [ %5039, %4945 ], [ %6439, %__nv_exp2f.exit1315 ] + %7529 = phi float [ %5040, %4945 ], [ %6440, %__nv_exp2f.exit1315 ] + %7530 = phi float [ %5041, %4945 ], [ %6441, %__nv_exp2f.exit1315 ] + %7531 = phi float [ %5042, %4945 ], [ %6442, %__nv_exp2f.exit1315 ] + %7532 = phi float [ %5043, %4945 ], [ %6443, %__nv_exp2f.exit1315 ] + %7533 = phi float [ %5044, %4945 ], [ %6444, %__nv_exp2f.exit1315 ] + %7534 = phi float [ %5045, %4945 ], [ %6445, %__nv_exp2f.exit1315 ] + %7535 = phi float [ %5046, %4945 ], [ %6446, %__nv_exp2f.exit1315 ] + %7536 = phi float [ %5047, %4945 ], [ %6447, %__nv_exp2f.exit1315 ] + %7537 = phi float [ %5048, %4945 ], [ %6448, %__nv_exp2f.exit1315 ] + %7538 = phi float [ %5049, %4945 ], [ %6449, %__nv_exp2f.exit1315 ] + %7539 = phi float [ %5050, %4945 ], [ %6450, %__nv_exp2f.exit1315 ] + %7540 = phi float [ %5051, %4945 ], [ %6451, %__nv_exp2f.exit1315 ] + %7541 = phi float [ %5052, %4945 ], [ %6452, %__nv_exp2f.exit1315 ] + %7542 = phi float [ %5053, %4945 ], [ %6453, %__nv_exp2f.exit1315 ] + %7543 = phi float [ %5054, %4945 ], [ %6454, %__nv_exp2f.exit1315 ] + %7544 = phi float [ %5055, %4945 ], [ %6455, %__nv_exp2f.exit1315 ] + %7545 = phi float [ %5056, %4945 ], [ %6456, %__nv_exp2f.exit1315 ] + %7546 = phi float [ %5057, %4945 ], [ %6457, %__nv_exp2f.exit1315 ] + %7547 = phi float [ %5058, %4945 ], [ %6458, %__nv_exp2f.exit1315 ] + %7548 = phi float [ %5059, %4945 ], [ %6459, %__nv_exp2f.exit1315 ] + %7549 = phi float [ %5060, %4945 ], [ %6460, %__nv_exp2f.exit1315 ] + %7550 = phi float [ %5061, %4945 ], [ %6461, %__nv_exp2f.exit1315 ] + %7551 = phi float [ %5062, %4945 ], [ %6462, %__nv_exp2f.exit1315 ] + %7552 = phi float [ %5063, %4945 ], [ %6463, %__nv_exp2f.exit1315 ] + %7553 = phi float [ %5064, %4945 ], [ %6464, %__nv_exp2f.exit1315 ] + %7554 = phi float [ %5065, %4945 ], [ %6465, %__nv_exp2f.exit1315 ] + %7555 = phi float [ %5066, %4945 ], [ %6466, %__nv_exp2f.exit1315 ] + %7556 = phi float [ %5067, %4945 ], [ %6467, %__nv_exp2f.exit1315 ] + %7557 = phi float [ %5068, %4945 ], [ %6468, %__nv_exp2f.exit1315 ] + %7558 = phi float [ %5069, %4945 ], [ %6469, %__nv_exp2f.exit1315 ] + %7559 = phi float [ %5070, %4945 ], [ %6470, %__nv_exp2f.exit1315 ] + %7560 = phi float [ %5071, %4945 ], [ %6471, %__nv_exp2f.exit1315 ] + %7561 = phi float [ %5072, %4945 ], [ %6472, %__nv_exp2f.exit1315 ] + %7562 = phi float [ %5073, %4945 ], [ %6473, %__nv_exp2f.exit1315 ] + %7563 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "// wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63,$64,$65,$66,$67,$68,$69,$70,$71,$72,$73,$74,$75,$76,$77,$78,$79,$80,$81,$82,$83,$84,$85,$86,$87,$88,$89,$90,$91,$92,$93,$94,$95,$96,$97,$98,$99,$100,$101,$102,$103,$104,$105,$106,$107,$108,$109,$110,$111,$112,$113,$114,$115,$116,$117,$118,$119,$120,$121,$122,$123,$124,$125,$126,$127\0A\09wgmma.wait_group.sync.aligned 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127"(float %7499, float %7500, float %7501, float %7502, float %7503, float %7504, float %7505, float %7506, float %7507, float %7508, float %7509, float %7510, float %7511, float %7512, float %7513, float %7514, float %7515, float %7516, float %7517, float %7518, float %7519, float %7520, float %7521, float %7522, float %7523, float %7524, float %7525, float %7526, float %7527, float %7528, float %7529, float %7530, float %7531, float %7532, float %7533, float %7534, float %7535, float %7536, float %7537, float %7538, float %7539, float %7540, float %7541, float %7542, float %7543, float %7544, float %7545, float %7546, float %7547, float %7548, float %7549, float %7550, float %7551, float %7552, float %7553, float %7554, float %7555, float %7556, float %7557, float %7558, float %7559, float %7560, float %7561, float %7562, float %7435, float %7436, float %7437, float %7438, float %7439, float %7440, float %7441, float %7442, float %7443, float %7444, float %7445, float %7446, float %7447, float %7448, float %7449, float %7450, float %7451, float %7452, float %7453, float %7454, float %7455, float %7456, float %7457, float %7458, float %7459, float %7460, float %7461, float %7462, float %7463, float %7464, float %7465, float %7466, float %7467, float %7468, float %7469, float %7470, float %7471, float %7472, float %7473, float %7474, float %7475, float %7476, float %7477, float %7478, float %7479, float %7480, float %7481, float %7482, float %7483, float %7484, float %7485, float %7486, float %7487, float %7488, float %7489, float %7490, float %7491, float %7492, float %7493, float %7494, float %7495, float %7496, float %7497, float %7498) #3, !dbg !213 + %7564 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 0, !dbg !213 + %7565 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 1, !dbg !213 + %7566 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 2, !dbg !213 + %7567 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 3, !dbg !213 + %7568 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 4, !dbg !213 + %7569 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 5, !dbg !213 + %7570 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 6, !dbg !213 + %7571 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 7, !dbg !213 + %7572 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 8, !dbg !213 + %7573 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 9, !dbg !213 + %7574 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 10, !dbg !213 + %7575 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 11, !dbg !213 + %7576 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 12, !dbg !213 + %7577 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 13, !dbg !213 + %7578 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 14, !dbg !213 + %7579 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 15, !dbg !213 + %7580 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 16, !dbg !213 + %7581 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 17, !dbg !213 + %7582 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 18, !dbg !213 + %7583 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 19, !dbg !213 + %7584 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 20, !dbg !213 + %7585 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 21, !dbg !213 + %7586 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 22, !dbg !213 + %7587 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 23, !dbg !213 + %7588 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 24, !dbg !213 + %7589 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 25, !dbg !213 + %7590 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 26, !dbg !213 + %7591 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 27, !dbg !213 + %7592 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 28, !dbg !213 + %7593 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 29, !dbg !213 + %7594 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 30, !dbg !213 + %7595 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 31, !dbg !213 + %7596 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 32, !dbg !213 + %7597 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 33, !dbg !213 + %7598 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 34, !dbg !213 + %7599 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 35, !dbg !213 + %7600 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 36, !dbg !213 + %7601 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 37, !dbg !213 + %7602 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 38, !dbg !213 + %7603 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 39, !dbg !213 + %7604 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 40, !dbg !213 + %7605 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 41, !dbg !213 + %7606 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 42, !dbg !213 + %7607 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 43, !dbg !213 + %7608 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 44, !dbg !213 + %7609 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 45, !dbg !213 + %7610 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 46, !dbg !213 + %7611 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 47, !dbg !213 + %7612 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 48, !dbg !213 + %7613 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 49, !dbg !213 + %7614 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 50, !dbg !213 + %7615 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 51, !dbg !213 + %7616 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 52, !dbg !213 + %7617 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 53, !dbg !213 + %7618 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 54, !dbg !213 + %7619 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 55, !dbg !213 + %7620 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 56, !dbg !213 + %7621 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 57, !dbg !213 + %7622 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 58, !dbg !213 + %7623 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 59, !dbg !213 + %7624 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 60, !dbg !213 + %7625 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 61, !dbg !213 + %7626 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 62, !dbg !213 + %7627 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 63, !dbg !213 + %7628 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 64, !dbg !213 + %7629 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 65, !dbg !213 + %7630 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 66, !dbg !213 + %7631 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 67, !dbg !213 + %7632 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 68, !dbg !213 + %7633 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 69, !dbg !213 + %7634 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 70, !dbg !213 + %7635 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 71, !dbg !213 + %7636 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 72, !dbg !213 + %7637 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 73, !dbg !213 + %7638 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 74, !dbg !213 + %7639 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 75, !dbg !213 + %7640 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 76, !dbg !213 + %7641 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 77, !dbg !213 + %7642 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 78, !dbg !213 + %7643 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 79, !dbg !213 + %7644 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 80, !dbg !213 + %7645 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 81, !dbg !213 + %7646 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 82, !dbg !213 + %7647 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 83, !dbg !213 + %7648 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 84, !dbg !213 + %7649 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 85, !dbg !213 + %7650 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 86, !dbg !213 + %7651 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 87, !dbg !213 + %7652 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 88, !dbg !213 + %7653 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 89, !dbg !213 + %7654 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 90, !dbg !213 + %7655 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 91, !dbg !213 + %7656 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 92, !dbg !213 + %7657 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 93, !dbg !213 + %7658 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 94, !dbg !213 + %7659 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 95, !dbg !213 + %7660 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 96, !dbg !213 + %7661 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 97, !dbg !213 + %7662 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 98, !dbg !213 + %7663 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 99, !dbg !213 + %7664 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 100, !dbg !213 + %7665 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 101, !dbg !213 + %7666 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 102, !dbg !213 + %7667 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 103, !dbg !213 + %7668 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 104, !dbg !213 + %7669 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 105, !dbg !213 + %7670 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 106, !dbg !213 + %7671 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 107, !dbg !213 + %7672 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 108, !dbg !213 + %7673 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 109, !dbg !213 + %7674 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 110, !dbg !213 + %7675 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 111, !dbg !213 + %7676 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 112, !dbg !213 + %7677 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 113, !dbg !213 + %7678 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 114, !dbg !213 + %7679 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 115, !dbg !213 + %7680 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 116, !dbg !213 + %7681 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 117, !dbg !213 + %7682 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 118, !dbg !213 + %7683 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 119, !dbg !213 + %7684 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 120, !dbg !213 + %7685 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 121, !dbg !213 + %7686 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 122, !dbg !213 + %7687 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 123, !dbg !213 + %7688 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 124, !dbg !213 + %7689 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 125, !dbg !213 + %7690 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 126, !dbg !213 + %7691 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7563, 127, !dbg !213 + tail call void @llvm.nvvm.cp.async.wait.group(i32 0), !dbg !213 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !213 + %7692 = getelementptr bfloat, ptr addrspace(1) %5085, i64 %4886, !dbg !286 + %7693 = getelementptr bfloat, ptr addrspace(1) %5085, i64 %4887, !dbg !286 + %7694 = getelementptr bfloat, ptr addrspace(1) %5085, i64 %4888, !dbg !286 + %7695 = getelementptr bfloat, ptr addrspace(1) %5085, i64 %4889, !dbg !286 + %7696 = getelementptr bfloat, ptr addrspace(1) %7692, i64 %4498, !dbg !287 + %7697 = getelementptr bfloat, ptr addrspace(1) %7693, i64 %4498, !dbg !287 + %7698 = getelementptr bfloat, ptr addrspace(1) %7694, i64 %4498, !dbg !287 + %7699 = getelementptr bfloat, ptr addrspace(1) %7695, i64 %4498, !dbg !287 + %7700 = getelementptr bfloat, ptr addrspace(1) %5086, i64 %4890, !dbg !288 + %7701 = getelementptr bfloat, ptr addrspace(1) %5086, i64 %4891, !dbg !288 + %7702 = getelementptr bfloat, ptr addrspace(1) %5086, i64 %4892, !dbg !288 + %7703 = getelementptr bfloat, ptr addrspace(1) %5086, i64 %4893, !dbg !288 + %7704 = getelementptr bfloat, ptr addrspace(1) %7700, i64 %4498, !dbg !289 + %7705 = getelementptr bfloat, ptr addrspace(1) %7701, i64 %4498, !dbg !289 + %7706 = getelementptr bfloat, ptr addrspace(1) %7702, i64 %4498, !dbg !289 + %7707 = getelementptr bfloat, ptr addrspace(1) %7703, i64 %4498, !dbg !289 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %4794, ptr addrspace(1) %7696, i32 %4895) #3, !dbg !290 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4797, ptr addrspace(1) %7697, i32 %4895) #3, !dbg !290 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4799, ptr addrspace(1) %7698, i32 %4895) #3, !dbg !290 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4801, ptr addrspace(1) %7699, i32 %4895) #3, !dbg !290 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !290 + %7708 = getelementptr float, ptr addrspace(1) %5087, i64 %4896, !dbg !291 + %7709 = getelementptr float, ptr addrspace(1) %5087, i64 %4897, !dbg !291 + %7710 = getelementptr float, ptr addrspace(1) %5087, i64 %4898, !dbg !291 + %7711 = getelementptr float, ptr addrspace(1) %5087, i64 %4899, !dbg !291 + %7712 = getelementptr float, ptr addrspace(1) %5087, i64 %4900, !dbg !291 + %7713 = getelementptr float, ptr addrspace(1) %5087, i64 %4901, !dbg !291 + %7714 = getelementptr float, ptr addrspace(1) %5087, i64 %4902, !dbg !291 + %7715 = getelementptr float, ptr addrspace(1) %5087, i64 %4903, !dbg !291 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) %4813, ptr addrspace(1) %7708, i32 %4904, i1 %4811) #3, !dbg !292 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4816, ptr addrspace(1) %7709, i32 %4904, i1 %4811) #3, !dbg !292 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4818, ptr addrspace(1) %7710, i32 %4904, i1 %4811) #3, !dbg !292 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4820, ptr addrspace(1) %7711, i32 %4904, i1 %4811) #3, !dbg !292 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4822, ptr addrspace(1) %7712, i32 %4904, i1 %4811) #3, !dbg !292 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4824, ptr addrspace(1) %7713, i32 %4904, i1 %4811) #3, !dbg !292 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4826, ptr addrspace(1) %7714, i32 %4904, i1 %4811) #3, !dbg !292 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4828, ptr addrspace(1) %7715, i32 %4904, i1 %4811) #3, !dbg !292 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !292 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %4829, ptr addrspace(1) %7704, i32 %4895) #3, !dbg !290 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4830, ptr addrspace(1) %7705, i32 %4895) #3, !dbg !290 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4831, ptr addrspace(1) %7706, i32 %4895) #3, !dbg !290 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4832, ptr addrspace(1) %7707, i32 %4895) #3, !dbg !290 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !290 + %7716 = getelementptr float, ptr addrspace(1) %5088, i64 %4896, !dbg !293 + %7717 = getelementptr float, ptr addrspace(1) %5088, i64 %4897, !dbg !293 + %7718 = getelementptr float, ptr addrspace(1) %5088, i64 %4898, !dbg !293 + %7719 = getelementptr float, ptr addrspace(1) %5088, i64 %4899, !dbg !293 + %7720 = getelementptr float, ptr addrspace(1) %5088, i64 %4900, !dbg !293 + %7721 = getelementptr float, ptr addrspace(1) %5088, i64 %4901, !dbg !293 + %7722 = getelementptr float, ptr addrspace(1) %5088, i64 %4902, !dbg !293 + %7723 = getelementptr float, ptr addrspace(1) %5088, i64 %4903, !dbg !293 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) %4833, ptr addrspace(1) %7716, i32 %4904, i1 %4811) #3, !dbg !294 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4834, ptr addrspace(1) %7717, i32 %4904, i1 %4811) #3, !dbg !294 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4835, ptr addrspace(1) %7718, i32 %4904, i1 %4811) #3, !dbg !294 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4836, ptr addrspace(1) %7719, i32 %4904, i1 %4811) #3, !dbg !294 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4837, ptr addrspace(1) %7720, i32 %4904, i1 %4811) #3, !dbg !294 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4838, ptr addrspace(1) %7721, i32 %4904, i1 %4811) #3, !dbg !294 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4839, ptr addrspace(1) %7722, i32 %4904, i1 %4811) #3, !dbg !294 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4840, ptr addrspace(1) %7723, i32 %4904, i1 %4811) #3, !dbg !294 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !294 + %7724 = getelementptr i8, ptr addrspace(1) %7696, i64 524288, !dbg !295 + %7725 = getelementptr i8, ptr addrspace(1) %7697, i64 524288, !dbg !295 + %7726 = getelementptr i8, ptr addrspace(1) %7698, i64 524288, !dbg !295 + %7727 = getelementptr i8, ptr addrspace(1) %7699, i64 524288, !dbg !295 + %7728 = getelementptr i8, ptr addrspace(1) %7704, i64 16384, !dbg !296 + %7729 = getelementptr i8, ptr addrspace(1) %7705, i64 16384, !dbg !296 + %7730 = getelementptr i8, ptr addrspace(1) %7706, i64 16384, !dbg !296 + %7731 = getelementptr i8, ptr addrspace(1) %7707, i64 16384, !dbg !296 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !290 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %4850, ptr addrspace(1) %7724, i32 %4914) #3, !dbg !290 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4852, ptr addrspace(1) %7725, i32 %4914) #3, !dbg !290 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4853, ptr addrspace(1) %7726, i32 %4914) #3, !dbg !290 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4854, ptr addrspace(1) %7727, i32 %4914) #3, !dbg !290 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !290 + %7732 = getelementptr float, ptr addrspace(1) %5087, i64 %4915, !dbg !291 + %7733 = getelementptr float, ptr addrspace(1) %5087, i64 %4916, !dbg !291 + %7734 = getelementptr float, ptr addrspace(1) %5087, i64 %4917, !dbg !291 + %7735 = getelementptr float, ptr addrspace(1) %5087, i64 %4918, !dbg !291 + %7736 = getelementptr float, ptr addrspace(1) %5087, i64 %4919, !dbg !291 + %7737 = getelementptr float, ptr addrspace(1) %5087, i64 %4920, !dbg !291 + %7738 = getelementptr float, ptr addrspace(1) %5087, i64 %4921, !dbg !291 + %7739 = getelementptr float, ptr addrspace(1) %5087, i64 %4922, !dbg !291 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) %4863, ptr addrspace(1) %7732, i32 %4923, i1 %4811) #3, !dbg !292 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4865, ptr addrspace(1) %7733, i32 %4923, i1 %4811) #3, !dbg !292 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4866, ptr addrspace(1) %7734, i32 %4923, i1 %4811) #3, !dbg !292 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4867, ptr addrspace(1) %7735, i32 %4923, i1 %4811) #3, !dbg !292 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4868, ptr addrspace(1) %7736, i32 %4923, i1 %4811) #3, !dbg !292 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4869, ptr addrspace(1) %7737, i32 %4923, i1 %4811) #3, !dbg !292 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4870, ptr addrspace(1) %7738, i32 %4923, i1 %4811) #3, !dbg !292 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4871, ptr addrspace(1) %7739, i32 %4923, i1 %4811) #3, !dbg !292 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !292 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %4872, ptr addrspace(1) %7728, i32 %4914) #3, !dbg !290 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4873, ptr addrspace(1) %7729, i32 %4914) #3, !dbg !290 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4874, ptr addrspace(1) %7730, i32 %4914) #3, !dbg !290 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4875, ptr addrspace(1) %7731, i32 %4914) #3, !dbg !290 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !290 + %7740 = getelementptr float, ptr addrspace(1) %5088, i64 %4915, !dbg !293 + %7741 = getelementptr float, ptr addrspace(1) %5088, i64 %4916, !dbg !293 + %7742 = getelementptr float, ptr addrspace(1) %5088, i64 %4917, !dbg !293 + %7743 = getelementptr float, ptr addrspace(1) %5088, i64 %4918, !dbg !293 + %7744 = getelementptr float, ptr addrspace(1) %5088, i64 %4919, !dbg !293 + %7745 = getelementptr float, ptr addrspace(1) %5088, i64 %4920, !dbg !293 + %7746 = getelementptr float, ptr addrspace(1) %5088, i64 %4921, !dbg !293 + %7747 = getelementptr float, ptr addrspace(1) %5088, i64 %4922, !dbg !293 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) %4876, ptr addrspace(1) %7740, i32 %4923, i1 %4811) #3, !dbg !294 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4877, ptr addrspace(1) %7741, i32 %4923, i1 %4811) #3, !dbg !294 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4878, ptr addrspace(1) %7742, i32 %4923, i1 %4811) #3, !dbg !294 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4879, ptr addrspace(1) %7743, i32 %4923, i1 %4811) #3, !dbg !294 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4880, ptr addrspace(1) %7744, i32 %4923, i1 %4811) #3, !dbg !294 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4881, ptr addrspace(1) %7745, i32 %4923, i1 %4811) #3, !dbg !294 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4882, ptr addrspace(1) %7746, i32 %4923, i1 %4811) #3, !dbg !294 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %4883, ptr addrspace(1) %7747, i32 %4923, i1 %4811) #3, !dbg !294 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !294 + br i1 %4894, label %.lr.ph1691, label %._crit_edge1692, !dbg !297 + +.lr.ph1691: ; preds = %._crit_edge, %__nv_exp2f.exit1219 + %7748 = phi i32 [ %7757, %__nv_exp2f.exit1219 ], [ -1, %._crit_edge ] + %7749 = phi i32 [ %9619, %__nv_exp2f.exit1219 ], [ 1, %._crit_edge ] + %7750 = phi i32 [ %7760, %__nv_exp2f.exit1219 ], [ -1, %._crit_edge ] + %7751 = phi i32 [ %9622, %__nv_exp2f.exit1219 ], [ 1, %._crit_edge ] + %.pn6361689 = phi i32 [ %9616, %__nv_exp2f.exit1219 ], [ %4913, %._crit_edge ] + %.pn6401688 = phi i32 [ %9615, %__nv_exp2f.exit1219 ], [ %4912, %._crit_edge ] + %.pn6441687 = phi i32 [ %9614, %__nv_exp2f.exit1219 ], [ %4911, %._crit_edge ] + %.pn6481686 = phi i32 [ %9613, %__nv_exp2f.exit1219 ], [ %4910, %._crit_edge ] + %.pn6521685 = phi i32 [ %9612, %__nv_exp2f.exit1219 ], [ %4909, %._crit_edge ] + %.pn6561684 = phi i32 [ %9611, %__nv_exp2f.exit1219 ], [ %4908, %._crit_edge ] + %.pn6601683 = phi i32 [ %9610, %__nv_exp2f.exit1219 ], [ %4907, %._crit_edge ] + %.pn6641682 = phi i32 [ %9609, %__nv_exp2f.exit1219 ], [ %4906, %._crit_edge ] + %.pn5841681 = phi ptr addrspace(1) [ %9608, %__nv_exp2f.exit1219 ], [ %7731, %._crit_edge ] + %.pn6001680 = phi ptr addrspace(1) [ %9607, %__nv_exp2f.exit1219 ], [ %7730, %._crit_edge ] + %.pn6161679 = phi ptr addrspace(1) [ %9606, %__nv_exp2f.exit1219 ], [ %7729, %._crit_edge ] + %.pn6321678 = phi ptr addrspace(1) [ %9605, %__nv_exp2f.exit1219 ], [ %7728, %._crit_edge ] + %.pn5201677 = phi ptr addrspace(1) [ %9602, %__nv_exp2f.exit1219 ], [ %7727, %._crit_edge ] + %.pn5361676 = phi ptr addrspace(1) [ %9601, %__nv_exp2f.exit1219 ], [ %7726, %._crit_edge ] + %.pn5521675 = phi ptr addrspace(1) [ %9600, %__nv_exp2f.exit1219 ], [ %7725, %._crit_edge ] + %.pn5681674 = phi ptr addrspace(1) [ %9599, %__nv_exp2f.exit1219 ], [ %7724, %._crit_edge ] + %.pn3781673 = phi float [ %8816, %__nv_exp2f.exit1219 ], [ %7627, %._crit_edge ] + %.pn3801672 = phi float [ %8815, %__nv_exp2f.exit1219 ], [ %7626, %._crit_edge ] + %.pn3821671 = phi float [ %8814, %__nv_exp2f.exit1219 ], [ %7625, %._crit_edge ] + %.pn3841670 = phi float [ %8813, %__nv_exp2f.exit1219 ], [ %7624, %._crit_edge ] + %.pn3861669 = phi float [ %8812, %__nv_exp2f.exit1219 ], [ %7623, %._crit_edge ] + %.pn3881668 = phi float [ %8811, %__nv_exp2f.exit1219 ], [ %7622, %._crit_edge ] + %.pn3901667 = phi float [ %8810, %__nv_exp2f.exit1219 ], [ %7621, %._crit_edge ] + %.pn3921666 = phi float [ %8809, %__nv_exp2f.exit1219 ], [ %7620, %._crit_edge ] + %.pn3941665 = phi float [ %8808, %__nv_exp2f.exit1219 ], [ %7619, %._crit_edge ] + %.pn3961664 = phi float [ %8807, %__nv_exp2f.exit1219 ], [ %7618, %._crit_edge ] + %.pn3981663 = phi float [ %8806, %__nv_exp2f.exit1219 ], [ %7617, %._crit_edge ] + %.pn4001662 = phi float [ %8805, %__nv_exp2f.exit1219 ], [ %7616, %._crit_edge ] + %.pn4021661 = phi float [ %8804, %__nv_exp2f.exit1219 ], [ %7615, %._crit_edge ] + %.pn4041660 = phi float [ %8803, %__nv_exp2f.exit1219 ], [ %7614, %._crit_edge ] + %.pn4061659 = phi float [ %8802, %__nv_exp2f.exit1219 ], [ %7613, %._crit_edge ] + %.pn4081658 = phi float [ %8801, %__nv_exp2f.exit1219 ], [ %7612, %._crit_edge ] + %.pn4101657 = phi float [ %8800, %__nv_exp2f.exit1219 ], [ %7611, %._crit_edge ] + %.pn4121656 = phi float [ %8799, %__nv_exp2f.exit1219 ], [ %7610, %._crit_edge ] + %.pn4141655 = phi float [ %8798, %__nv_exp2f.exit1219 ], [ %7609, %._crit_edge ] + %.pn4161654 = phi float [ %8797, %__nv_exp2f.exit1219 ], [ %7608, %._crit_edge ] + %.pn4181653 = phi float [ %8796, %__nv_exp2f.exit1219 ], [ %7607, %._crit_edge ] + %.pn4201652 = phi float [ %8795, %__nv_exp2f.exit1219 ], [ %7606, %._crit_edge ] + %.pn4221651 = phi float [ %8794, %__nv_exp2f.exit1219 ], [ %7605, %._crit_edge ] + %.pn4241650 = phi float [ %8793, %__nv_exp2f.exit1219 ], [ %7604, %._crit_edge ] + %.pn4261649 = phi float [ %8792, %__nv_exp2f.exit1219 ], [ %7603, %._crit_edge ] + %.pn4281648 = phi float [ %8791, %__nv_exp2f.exit1219 ], [ %7602, %._crit_edge ] + %.pn4301647 = phi float [ %8790, %__nv_exp2f.exit1219 ], [ %7601, %._crit_edge ] + %.pn4321646 = phi float [ %8789, %__nv_exp2f.exit1219 ], [ %7600, %._crit_edge ] + %.pn4341645 = phi float [ %8788, %__nv_exp2f.exit1219 ], [ %7599, %._crit_edge ] + %.pn4361644 = phi float [ %8787, %__nv_exp2f.exit1219 ], [ %7598, %._crit_edge ] + %.pn4381643 = phi float [ %8786, %__nv_exp2f.exit1219 ], [ %7597, %._crit_edge ] + %.pn4401642 = phi float [ %8785, %__nv_exp2f.exit1219 ], [ %7596, %._crit_edge ] + %.pn4421641 = phi float [ %8784, %__nv_exp2f.exit1219 ], [ %7595, %._crit_edge ] + %.pn4441640 = phi float [ %8783, %__nv_exp2f.exit1219 ], [ %7594, %._crit_edge ] + %.pn4461639 = phi float [ %8782, %__nv_exp2f.exit1219 ], [ %7593, %._crit_edge ] + %.pn4481638 = phi float [ %8781, %__nv_exp2f.exit1219 ], [ %7592, %._crit_edge ] + %.pn4501637 = phi float [ %8780, %__nv_exp2f.exit1219 ], [ %7591, %._crit_edge ] + %.pn4521636 = phi float [ %8779, %__nv_exp2f.exit1219 ], [ %7590, %._crit_edge ] + %.pn4541635 = phi float [ %8778, %__nv_exp2f.exit1219 ], [ %7589, %._crit_edge ] + %.pn4561634 = phi float [ %8777, %__nv_exp2f.exit1219 ], [ %7588, %._crit_edge ] + %.pn4581633 = phi float [ %8776, %__nv_exp2f.exit1219 ], [ %7587, %._crit_edge ] + %.pn4601632 = phi float [ %8775, %__nv_exp2f.exit1219 ], [ %7586, %._crit_edge ] + %.pn4621631 = phi float [ %8774, %__nv_exp2f.exit1219 ], [ %7585, %._crit_edge ] + %.pn4641630 = phi float [ %8773, %__nv_exp2f.exit1219 ], [ %7584, %._crit_edge ] + %.pn4661629 = phi float [ %8772, %__nv_exp2f.exit1219 ], [ %7583, %._crit_edge ] + %.pn4681628 = phi float [ %8771, %__nv_exp2f.exit1219 ], [ %7582, %._crit_edge ] + %.pn4701627 = phi float [ %8770, %__nv_exp2f.exit1219 ], [ %7581, %._crit_edge ] + %.pn4721626 = phi float [ %8769, %__nv_exp2f.exit1219 ], [ %7580, %._crit_edge ] + %.pn4741625 = phi float [ %8768, %__nv_exp2f.exit1219 ], [ %7579, %._crit_edge ] + %.pn4761624 = phi float [ %8767, %__nv_exp2f.exit1219 ], [ %7578, %._crit_edge ] + %.pn4781623 = phi float [ %8766, %__nv_exp2f.exit1219 ], [ %7577, %._crit_edge ] + %.pn4801622 = phi float [ %8765, %__nv_exp2f.exit1219 ], [ %7576, %._crit_edge ] + %.pn4821621 = phi float [ %8764, %__nv_exp2f.exit1219 ], [ %7575, %._crit_edge ] + %.pn4841620 = phi float [ %8763, %__nv_exp2f.exit1219 ], [ %7574, %._crit_edge ] + %.pn4861619 = phi float [ %8762, %__nv_exp2f.exit1219 ], [ %7573, %._crit_edge ] + %.pn4881618 = phi float [ %8761, %__nv_exp2f.exit1219 ], [ %7572, %._crit_edge ] + %.pn4901617 = phi float [ %8760, %__nv_exp2f.exit1219 ], [ %7571, %._crit_edge ] + %.pn4921616 = phi float [ %8759, %__nv_exp2f.exit1219 ], [ %7570, %._crit_edge ] + %.pn4941615 = phi float [ %8758, %__nv_exp2f.exit1219 ], [ %7569, %._crit_edge ] + %.pn4961614 = phi float [ %8757, %__nv_exp2f.exit1219 ], [ %7568, %._crit_edge ] + %.pn4981613 = phi float [ %8756, %__nv_exp2f.exit1219 ], [ %7567, %._crit_edge ] + %.pn5001612 = phi float [ %8755, %__nv_exp2f.exit1219 ], [ %7566, %._crit_edge ] + %.pn5021611 = phi float [ %8754, %__nv_exp2f.exit1219 ], [ %7565, %._crit_edge ] + %.pn5041610 = phi float [ %8753, %__nv_exp2f.exit1219 ], [ %7564, %._crit_edge ] + %.pn2501609 = phi float [ %9576, %__nv_exp2f.exit1219 ], [ %7691, %._crit_edge ] + %.pn2521608 = phi float [ %9575, %__nv_exp2f.exit1219 ], [ %7690, %._crit_edge ] + %.pn2541607 = phi float [ %9574, %__nv_exp2f.exit1219 ], [ %7689, %._crit_edge ] + %.pn2561606 = phi float [ %9573, %__nv_exp2f.exit1219 ], [ %7688, %._crit_edge ] + %.pn2581605 = phi float [ %9572, %__nv_exp2f.exit1219 ], [ %7687, %._crit_edge ] + %.pn2601604 = phi float [ %9571, %__nv_exp2f.exit1219 ], [ %7686, %._crit_edge ] + %.pn2621603 = phi float [ %9570, %__nv_exp2f.exit1219 ], [ %7685, %._crit_edge ] + %.pn2641602 = phi float [ %9569, %__nv_exp2f.exit1219 ], [ %7684, %._crit_edge ] + %.pn2661601 = phi float [ %9568, %__nv_exp2f.exit1219 ], [ %7683, %._crit_edge ] + %.pn2681600 = phi float [ %9567, %__nv_exp2f.exit1219 ], [ %7682, %._crit_edge ] + %.pn2701599 = phi float [ %9566, %__nv_exp2f.exit1219 ], [ %7681, %._crit_edge ] + %.pn2721598 = phi float [ %9565, %__nv_exp2f.exit1219 ], [ %7680, %._crit_edge ] + %.pn2741597 = phi float [ %9564, %__nv_exp2f.exit1219 ], [ %7679, %._crit_edge ] + %.pn2761596 = phi float [ %9563, %__nv_exp2f.exit1219 ], [ %7678, %._crit_edge ] + %.pn2781595 = phi float [ %9562, %__nv_exp2f.exit1219 ], [ %7677, %._crit_edge ] + %.pn2801594 = phi float [ %9561, %__nv_exp2f.exit1219 ], [ %7676, %._crit_edge ] + %.pn2821593 = phi float [ %9560, %__nv_exp2f.exit1219 ], [ %7675, %._crit_edge ] + %.pn2841592 = phi float [ %9559, %__nv_exp2f.exit1219 ], [ %7674, %._crit_edge ] + %.pn2861591 = phi float [ %9558, %__nv_exp2f.exit1219 ], [ %7673, %._crit_edge ] + %.pn2881590 = phi float [ %9557, %__nv_exp2f.exit1219 ], [ %7672, %._crit_edge ] + %.pn2901589 = phi float [ %9556, %__nv_exp2f.exit1219 ], [ %7671, %._crit_edge ] + %.pn2921588 = phi float [ %9555, %__nv_exp2f.exit1219 ], [ %7670, %._crit_edge ] + %.pn2941587 = phi float [ %9554, %__nv_exp2f.exit1219 ], [ %7669, %._crit_edge ] + %.pn2961586 = phi float [ %9553, %__nv_exp2f.exit1219 ], [ %7668, %._crit_edge ] + %.pn2981585 = phi float [ %9552, %__nv_exp2f.exit1219 ], [ %7667, %._crit_edge ] + %.pn3001584 = phi float [ %9551, %__nv_exp2f.exit1219 ], [ %7666, %._crit_edge ] + %.pn3021583 = phi float [ %9550, %__nv_exp2f.exit1219 ], [ %7665, %._crit_edge ] + %.pn3041582 = phi float [ %9549, %__nv_exp2f.exit1219 ], [ %7664, %._crit_edge ] + %.pn3061581 = phi float [ %9548, %__nv_exp2f.exit1219 ], [ %7663, %._crit_edge ] + %.pn3081580 = phi float [ %9547, %__nv_exp2f.exit1219 ], [ %7662, %._crit_edge ] + %.pn3101579 = phi float [ %9546, %__nv_exp2f.exit1219 ], [ %7661, %._crit_edge ] + %.pn3121578 = phi float [ %9545, %__nv_exp2f.exit1219 ], [ %7660, %._crit_edge ] + %.pn3141577 = phi float [ %9544, %__nv_exp2f.exit1219 ], [ %7659, %._crit_edge ] + %.pn3161576 = phi float [ %9543, %__nv_exp2f.exit1219 ], [ %7658, %._crit_edge ] + %.pn3181575 = phi float [ %9542, %__nv_exp2f.exit1219 ], [ %7657, %._crit_edge ] + %.pn3201574 = phi float [ %9541, %__nv_exp2f.exit1219 ], [ %7656, %._crit_edge ] + %.pn3221573 = phi float [ %9540, %__nv_exp2f.exit1219 ], [ %7655, %._crit_edge ] + %.pn3241572 = phi float [ %9539, %__nv_exp2f.exit1219 ], [ %7654, %._crit_edge ] + %.pn3261571 = phi float [ %9538, %__nv_exp2f.exit1219 ], [ %7653, %._crit_edge ] + %.pn3281570 = phi float [ %9537, %__nv_exp2f.exit1219 ], [ %7652, %._crit_edge ] + %.pn3301569 = phi float [ %9536, %__nv_exp2f.exit1219 ], [ %7651, %._crit_edge ] + %.pn3321568 = phi float [ %9535, %__nv_exp2f.exit1219 ], [ %7650, %._crit_edge ] + %.pn3341567 = phi float [ %9534, %__nv_exp2f.exit1219 ], [ %7649, %._crit_edge ] + %.pn3361566 = phi float [ %9533, %__nv_exp2f.exit1219 ], [ %7648, %._crit_edge ] + %.pn3381565 = phi float [ %9532, %__nv_exp2f.exit1219 ], [ %7647, %._crit_edge ] + %.pn3401564 = phi float [ %9531, %__nv_exp2f.exit1219 ], [ %7646, %._crit_edge ] + %.pn3421563 = phi float [ %9530, %__nv_exp2f.exit1219 ], [ %7645, %._crit_edge ] + %.pn3441562 = phi float [ %9529, %__nv_exp2f.exit1219 ], [ %7644, %._crit_edge ] + %.pn3461561 = phi float [ %9528, %__nv_exp2f.exit1219 ], [ %7643, %._crit_edge ] + %.pn3481560 = phi float [ %9527, %__nv_exp2f.exit1219 ], [ %7642, %._crit_edge ] + %.pn3501559 = phi float [ %9526, %__nv_exp2f.exit1219 ], [ %7641, %._crit_edge ] + %.pn3521558 = phi float [ %9525, %__nv_exp2f.exit1219 ], [ %7640, %._crit_edge ] + %.pn3541557 = phi float [ %9524, %__nv_exp2f.exit1219 ], [ %7639, %._crit_edge ] + %.pn3561556 = phi float [ %9523, %__nv_exp2f.exit1219 ], [ %7638, %._crit_edge ] + %.pn3581555 = phi float [ %9522, %__nv_exp2f.exit1219 ], [ %7637, %._crit_edge ] + %.pn3601554 = phi float [ %9521, %__nv_exp2f.exit1219 ], [ %7636, %._crit_edge ] + %.pn3621553 = phi float [ %9520, %__nv_exp2f.exit1219 ], [ %7635, %._crit_edge ] + %.pn3641552 = phi float [ %9519, %__nv_exp2f.exit1219 ], [ %7634, %._crit_edge ] + %.pn3661551 = phi float [ %9518, %__nv_exp2f.exit1219 ], [ %7633, %._crit_edge ] + %.pn3681550 = phi float [ %9517, %__nv_exp2f.exit1219 ], [ %7632, %._crit_edge ] + %.pn3701549 = phi float [ %9516, %__nv_exp2f.exit1219 ], [ %7631, %._crit_edge ] + %.pn3721548 = phi float [ %9515, %__nv_exp2f.exit1219 ], [ %7630, %._crit_edge ] + %.pn3741547 = phi float [ %9514, %__nv_exp2f.exit1219 ], [ %7629, %._crit_edge ] + %.pn3761546 = phi float [ %9513, %__nv_exp2f.exit1219 ], [ %7628, %._crit_edge ] + %7752 = phi i32 [ %9577, %__nv_exp2f.exit1219 ], [ 0, %._crit_edge ] + %7753 = icmp slt i32 %7752, %4924, !dbg !297 + %7754 = icmp slt i32 %7752, %4925, !dbg !297 + %7755 = add i32 %7748, 1, !dbg !297 + %7756 = icmp sgt i32 %7755, 1, !dbg !297 + %7757 = select i1 %7756, i32 0, i32 %7755, !dbg !297 + %7758 = add i32 %7750, 1, !dbg !297 + %7759 = icmp sgt i32 %7758, 2, !dbg !297 + %7760 = select i1 %7759, i32 0, i32 %7758, !dbg !297 + tail call void @llvm.nvvm.cp.async.wait.group(i32 4), !dbg !290 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !290 + %7761 = shl i32 %7760, 13, !dbg !290 + %7762 = getelementptr bfloat, ptr addrspace(3) @global_smem, i32 %7761, !dbg !290 + %7763 = shl i32 %7757, 6, !dbg !292 + %7764 = getelementptr float, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %7763, !dbg !292 + %7765 = getelementptr inbounds nuw i8, ptr addrspace(3) %7764, i32 %4812, !dbg !292 + %7766 = load float, ptr addrspace(3) %7765, align 8, !dbg !292 + %7767 = getelementptr inbounds nuw i8, ptr addrspace(3) %7765, i32 4, !dbg !292 + %7768 = load float, ptr addrspace(3) %7767, align 4, !dbg !292 + %7769 = getelementptr inbounds nuw i8, ptr addrspace(3) %7764, i32 %4815, !dbg !292 + %7770 = load float, ptr addrspace(3) %7769, align 8, !dbg !292 + %7771 = getelementptr inbounds nuw i8, ptr addrspace(3) %7769, i32 4, !dbg !292 + %7772 = load float, ptr addrspace(3) %7771, align 4, !dbg !292 + %7773 = getelementptr inbounds nuw i8, ptr addrspace(3) %7764, i32 %4817, !dbg !292 + %7774 = load float, ptr addrspace(3) %7773, align 8, !dbg !292 + %7775 = getelementptr inbounds nuw i8, ptr addrspace(3) %7773, i32 4, !dbg !292 + %7776 = load float, ptr addrspace(3) %7775, align 4, !dbg !292 + %7777 = getelementptr inbounds nuw i8, ptr addrspace(3) %7764, i32 %4819, !dbg !292 + %7778 = load float, ptr addrspace(3) %7777, align 8, !dbg !292 + %7779 = getelementptr inbounds nuw i8, ptr addrspace(3) %7777, i32 4, !dbg !292 + %7780 = load float, ptr addrspace(3) %7779, align 4, !dbg !292 + %7781 = getelementptr inbounds nuw i8, ptr addrspace(3) %7764, i32 %4821, !dbg !292 + %7782 = load float, ptr addrspace(3) %7781, align 8, !dbg !292 + %7783 = getelementptr inbounds nuw i8, ptr addrspace(3) %7781, i32 4, !dbg !292 + %7784 = load float, ptr addrspace(3) %7783, align 4, !dbg !292 + %7785 = getelementptr inbounds nuw i8, ptr addrspace(3) %7764, i32 %4823, !dbg !292 + %7786 = load float, ptr addrspace(3) %7785, align 8, !dbg !292 + %7787 = getelementptr inbounds nuw i8, ptr addrspace(3) %7785, i32 4, !dbg !292 + %7788 = load float, ptr addrspace(3) %7787, align 4, !dbg !292 + %7789 = getelementptr inbounds nuw i8, ptr addrspace(3) %7764, i32 %4825, !dbg !292 + %7790 = load float, ptr addrspace(3) %7789, align 8, !dbg !292 + %7791 = getelementptr inbounds nuw i8, ptr addrspace(3) %7789, i32 4, !dbg !292 + %7792 = load float, ptr addrspace(3) %7791, align 4, !dbg !292 + %7793 = getelementptr inbounds nuw i8, ptr addrspace(3) %7764, i32 %4827, !dbg !292 + %7794 = load float, ptr addrspace(3) %7793, align 8, !dbg !292 + %7795 = getelementptr inbounds nuw i8, ptr addrspace(3) %7793, i32 4, !dbg !292 + %7796 = load float, ptr addrspace(3) %7795, align 4, !dbg !292 + %7797 = fcmp oeq float %7766, 0xFFF0000000000000, !dbg !298 + %7798 = fcmp oeq float %7768, 0xFFF0000000000000, !dbg !298 + %7799 = fcmp oeq float %7770, 0xFFF0000000000000, !dbg !298 + %7800 = fcmp oeq float %7772, 0xFFF0000000000000, !dbg !298 + %7801 = fcmp oeq float %7774, 0xFFF0000000000000, !dbg !298 + %7802 = fcmp oeq float %7776, 0xFFF0000000000000, !dbg !298 + %7803 = fcmp oeq float %7778, 0xFFF0000000000000, !dbg !298 + %7804 = fcmp oeq float %7780, 0xFFF0000000000000, !dbg !298 + %7805 = fcmp oeq float %7782, 0xFFF0000000000000, !dbg !298 + %7806 = fcmp oeq float %7784, 0xFFF0000000000000, !dbg !298 + %7807 = fcmp oeq float %7786, 0xFFF0000000000000, !dbg !298 + %7808 = fcmp oeq float %7788, 0xFFF0000000000000, !dbg !298 + %7809 = fcmp oeq float %7790, 0xFFF0000000000000, !dbg !298 + %7810 = fcmp oeq float %7792, 0xFFF0000000000000, !dbg !298 + %7811 = fcmp oeq float %7794, 0xFFF0000000000000, !dbg !298 + %7812 = fcmp oeq float %7796, 0xFFF0000000000000, !dbg !298 + %7813 = select i1 %7797, float 0.000000e+00, float %7766, !dbg !299 + %7814 = select i1 %7798, float 0.000000e+00, float %7768, !dbg !299 + %7815 = select i1 %7799, float 0.000000e+00, float %7770, !dbg !299 + %7816 = select i1 %7800, float 0.000000e+00, float %7772, !dbg !299 + %7817 = select i1 %7801, float 0.000000e+00, float %7774, !dbg !299 + %7818 = select i1 %7802, float 0.000000e+00, float %7776, !dbg !299 + %7819 = select i1 %7803, float 0.000000e+00, float %7778, !dbg !299 + %7820 = select i1 %7804, float 0.000000e+00, float %7780, !dbg !299 + %7821 = select i1 %7805, float 0.000000e+00, float %7782, !dbg !299 + %7822 = select i1 %7806, float 0.000000e+00, float %7784, !dbg !299 + %7823 = select i1 %7807, float 0.000000e+00, float %7786, !dbg !299 + %7824 = select i1 %7808, float 0.000000e+00, float %7788, !dbg !299 + %7825 = select i1 %7809, float 0.000000e+00, float %7790, !dbg !299 + %7826 = select i1 %7810, float 0.000000e+00, float %7792, !dbg !299 + %7827 = select i1 %7811, float 0.000000e+00, float %7794, !dbg !299 + %7828 = select i1 %7812, float 0.000000e+00, float %7796, !dbg !299 + %7829 = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 %36, i32 0, i32 31), !dbg !300 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !300 + %7830 = shl i32 %7829, 11, !dbg !300 + %7831 = and i32 %7830, 8192, !dbg !300 + %7832 = add i32 %7831, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !300 + %7833 = lshr exact i32 %7832, 4, !dbg !300 + %7834 = and i32 %7833, 16383, !dbg !300 + %7835 = zext nneg i32 %7834 to i64, !dbg !300 + %7836 = or disjoint i64 %7835, 4611686293372403712, !dbg !300 + %7837 = ptrtoint ptr addrspace(3) %7762 to i32, !dbg !300 + %7838 = lshr exact i32 %7837, 4, !dbg !300 + %7839 = and i32 %7838, 16383, !dbg !300 + %7840 = zext nneg i32 %7839 to i64, !dbg !300 + %7841 = or disjoint i64 %7840, 4611686293338849280, !dbg !300 + %7842 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $32, $33, 0, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,l,l"(i64 %7836, i64 %7841) #3, !dbg !300 + %7843 = or disjoint i32 %7831, 32, !dbg !300 + %7844 = add i32 %7843, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !300 + %7845 = lshr exact i32 %7844, 4, !dbg !300 + %7846 = and i32 %7845, 16383, !dbg !300 + %7847 = zext nneg i32 %7846 to i64, !dbg !300 + %7848 = or disjoint i64 %7847, 4611686293372403712, !dbg !300 + %7849 = add i32 %7837, 32, !dbg !300 + %7850 = lshr exact i32 %7849, 4, !dbg !300 + %7851 = and i32 %7850, 16383, !dbg !300 + %7852 = zext nneg i32 %7851 to i64, !dbg !300 + %7853 = or disjoint i64 %7852, 4611686293338849280, !dbg !300 + %7854 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7842, 0, !dbg !300 + %7855 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7842, 1, !dbg !300 + %7856 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7842, 2, !dbg !300 + %7857 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7842, 3, !dbg !300 + %7858 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7842, 4, !dbg !300 + %7859 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7842, 5, !dbg !300 + %7860 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7842, 6, !dbg !300 + %7861 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7842, 7, !dbg !300 + %7862 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7842, 8, !dbg !300 + %7863 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7842, 9, !dbg !300 + %7864 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7842, 10, !dbg !300 + %7865 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7842, 11, !dbg !300 + %7866 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7842, 12, !dbg !300 + %7867 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7842, 13, !dbg !300 + %7868 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7842, 14, !dbg !300 + %7869 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7842, 15, !dbg !300 + %7870 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7842, 16, !dbg !300 + %7871 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7842, 17, !dbg !300 + %7872 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7842, 18, !dbg !300 + %7873 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7842, 19, !dbg !300 + %7874 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7842, 20, !dbg !300 + %7875 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7842, 21, !dbg !300 + %7876 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7842, 22, !dbg !300 + %7877 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7842, 23, !dbg !300 + %7878 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7842, 24, !dbg !300 + %7879 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7842, 25, !dbg !300 + %7880 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7842, 26, !dbg !300 + %7881 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7842, 27, !dbg !300 + %7882 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7842, 28, !dbg !300 + %7883 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7842, 29, !dbg !300 + %7884 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7842, 30, !dbg !300 + %7885 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7842, 31, !dbg !300 + %7886 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %7854, float %7855, float %7856, float %7857, float %7858, float %7859, float %7860, float %7861, float %7862, float %7863, float %7864, float %7865, float %7866, float %7867, float %7868, float %7869, float %7870, float %7871, float %7872, float %7873, float %7874, float %7875, float %7876, float %7877, float %7878, float %7879, float %7880, float %7881, float %7882, float %7883, float %7884, float %7885, i64 %7848, i64 %7853, i1 true) #3, !dbg !300 + %7887 = or disjoint i32 %7831, 64, !dbg !300 + %7888 = add i32 %7887, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !300 + %7889 = lshr exact i32 %7888, 4, !dbg !300 + %7890 = and i32 %7889, 16383, !dbg !300 + %7891 = zext nneg i32 %7890 to i64, !dbg !300 + %7892 = or disjoint i64 %7891, 4611686293372403712, !dbg !300 + %7893 = add i32 %7837, 64, !dbg !300 + %7894 = lshr exact i32 %7893, 4, !dbg !300 + %7895 = and i32 %7894, 16383, !dbg !300 + %7896 = zext nneg i32 %7895 to i64, !dbg !300 + %7897 = or disjoint i64 %7896, 4611686293338849280, !dbg !300 + %7898 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7886, 0, !dbg !300 + %7899 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7886, 1, !dbg !300 + %7900 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7886, 2, !dbg !300 + %7901 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7886, 3, !dbg !300 + %7902 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7886, 4, !dbg !300 + %7903 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7886, 5, !dbg !300 + %7904 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7886, 6, !dbg !300 + %7905 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7886, 7, !dbg !300 + %7906 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7886, 8, !dbg !300 + %7907 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7886, 9, !dbg !300 + %7908 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7886, 10, !dbg !300 + %7909 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7886, 11, !dbg !300 + %7910 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7886, 12, !dbg !300 + %7911 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7886, 13, !dbg !300 + %7912 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7886, 14, !dbg !300 + %7913 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7886, 15, !dbg !300 + %7914 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7886, 16, !dbg !300 + %7915 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7886, 17, !dbg !300 + %7916 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7886, 18, !dbg !300 + %7917 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7886, 19, !dbg !300 + %7918 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7886, 20, !dbg !300 + %7919 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7886, 21, !dbg !300 + %7920 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7886, 22, !dbg !300 + %7921 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7886, 23, !dbg !300 + %7922 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7886, 24, !dbg !300 + %7923 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7886, 25, !dbg !300 + %7924 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7886, 26, !dbg !300 + %7925 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7886, 27, !dbg !300 + %7926 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7886, 28, !dbg !300 + %7927 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7886, 29, !dbg !300 + %7928 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7886, 30, !dbg !300 + %7929 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7886, 31, !dbg !300 + %7930 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %7898, float %7899, float %7900, float %7901, float %7902, float %7903, float %7904, float %7905, float %7906, float %7907, float %7908, float %7909, float %7910, float %7911, float %7912, float %7913, float %7914, float %7915, float %7916, float %7917, float %7918, float %7919, float %7920, float %7921, float %7922, float %7923, float %7924, float %7925, float %7926, float %7927, float %7928, float %7929, i64 %7892, i64 %7897, i1 true) #3, !dbg !300 + %7931 = or disjoint i32 %7831, 96, !dbg !300 + %7932 = add i32 %7931, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !300 + %7933 = lshr exact i32 %7932, 4, !dbg !300 + %7934 = and i32 %7933, 16383, !dbg !300 + %7935 = zext nneg i32 %7934 to i64, !dbg !300 + %7936 = or disjoint i64 %7935, 4611686293372403712, !dbg !300 + %7937 = add i32 %7837, 96, !dbg !300 + %7938 = lshr exact i32 %7937, 4, !dbg !300 + %7939 = and i32 %7938, 16383, !dbg !300 + %7940 = zext nneg i32 %7939 to i64, !dbg !300 + %7941 = or disjoint i64 %7940, 4611686293338849280, !dbg !300 + %7942 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7930, 0, !dbg !300 + %7943 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7930, 1, !dbg !300 + %7944 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7930, 2, !dbg !300 + %7945 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7930, 3, !dbg !300 + %7946 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7930, 4, !dbg !300 + %7947 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7930, 5, !dbg !300 + %7948 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7930, 6, !dbg !300 + %7949 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7930, 7, !dbg !300 + %7950 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7930, 8, !dbg !300 + %7951 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7930, 9, !dbg !300 + %7952 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7930, 10, !dbg !300 + %7953 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7930, 11, !dbg !300 + %7954 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7930, 12, !dbg !300 + %7955 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7930, 13, !dbg !300 + %7956 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7930, 14, !dbg !300 + %7957 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7930, 15, !dbg !300 + %7958 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7930, 16, !dbg !300 + %7959 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7930, 17, !dbg !300 + %7960 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7930, 18, !dbg !300 + %7961 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7930, 19, !dbg !300 + %7962 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7930, 20, !dbg !300 + %7963 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7930, 21, !dbg !300 + %7964 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7930, 22, !dbg !300 + %7965 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7930, 23, !dbg !300 + %7966 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7930, 24, !dbg !300 + %7967 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7930, 25, !dbg !300 + %7968 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7930, 26, !dbg !300 + %7969 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7930, 27, !dbg !300 + %7970 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7930, 28, !dbg !300 + %7971 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7930, 29, !dbg !300 + %7972 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7930, 30, !dbg !300 + %7973 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7930, 31, !dbg !300 + %7974 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %7942, float %7943, float %7944, float %7945, float %7946, float %7947, float %7948, float %7949, float %7950, float %7951, float %7952, float %7953, float %7954, float %7955, float %7956, float %7957, float %7958, float %7959, float %7960, float %7961, float %7962, float %7963, float %7964, float %7965, float %7966, float %7967, float %7968, float %7969, float %7970, float %7971, float %7972, float %7973, i64 %7936, i64 %7941, i1 true) #3, !dbg !300 + %7975 = or disjoint i32 %7831, 16384, !dbg !300 + %7976 = add i32 %7975, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !300 + %7977 = lshr exact i32 %7976, 4, !dbg !300 + %7978 = and i32 %7977, 16383, !dbg !300 + %7979 = zext nneg i32 %7978 to i64, !dbg !300 + %7980 = or disjoint i64 %7979, 4611686293372403712, !dbg !300 + %7981 = add i32 %7837, 8192, !dbg !300 + %7982 = lshr exact i32 %7981, 4, !dbg !300 + %7983 = and i32 %7982, 16383, !dbg !300 + %7984 = zext nneg i32 %7983 to i64, !dbg !300 + %7985 = or disjoint i64 %7984, 4611686293338849280, !dbg !300 + %7986 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7974, 0, !dbg !300 + %7987 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7974, 1, !dbg !300 + %7988 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7974, 2, !dbg !300 + %7989 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7974, 3, !dbg !300 + %7990 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7974, 4, !dbg !300 + %7991 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7974, 5, !dbg !300 + %7992 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7974, 6, !dbg !300 + %7993 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7974, 7, !dbg !300 + %7994 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7974, 8, !dbg !300 + %7995 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7974, 9, !dbg !300 + %7996 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7974, 10, !dbg !300 + %7997 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7974, 11, !dbg !300 + %7998 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7974, 12, !dbg !300 + %7999 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7974, 13, !dbg !300 + %8000 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7974, 14, !dbg !300 + %8001 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7974, 15, !dbg !300 + %8002 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7974, 16, !dbg !300 + %8003 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7974, 17, !dbg !300 + %8004 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7974, 18, !dbg !300 + %8005 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7974, 19, !dbg !300 + %8006 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7974, 20, !dbg !300 + %8007 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7974, 21, !dbg !300 + %8008 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7974, 22, !dbg !300 + %8009 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7974, 23, !dbg !300 + %8010 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7974, 24, !dbg !300 + %8011 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7974, 25, !dbg !300 + %8012 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7974, 26, !dbg !300 + %8013 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7974, 27, !dbg !300 + %8014 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7974, 28, !dbg !300 + %8015 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7974, 29, !dbg !300 + %8016 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7974, 30, !dbg !300 + %8017 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7974, 31, !dbg !300 + %8018 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %7986, float %7987, float %7988, float %7989, float %7990, float %7991, float %7992, float %7993, float %7994, float %7995, float %7996, float %7997, float %7998, float %7999, float %8000, float %8001, float %8002, float %8003, float %8004, float %8005, float %8006, float %8007, float %8008, float %8009, float %8010, float %8011, float %8012, float %8013, float %8014, float %8015, float %8016, float %8017, i64 %7980, i64 %7985, i1 true) #3, !dbg !300 + %8019 = or disjoint i32 %7831, 16416, !dbg !300 + %8020 = add i32 %8019, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !300 + %8021 = lshr exact i32 %8020, 4, !dbg !300 + %8022 = and i32 %8021, 16383, !dbg !300 + %8023 = zext nneg i32 %8022 to i64, !dbg !300 + %8024 = or disjoint i64 %8023, 4611686293372403712, !dbg !300 + %8025 = add i32 %7837, 8224, !dbg !300 + %8026 = lshr exact i32 %8025, 4, !dbg !300 + %8027 = and i32 %8026, 16383, !dbg !300 + %8028 = zext nneg i32 %8027 to i64, !dbg !300 + %8029 = or disjoint i64 %8028, 4611686293338849280, !dbg !300 + %8030 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8018, 0, !dbg !300 + %8031 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8018, 1, !dbg !300 + %8032 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8018, 2, !dbg !300 + %8033 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8018, 3, !dbg !300 + %8034 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8018, 4, !dbg !300 + %8035 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8018, 5, !dbg !300 + %8036 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8018, 6, !dbg !300 + %8037 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8018, 7, !dbg !300 + %8038 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8018, 8, !dbg !300 + %8039 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8018, 9, !dbg !300 + %8040 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8018, 10, !dbg !300 + %8041 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8018, 11, !dbg !300 + %8042 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8018, 12, !dbg !300 + %8043 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8018, 13, !dbg !300 + %8044 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8018, 14, !dbg !300 + %8045 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8018, 15, !dbg !300 + %8046 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8018, 16, !dbg !300 + %8047 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8018, 17, !dbg !300 + %8048 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8018, 18, !dbg !300 + %8049 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8018, 19, !dbg !300 + %8050 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8018, 20, !dbg !300 + %8051 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8018, 21, !dbg !300 + %8052 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8018, 22, !dbg !300 + %8053 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8018, 23, !dbg !300 + %8054 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8018, 24, !dbg !300 + %8055 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8018, 25, !dbg !300 + %8056 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8018, 26, !dbg !300 + %8057 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8018, 27, !dbg !300 + %8058 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8018, 28, !dbg !300 + %8059 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8018, 29, !dbg !300 + %8060 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8018, 30, !dbg !300 + %8061 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8018, 31, !dbg !300 + %8062 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %8030, float %8031, float %8032, float %8033, float %8034, float %8035, float %8036, float %8037, float %8038, float %8039, float %8040, float %8041, float %8042, float %8043, float %8044, float %8045, float %8046, float %8047, float %8048, float %8049, float %8050, float %8051, float %8052, float %8053, float %8054, float %8055, float %8056, float %8057, float %8058, float %8059, float %8060, float %8061, i64 %8024, i64 %8029, i1 true) #3, !dbg !300 + %8063 = or disjoint i32 %7831, 16448, !dbg !300 + %8064 = add i32 %8063, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !300 + %8065 = lshr exact i32 %8064, 4, !dbg !300 + %8066 = and i32 %8065, 16383, !dbg !300 + %8067 = zext nneg i32 %8066 to i64, !dbg !300 + %8068 = or disjoint i64 %8067, 4611686293372403712, !dbg !300 + %8069 = add i32 %7837, 8256, !dbg !300 + %8070 = lshr exact i32 %8069, 4, !dbg !300 + %8071 = and i32 %8070, 16383, !dbg !300 + %8072 = zext nneg i32 %8071 to i64, !dbg !300 + %8073 = or disjoint i64 %8072, 4611686293338849280, !dbg !300 + %8074 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8062, 0, !dbg !300 + %8075 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8062, 1, !dbg !300 + %8076 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8062, 2, !dbg !300 + %8077 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8062, 3, !dbg !300 + %8078 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8062, 4, !dbg !300 + %8079 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8062, 5, !dbg !300 + %8080 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8062, 6, !dbg !300 + %8081 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8062, 7, !dbg !300 + %8082 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8062, 8, !dbg !300 + %8083 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8062, 9, !dbg !300 + %8084 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8062, 10, !dbg !300 + %8085 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8062, 11, !dbg !300 + %8086 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8062, 12, !dbg !300 + %8087 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8062, 13, !dbg !300 + %8088 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8062, 14, !dbg !300 + %8089 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8062, 15, !dbg !300 + %8090 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8062, 16, !dbg !300 + %8091 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8062, 17, !dbg !300 + %8092 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8062, 18, !dbg !300 + %8093 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8062, 19, !dbg !300 + %8094 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8062, 20, !dbg !300 + %8095 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8062, 21, !dbg !300 + %8096 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8062, 22, !dbg !300 + %8097 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8062, 23, !dbg !300 + %8098 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8062, 24, !dbg !300 + %8099 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8062, 25, !dbg !300 + %8100 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8062, 26, !dbg !300 + %8101 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8062, 27, !dbg !300 + %8102 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8062, 28, !dbg !300 + %8103 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8062, 29, !dbg !300 + %8104 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8062, 30, !dbg !300 + %8105 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8062, 31, !dbg !300 + %8106 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %8074, float %8075, float %8076, float %8077, float %8078, float %8079, float %8080, float %8081, float %8082, float %8083, float %8084, float %8085, float %8086, float %8087, float %8088, float %8089, float %8090, float %8091, float %8092, float %8093, float %8094, float %8095, float %8096, float %8097, float %8098, float %8099, float %8100, float %8101, float %8102, float %8103, float %8104, float %8105, i64 %8068, i64 %8073, i1 true) #3, !dbg !300 + %8107 = or disjoint i32 %7831, 16480, !dbg !300 + %8108 = add i32 %8107, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !300 + %8109 = lshr exact i32 %8108, 4, !dbg !300 + %8110 = and i32 %8109, 16383, !dbg !300 + %8111 = zext nneg i32 %8110 to i64, !dbg !300 + %8112 = or disjoint i64 %8111, 4611686293372403712, !dbg !300 + %8113 = add i32 %7837, 8288, !dbg !300 + %8114 = lshr exact i32 %8113, 4, !dbg !300 + %8115 = and i32 %8114, 16383, !dbg !300 + %8116 = zext nneg i32 %8115 to i64, !dbg !300 + %8117 = or disjoint i64 %8116, 4611686293338849280, !dbg !300 + %8118 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8106, 0, !dbg !300 + %8119 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8106, 1, !dbg !300 + %8120 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8106, 2, !dbg !300 + %8121 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8106, 3, !dbg !300 + %8122 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8106, 4, !dbg !300 + %8123 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8106, 5, !dbg !300 + %8124 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8106, 6, !dbg !300 + %8125 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8106, 7, !dbg !300 + %8126 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8106, 8, !dbg !300 + %8127 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8106, 9, !dbg !300 + %8128 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8106, 10, !dbg !300 + %8129 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8106, 11, !dbg !300 + %8130 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8106, 12, !dbg !300 + %8131 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8106, 13, !dbg !300 + %8132 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8106, 14, !dbg !300 + %8133 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8106, 15, !dbg !300 + %8134 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8106, 16, !dbg !300 + %8135 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8106, 17, !dbg !300 + %8136 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8106, 18, !dbg !300 + %8137 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8106, 19, !dbg !300 + %8138 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8106, 20, !dbg !300 + %8139 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8106, 21, !dbg !300 + %8140 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8106, 22, !dbg !300 + %8141 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8106, 23, !dbg !300 + %8142 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8106, 24, !dbg !300 + %8143 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8106, 25, !dbg !300 + %8144 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8106, 26, !dbg !300 + %8145 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8106, 27, !dbg !300 + %8146 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8106, 28, !dbg !300 + %8147 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8106, 29, !dbg !300 + %8148 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8106, 30, !dbg !300 + %8149 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8106, 31, !dbg !300 + %8150 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %8118, float %8119, float %8120, float %8121, float %8122, float %8123, float %8124, float %8125, float %8126, float %8127, float %8128, float %8129, float %8130, float %8131, float %8132, float %8133, float %8134, float %8135, float %8136, float %8137, float %8138, float %8139, float %8140, float %8141, float %8142, float %8143, float %8144, float %8145, float %8146, float %8147, float %8148, float %8149, i64 %8112, i64 %8117, i1 true) #3, !dbg !300 + %8151 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8150, 0, !dbg !300 + %8152 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8150, 1, !dbg !300 + %8153 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8150, 2, !dbg !300 + %8154 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8150, 3, !dbg !300 + %8155 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8150, 4, !dbg !300 + %8156 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8150, 5, !dbg !300 + %8157 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8150, 6, !dbg !300 + %8158 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8150, 7, !dbg !300 + %8159 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8150, 8, !dbg !300 + %8160 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8150, 9, !dbg !300 + %8161 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8150, 10, !dbg !300 + %8162 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8150, 11, !dbg !300 + %8163 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8150, 12, !dbg !300 + %8164 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8150, 13, !dbg !300 + %8165 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8150, 14, !dbg !300 + %8166 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8150, 15, !dbg !300 + %8167 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8150, 16, !dbg !300 + %8168 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8150, 17, !dbg !300 + %8169 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8150, 18, !dbg !300 + %8170 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8150, 19, !dbg !300 + %8171 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8150, 20, !dbg !300 + %8172 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8150, 21, !dbg !300 + %8173 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8150, 22, !dbg !300 + %8174 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8150, 23, !dbg !300 + %8175 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8150, 24, !dbg !300 + %8176 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8150, 25, !dbg !300 + %8177 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8150, 26, !dbg !300 + %8178 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8150, 27, !dbg !300 + %8179 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8150, 28, !dbg !300 + %8180 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8150, 29, !dbg !300 + %8181 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8150, 30, !dbg !300 + %8182 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8150, 31, !dbg !300 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !300 + %8183 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } asm sideeffect "// wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37\0A\09wgmma.wait_group.sync.aligned 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37"(float %8151, float %8152, float %8153, float %8154, float %8155, float %8156, float %8157, float %8158, float %8159, float %8160, float %8161, float %8162, float %8163, float %8164, float %8165, float %8166, float %8167, float %8168, float %8169, float %8170, float %8171, float %8172, float %8173, float %8174, float %8175, float %8176, float %8177, float %8178, float %8179, float %8180, float %8181, float %8182, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328), i32 0, i32 0, ptr addrspace(3) %7762, i32 0, i32 0) #3, !dbg !300 + %8184 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %8183, 0, !dbg !300 + %8185 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %8183, 1, !dbg !300 + %8186 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %8183, 2, !dbg !300 + %8187 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %8183, 3, !dbg !300 + %8188 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %8183, 4, !dbg !300 + %8189 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %8183, 5, !dbg !300 + %8190 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %8183, 6, !dbg !300 + %8191 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %8183, 7, !dbg !300 + %8192 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %8183, 8, !dbg !300 + %8193 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %8183, 9, !dbg !300 + %8194 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %8183, 10, !dbg !300 + %8195 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %8183, 11, !dbg !300 + %8196 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %8183, 12, !dbg !300 + %8197 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %8183, 13, !dbg !300 + %8198 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %8183, 14, !dbg !300 + %8199 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %8183, 15, !dbg !300 + %8200 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %8183, 16, !dbg !300 + %8201 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %8183, 17, !dbg !300 + %8202 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %8183, 18, !dbg !300 + %8203 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %8183, 19, !dbg !300 + %8204 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %8183, 20, !dbg !300 + %8205 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %8183, 21, !dbg !300 + %8206 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %8183, 22, !dbg !300 + %8207 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %8183, 23, !dbg !300 + %8208 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %8183, 24, !dbg !300 + %8209 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %8183, 25, !dbg !300 + %8210 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %8183, 26, !dbg !300 + %8211 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %8183, 27, !dbg !300 + %8212 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %8183, 28, !dbg !300 + %8213 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %8183, 29, !dbg !300 + %8214 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %8183, 30, !dbg !300 + %8215 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %8183, 31, !dbg !300 + %8216 = fmul float %8184, 0x3FB6A09E60000000, !dbg !301 + %8217 = fmul float %8185, 0x3FB6A09E60000000, !dbg !301 + %8218 = fmul float %8186, 0x3FB6A09E60000000, !dbg !301 + %8219 = fmul float %8187, 0x3FB6A09E60000000, !dbg !301 + %8220 = fmul float %8188, 0x3FB6A09E60000000, !dbg !301 + %8221 = fmul float %8189, 0x3FB6A09E60000000, !dbg !301 + %8222 = fmul float %8190, 0x3FB6A09E60000000, !dbg !301 + %8223 = fmul float %8191, 0x3FB6A09E60000000, !dbg !301 + %8224 = fmul float %8192, 0x3FB6A09E60000000, !dbg !301 + %8225 = fmul float %8193, 0x3FB6A09E60000000, !dbg !301 + %8226 = fmul float %8194, 0x3FB6A09E60000000, !dbg !301 + %8227 = fmul float %8195, 0x3FB6A09E60000000, !dbg !301 + %8228 = fmul float %8196, 0x3FB6A09E60000000, !dbg !301 + %8229 = fmul float %8197, 0x3FB6A09E60000000, !dbg !301 + %8230 = fmul float %8198, 0x3FB6A09E60000000, !dbg !301 + %8231 = fmul float %8199, 0x3FB6A09E60000000, !dbg !301 + %8232 = fmul float %8200, 0x3FB6A09E60000000, !dbg !301 + %8233 = fmul float %8201, 0x3FB6A09E60000000, !dbg !301 + %8234 = fmul float %8202, 0x3FB6A09E60000000, !dbg !301 + %8235 = fmul float %8203, 0x3FB6A09E60000000, !dbg !301 + %8236 = fmul float %8204, 0x3FB6A09E60000000, !dbg !301 + %8237 = fmul float %8205, 0x3FB6A09E60000000, !dbg !301 + %8238 = fmul float %8206, 0x3FB6A09E60000000, !dbg !301 + %8239 = fmul float %8207, 0x3FB6A09E60000000, !dbg !301 + %8240 = fmul float %8208, 0x3FB6A09E60000000, !dbg !301 + %8241 = fmul float %8209, 0x3FB6A09E60000000, !dbg !301 + %8242 = fmul float %8210, 0x3FB6A09E60000000, !dbg !301 + %8243 = fmul float %8211, 0x3FB6A09E60000000, !dbg !301 + %8244 = fmul float %8212, 0x3FB6A09E60000000, !dbg !301 + %8245 = fmul float %8213, 0x3FB6A09E60000000, !dbg !301 + %8246 = fmul float %8214, 0x3FB6A09E60000000, !dbg !301 + %8247 = fmul float %8215, 0x3FB6A09E60000000, !dbg !301 + %8248 = fmul float %8216, 0x3FF7154760000000, !dbg !302 + %8249 = fmul float %8217, 0x3FF7154760000000, !dbg !302 + %8250 = fmul float %8218, 0x3FF7154760000000, !dbg !302 + %8251 = fmul float %8219, 0x3FF7154760000000, !dbg !302 + %8252 = fmul float %8220, 0x3FF7154760000000, !dbg !302 + %8253 = fmul float %8221, 0x3FF7154760000000, !dbg !302 + %8254 = fmul float %8222, 0x3FF7154760000000, !dbg !302 + %8255 = fmul float %8223, 0x3FF7154760000000, !dbg !302 + %8256 = fmul float %8224, 0x3FF7154760000000, !dbg !302 + %8257 = fmul float %8225, 0x3FF7154760000000, !dbg !302 + %8258 = fmul float %8226, 0x3FF7154760000000, !dbg !302 + %8259 = fmul float %8227, 0x3FF7154760000000, !dbg !302 + %8260 = fmul float %8228, 0x3FF7154760000000, !dbg !302 + %8261 = fmul float %8229, 0x3FF7154760000000, !dbg !302 + %8262 = fmul float %8230, 0x3FF7154760000000, !dbg !302 + %8263 = fmul float %8231, 0x3FF7154760000000, !dbg !302 + %8264 = fmul float %8232, 0x3FF7154760000000, !dbg !302 + %8265 = fmul float %8233, 0x3FF7154760000000, !dbg !302 + %8266 = fmul float %8234, 0x3FF7154760000000, !dbg !302 + %8267 = fmul float %8235, 0x3FF7154760000000, !dbg !302 + %8268 = fmul float %8236, 0x3FF7154760000000, !dbg !302 + %8269 = fmul float %8237, 0x3FF7154760000000, !dbg !302 + %8270 = fmul float %8238, 0x3FF7154760000000, !dbg !302 + %8271 = fmul float %8239, 0x3FF7154760000000, !dbg !302 + %8272 = fmul float %8240, 0x3FF7154760000000, !dbg !302 + %8273 = fmul float %8241, 0x3FF7154760000000, !dbg !302 + %8274 = fmul float %8242, 0x3FF7154760000000, !dbg !302 + %8275 = fmul float %8243, 0x3FF7154760000000, !dbg !302 + %8276 = fmul float %8244, 0x3FF7154760000000, !dbg !302 + %8277 = fmul float %8245, 0x3FF7154760000000, !dbg !302 + %8278 = fmul float %8246, 0x3FF7154760000000, !dbg !302 + %8279 = fmul float %8247, 0x3FF7154760000000, !dbg !302 + %8280 = fsub float %8248, %7813, !dbg !303 + %8281 = fsub float %8249, %7814, !dbg !303 + %8282 = fsub float %8250, %7813, !dbg !303 + %8283 = fsub float %8251, %7814, !dbg !303 + %8284 = fsub float %8252, %7815, !dbg !303 + %8285 = fsub float %8253, %7816, !dbg !303 + %8286 = fsub float %8254, %7815, !dbg !303 + %8287 = fsub float %8255, %7816, !dbg !303 + %8288 = fsub float %8256, %7817, !dbg !303 + %8289 = fsub float %8257, %7818, !dbg !303 + %8290 = fsub float %8258, %7817, !dbg !303 + %8291 = fsub float %8259, %7818, !dbg !303 + %8292 = fsub float %8260, %7819, !dbg !303 + %8293 = fsub float %8261, %7820, !dbg !303 + %8294 = fsub float %8262, %7819, !dbg !303 + %8295 = fsub float %8263, %7820, !dbg !303 + %8296 = fsub float %8264, %7821, !dbg !303 + %8297 = fsub float %8265, %7822, !dbg !303 + %8298 = fsub float %8266, %7821, !dbg !303 + %8299 = fsub float %8267, %7822, !dbg !303 + %8300 = fsub float %8268, %7823, !dbg !303 + %8301 = fsub float %8269, %7824, !dbg !303 + %8302 = fsub float %8270, %7823, !dbg !303 + %8303 = fsub float %8271, %7824, !dbg !303 + %8304 = fsub float %8272, %7825, !dbg !303 + %8305 = fsub float %8273, %7826, !dbg !303 + %8306 = fsub float %8274, %7825, !dbg !303 + %8307 = fsub float %8275, %7826, !dbg !303 + %8308 = fsub float %8276, %7827, !dbg !303 + %8309 = fsub float %8277, %7828, !dbg !303 + %8310 = fsub float %8278, %7827, !dbg !303 + %8311 = fsub float %8279, %7828, !dbg !303 + %8312 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !304 + %.not.i = icmp eq i32 %8312, 0, !dbg !304 + br i1 %.not.i, label %8315, label %8313, !dbg !304 + +8313: ; preds = %.lr.ph1691 + %8314 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %8280) #3, !dbg !304 + br label %__nv_exp2f.exit, !dbg !304 + +8315: ; preds = %.lr.ph1691 + %8316 = tail call float @llvm.nvvm.ex2.approx.f(float %8280) #3, !dbg !304 + br label %__nv_exp2f.exit, !dbg !304 + +__nv_exp2f.exit: ; preds = %8313, %8315 + %.0.i = phi float [ %8314, %8313 ], [ %8316, %8315 ], !dbg !304 + %8317 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !304 + %.not.i1127 = icmp eq i32 %8317, 0, !dbg !304 + br i1 %.not.i1127, label %8320, label %8318, !dbg !304 + +8318: ; preds = %__nv_exp2f.exit + %8319 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %8281) #3, !dbg !304 + br label %__nv_exp2f.exit1129, !dbg !304 + +8320: ; preds = %__nv_exp2f.exit + %8321 = tail call float @llvm.nvvm.ex2.approx.f(float %8281) #3, !dbg !304 + br label %__nv_exp2f.exit1129, !dbg !304 + +__nv_exp2f.exit1129: ; preds = %8318, %8320 + %.0.i1128 = phi float [ %8319, %8318 ], [ %8321, %8320 ], !dbg !304 + %8322 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !304 + %.not.i1130 = icmp eq i32 %8322, 0, !dbg !304 + br i1 %.not.i1130, label %8325, label %8323, !dbg !304 + +8323: ; preds = %__nv_exp2f.exit1129 + %8324 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %8282) #3, !dbg !304 + br label %__nv_exp2f.exit1132, !dbg !304 + +8325: ; preds = %__nv_exp2f.exit1129 + %8326 = tail call float @llvm.nvvm.ex2.approx.f(float %8282) #3, !dbg !304 + br label %__nv_exp2f.exit1132, !dbg !304 + +__nv_exp2f.exit1132: ; preds = %8323, %8325 + %.0.i1131 = phi float [ %8324, %8323 ], [ %8326, %8325 ], !dbg !304 + %8327 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !304 + %.not.i1133 = icmp eq i32 %8327, 0, !dbg !304 + br i1 %.not.i1133, label %8330, label %8328, !dbg !304 + +8328: ; preds = %__nv_exp2f.exit1132 + %8329 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %8283) #3, !dbg !304 + br label %__nv_exp2f.exit1135, !dbg !304 + +8330: ; preds = %__nv_exp2f.exit1132 + %8331 = tail call float @llvm.nvvm.ex2.approx.f(float %8283) #3, !dbg !304 + br label %__nv_exp2f.exit1135, !dbg !304 + +__nv_exp2f.exit1135: ; preds = %8328, %8330 + %.0.i1134 = phi float [ %8329, %8328 ], [ %8331, %8330 ], !dbg !304 + %8332 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !304 + %.not.i1136 = icmp eq i32 %8332, 0, !dbg !304 + br i1 %.not.i1136, label %8335, label %8333, !dbg !304 + +8333: ; preds = %__nv_exp2f.exit1135 + %8334 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %8284) #3, !dbg !304 + br label %__nv_exp2f.exit1138, !dbg !304 + +8335: ; preds = %__nv_exp2f.exit1135 + %8336 = tail call float @llvm.nvvm.ex2.approx.f(float %8284) #3, !dbg !304 + br label %__nv_exp2f.exit1138, !dbg !304 + +__nv_exp2f.exit1138: ; preds = %8333, %8335 + %.0.i1137 = phi float [ %8334, %8333 ], [ %8336, %8335 ], !dbg !304 + %8337 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !304 + %.not.i1139 = icmp eq i32 %8337, 0, !dbg !304 + br i1 %.not.i1139, label %8340, label %8338, !dbg !304 + +8338: ; preds = %__nv_exp2f.exit1138 + %8339 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %8285) #3, !dbg !304 + br label %__nv_exp2f.exit1141, !dbg !304 + +8340: ; preds = %__nv_exp2f.exit1138 + %8341 = tail call float @llvm.nvvm.ex2.approx.f(float %8285) #3, !dbg !304 + br label %__nv_exp2f.exit1141, !dbg !304 + +__nv_exp2f.exit1141: ; preds = %8338, %8340 + %.0.i1140 = phi float [ %8339, %8338 ], [ %8341, %8340 ], !dbg !304 + %8342 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !304 + %.not.i1142 = icmp eq i32 %8342, 0, !dbg !304 + br i1 %.not.i1142, label %8345, label %8343, !dbg !304 + +8343: ; preds = %__nv_exp2f.exit1141 + %8344 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %8286) #3, !dbg !304 + br label %__nv_exp2f.exit1144, !dbg !304 + +8345: ; preds = %__nv_exp2f.exit1141 + %8346 = tail call float @llvm.nvvm.ex2.approx.f(float %8286) #3, !dbg !304 + br label %__nv_exp2f.exit1144, !dbg !304 + +__nv_exp2f.exit1144: ; preds = %8343, %8345 + %.0.i1143 = phi float [ %8344, %8343 ], [ %8346, %8345 ], !dbg !304 + %8347 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !304 + %.not.i1145 = icmp eq i32 %8347, 0, !dbg !304 + br i1 %.not.i1145, label %8350, label %8348, !dbg !304 + +8348: ; preds = %__nv_exp2f.exit1144 + %8349 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %8287) #3, !dbg !304 + br label %__nv_exp2f.exit1147, !dbg !304 + +8350: ; preds = %__nv_exp2f.exit1144 + %8351 = tail call float @llvm.nvvm.ex2.approx.f(float %8287) #3, !dbg !304 + br label %__nv_exp2f.exit1147, !dbg !304 + +__nv_exp2f.exit1147: ; preds = %8348, %8350 + %.0.i1146 = phi float [ %8349, %8348 ], [ %8351, %8350 ], !dbg !304 + %8352 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !304 + %.not.i1148 = icmp eq i32 %8352, 0, !dbg !304 + br i1 %.not.i1148, label %8355, label %8353, !dbg !304 + +8353: ; preds = %__nv_exp2f.exit1147 + %8354 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %8288) #3, !dbg !304 + br label %__nv_exp2f.exit1150, !dbg !304 + +8355: ; preds = %__nv_exp2f.exit1147 + %8356 = tail call float @llvm.nvvm.ex2.approx.f(float %8288) #3, !dbg !304 + br label %__nv_exp2f.exit1150, !dbg !304 + +__nv_exp2f.exit1150: ; preds = %8353, %8355 + %.0.i1149 = phi float [ %8354, %8353 ], [ %8356, %8355 ], !dbg !304 + %8357 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !304 + %.not.i1151 = icmp eq i32 %8357, 0, !dbg !304 + br i1 %.not.i1151, label %8360, label %8358, !dbg !304 + +8358: ; preds = %__nv_exp2f.exit1150 + %8359 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %8289) #3, !dbg !304 + br label %__nv_exp2f.exit1153, !dbg !304 + +8360: ; preds = %__nv_exp2f.exit1150 + %8361 = tail call float @llvm.nvvm.ex2.approx.f(float %8289) #3, !dbg !304 + br label %__nv_exp2f.exit1153, !dbg !304 + +__nv_exp2f.exit1153: ; preds = %8358, %8360 + %.0.i1152 = phi float [ %8359, %8358 ], [ %8361, %8360 ], !dbg !304 + %8362 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !304 + %.not.i1154 = icmp eq i32 %8362, 0, !dbg !304 + br i1 %.not.i1154, label %8365, label %8363, !dbg !304 + +8363: ; preds = %__nv_exp2f.exit1153 + %8364 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %8290) #3, !dbg !304 + br label %__nv_exp2f.exit1156, !dbg !304 + +8365: ; preds = %__nv_exp2f.exit1153 + %8366 = tail call float @llvm.nvvm.ex2.approx.f(float %8290) #3, !dbg !304 + br label %__nv_exp2f.exit1156, !dbg !304 + +__nv_exp2f.exit1156: ; preds = %8363, %8365 + %.0.i1155 = phi float [ %8364, %8363 ], [ %8366, %8365 ], !dbg !304 + %8367 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !304 + %.not.i1157 = icmp eq i32 %8367, 0, !dbg !304 + br i1 %.not.i1157, label %8370, label %8368, !dbg !304 + +8368: ; preds = %__nv_exp2f.exit1156 + %8369 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %8291) #3, !dbg !304 + br label %__nv_exp2f.exit1159, !dbg !304 + +8370: ; preds = %__nv_exp2f.exit1156 + %8371 = tail call float @llvm.nvvm.ex2.approx.f(float %8291) #3, !dbg !304 + br label %__nv_exp2f.exit1159, !dbg !304 + +__nv_exp2f.exit1159: ; preds = %8368, %8370 + %.0.i1158 = phi float [ %8369, %8368 ], [ %8371, %8370 ], !dbg !304 + %8372 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !304 + %.not.i1160 = icmp eq i32 %8372, 0, !dbg !304 + br i1 %.not.i1160, label %8375, label %8373, !dbg !304 + +8373: ; preds = %__nv_exp2f.exit1159 + %8374 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %8292) #3, !dbg !304 + br label %__nv_exp2f.exit1162, !dbg !304 + +8375: ; preds = %__nv_exp2f.exit1159 + %8376 = tail call float @llvm.nvvm.ex2.approx.f(float %8292) #3, !dbg !304 + br label %__nv_exp2f.exit1162, !dbg !304 + +__nv_exp2f.exit1162: ; preds = %8373, %8375 + %.0.i1161 = phi float [ %8374, %8373 ], [ %8376, %8375 ], !dbg !304 + %8377 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !304 + %.not.i1163 = icmp eq i32 %8377, 0, !dbg !304 + br i1 %.not.i1163, label %8380, label %8378, !dbg !304 + +8378: ; preds = %__nv_exp2f.exit1162 + %8379 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %8293) #3, !dbg !304 + br label %__nv_exp2f.exit1165, !dbg !304 + +8380: ; preds = %__nv_exp2f.exit1162 + %8381 = tail call float @llvm.nvvm.ex2.approx.f(float %8293) #3, !dbg !304 + br label %__nv_exp2f.exit1165, !dbg !304 + +__nv_exp2f.exit1165: ; preds = %8378, %8380 + %.0.i1164 = phi float [ %8379, %8378 ], [ %8381, %8380 ], !dbg !304 + %8382 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !304 + %.not.i1166 = icmp eq i32 %8382, 0, !dbg !304 + br i1 %.not.i1166, label %8385, label %8383, !dbg !304 + +8383: ; preds = %__nv_exp2f.exit1165 + %8384 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %8294) #3, !dbg !304 + br label %__nv_exp2f.exit1168, !dbg !304 + +8385: ; preds = %__nv_exp2f.exit1165 + %8386 = tail call float @llvm.nvvm.ex2.approx.f(float %8294) #3, !dbg !304 + br label %__nv_exp2f.exit1168, !dbg !304 + +__nv_exp2f.exit1168: ; preds = %8383, %8385 + %.0.i1167 = phi float [ %8384, %8383 ], [ %8386, %8385 ], !dbg !304 + %8387 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !304 + %.not.i1169 = icmp eq i32 %8387, 0, !dbg !304 + br i1 %.not.i1169, label %8390, label %8388, !dbg !304 + +8388: ; preds = %__nv_exp2f.exit1168 + %8389 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %8295) #3, !dbg !304 + br label %__nv_exp2f.exit1171, !dbg !304 + +8390: ; preds = %__nv_exp2f.exit1168 + %8391 = tail call float @llvm.nvvm.ex2.approx.f(float %8295) #3, !dbg !304 + br label %__nv_exp2f.exit1171, !dbg !304 + +__nv_exp2f.exit1171: ; preds = %8388, %8390 + %.0.i1170 = phi float [ %8389, %8388 ], [ %8391, %8390 ], !dbg !304 + %8392 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !304 + %.not.i1172 = icmp eq i32 %8392, 0, !dbg !304 + br i1 %.not.i1172, label %8395, label %8393, !dbg !304 + +8393: ; preds = %__nv_exp2f.exit1171 + %8394 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %8296) #3, !dbg !304 + br label %__nv_exp2f.exit1174, !dbg !304 + +8395: ; preds = %__nv_exp2f.exit1171 + %8396 = tail call float @llvm.nvvm.ex2.approx.f(float %8296) #3, !dbg !304 + br label %__nv_exp2f.exit1174, !dbg !304 + +__nv_exp2f.exit1174: ; preds = %8393, %8395 + %.0.i1173 = phi float [ %8394, %8393 ], [ %8396, %8395 ], !dbg !304 + %8397 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !304 + %.not.i1175 = icmp eq i32 %8397, 0, !dbg !304 + br i1 %.not.i1175, label %8400, label %8398, !dbg !304 + +8398: ; preds = %__nv_exp2f.exit1174 + %8399 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %8297) #3, !dbg !304 + br label %__nv_exp2f.exit1177, !dbg !304 + +8400: ; preds = %__nv_exp2f.exit1174 + %8401 = tail call float @llvm.nvvm.ex2.approx.f(float %8297) #3, !dbg !304 + br label %__nv_exp2f.exit1177, !dbg !304 + +__nv_exp2f.exit1177: ; preds = %8398, %8400 + %.0.i1176 = phi float [ %8399, %8398 ], [ %8401, %8400 ], !dbg !304 + %8402 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !304 + %.not.i1178 = icmp eq i32 %8402, 0, !dbg !304 + br i1 %.not.i1178, label %8405, label %8403, !dbg !304 + +8403: ; preds = %__nv_exp2f.exit1177 + %8404 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %8298) #3, !dbg !304 + br label %__nv_exp2f.exit1180, !dbg !304 + +8405: ; preds = %__nv_exp2f.exit1177 + %8406 = tail call float @llvm.nvvm.ex2.approx.f(float %8298) #3, !dbg !304 + br label %__nv_exp2f.exit1180, !dbg !304 + +__nv_exp2f.exit1180: ; preds = %8403, %8405 + %.0.i1179 = phi float [ %8404, %8403 ], [ %8406, %8405 ], !dbg !304 + %8407 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !304 + %.not.i1181 = icmp eq i32 %8407, 0, !dbg !304 + br i1 %.not.i1181, label %8410, label %8408, !dbg !304 + +8408: ; preds = %__nv_exp2f.exit1180 + %8409 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %8299) #3, !dbg !304 + br label %__nv_exp2f.exit1183, !dbg !304 + +8410: ; preds = %__nv_exp2f.exit1180 + %8411 = tail call float @llvm.nvvm.ex2.approx.f(float %8299) #3, !dbg !304 + br label %__nv_exp2f.exit1183, !dbg !304 + +__nv_exp2f.exit1183: ; preds = %8408, %8410 + %.0.i1182 = phi float [ %8409, %8408 ], [ %8411, %8410 ], !dbg !304 + %8412 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !304 + %.not.i1184 = icmp eq i32 %8412, 0, !dbg !304 + br i1 %.not.i1184, label %8415, label %8413, !dbg !304 + +8413: ; preds = %__nv_exp2f.exit1183 + %8414 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %8300) #3, !dbg !304 + br label %__nv_exp2f.exit1186, !dbg !304 + +8415: ; preds = %__nv_exp2f.exit1183 + %8416 = tail call float @llvm.nvvm.ex2.approx.f(float %8300) #3, !dbg !304 + br label %__nv_exp2f.exit1186, !dbg !304 + +__nv_exp2f.exit1186: ; preds = %8413, %8415 + %.0.i1185 = phi float [ %8414, %8413 ], [ %8416, %8415 ], !dbg !304 + %8417 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !304 + %.not.i1187 = icmp eq i32 %8417, 0, !dbg !304 + br i1 %.not.i1187, label %8420, label %8418, !dbg !304 + +8418: ; preds = %__nv_exp2f.exit1186 + %8419 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %8301) #3, !dbg !304 + br label %__nv_exp2f.exit1189, !dbg !304 + +8420: ; preds = %__nv_exp2f.exit1186 + %8421 = tail call float @llvm.nvvm.ex2.approx.f(float %8301) #3, !dbg !304 + br label %__nv_exp2f.exit1189, !dbg !304 + +__nv_exp2f.exit1189: ; preds = %8418, %8420 + %.0.i1188 = phi float [ %8419, %8418 ], [ %8421, %8420 ], !dbg !304 + %8422 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !304 + %.not.i1190 = icmp eq i32 %8422, 0, !dbg !304 + br i1 %.not.i1190, label %8425, label %8423, !dbg !304 + +8423: ; preds = %__nv_exp2f.exit1189 + %8424 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %8302) #3, !dbg !304 + br label %__nv_exp2f.exit1192, !dbg !304 + +8425: ; preds = %__nv_exp2f.exit1189 + %8426 = tail call float @llvm.nvvm.ex2.approx.f(float %8302) #3, !dbg !304 + br label %__nv_exp2f.exit1192, !dbg !304 + +__nv_exp2f.exit1192: ; preds = %8423, %8425 + %.0.i1191 = phi float [ %8424, %8423 ], [ %8426, %8425 ], !dbg !304 + %8427 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !304 + %.not.i1193 = icmp eq i32 %8427, 0, !dbg !304 + br i1 %.not.i1193, label %8430, label %8428, !dbg !304 + +8428: ; preds = %__nv_exp2f.exit1192 + %8429 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %8303) #3, !dbg !304 + br label %__nv_exp2f.exit1195, !dbg !304 + +8430: ; preds = %__nv_exp2f.exit1192 + %8431 = tail call float @llvm.nvvm.ex2.approx.f(float %8303) #3, !dbg !304 + br label %__nv_exp2f.exit1195, !dbg !304 + +__nv_exp2f.exit1195: ; preds = %8428, %8430 + %.0.i1194 = phi float [ %8429, %8428 ], [ %8431, %8430 ], !dbg !304 + %8432 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !304 + %.not.i1196 = icmp eq i32 %8432, 0, !dbg !304 + br i1 %.not.i1196, label %8435, label %8433, !dbg !304 + +8433: ; preds = %__nv_exp2f.exit1195 + %8434 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %8304) #3, !dbg !304 + br label %__nv_exp2f.exit1198, !dbg !304 + +8435: ; preds = %__nv_exp2f.exit1195 + %8436 = tail call float @llvm.nvvm.ex2.approx.f(float %8304) #3, !dbg !304 + br label %__nv_exp2f.exit1198, !dbg !304 + +__nv_exp2f.exit1198: ; preds = %8433, %8435 + %.0.i1197 = phi float [ %8434, %8433 ], [ %8436, %8435 ], !dbg !304 + %8437 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !304 + %.not.i1199 = icmp eq i32 %8437, 0, !dbg !304 + br i1 %.not.i1199, label %8440, label %8438, !dbg !304 + +8438: ; preds = %__nv_exp2f.exit1198 + %8439 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %8305) #3, !dbg !304 + br label %__nv_exp2f.exit1201, !dbg !304 + +8440: ; preds = %__nv_exp2f.exit1198 + %8441 = tail call float @llvm.nvvm.ex2.approx.f(float %8305) #3, !dbg !304 + br label %__nv_exp2f.exit1201, !dbg !304 + +__nv_exp2f.exit1201: ; preds = %8438, %8440 + %.0.i1200 = phi float [ %8439, %8438 ], [ %8441, %8440 ], !dbg !304 + %8442 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !304 + %.not.i1202 = icmp eq i32 %8442, 0, !dbg !304 + br i1 %.not.i1202, label %8445, label %8443, !dbg !304 + +8443: ; preds = %__nv_exp2f.exit1201 + %8444 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %8306) #3, !dbg !304 + br label %__nv_exp2f.exit1204, !dbg !304 + +8445: ; preds = %__nv_exp2f.exit1201 + %8446 = tail call float @llvm.nvvm.ex2.approx.f(float %8306) #3, !dbg !304 + br label %__nv_exp2f.exit1204, !dbg !304 + +__nv_exp2f.exit1204: ; preds = %8443, %8445 + %.0.i1203 = phi float [ %8444, %8443 ], [ %8446, %8445 ], !dbg !304 + %8447 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !304 + %.not.i1205 = icmp eq i32 %8447, 0, !dbg !304 + br i1 %.not.i1205, label %8450, label %8448, !dbg !304 + +8448: ; preds = %__nv_exp2f.exit1204 + %8449 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %8307) #3, !dbg !304 + br label %__nv_exp2f.exit1207, !dbg !304 + +8450: ; preds = %__nv_exp2f.exit1204 + %8451 = tail call float @llvm.nvvm.ex2.approx.f(float %8307) #3, !dbg !304 + br label %__nv_exp2f.exit1207, !dbg !304 + +__nv_exp2f.exit1207: ; preds = %8448, %8450 + %.0.i1206 = phi float [ %8449, %8448 ], [ %8451, %8450 ], !dbg !304 + %8452 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !304 + %.not.i1208 = icmp eq i32 %8452, 0, !dbg !304 + br i1 %.not.i1208, label %8455, label %8453, !dbg !304 + +8453: ; preds = %__nv_exp2f.exit1207 + %8454 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %8308) #3, !dbg !304 + br label %__nv_exp2f.exit1210, !dbg !304 + +8455: ; preds = %__nv_exp2f.exit1207 + %8456 = tail call float @llvm.nvvm.ex2.approx.f(float %8308) #3, !dbg !304 + br label %__nv_exp2f.exit1210, !dbg !304 + +__nv_exp2f.exit1210: ; preds = %8453, %8455 + %.0.i1209 = phi float [ %8454, %8453 ], [ %8456, %8455 ], !dbg !304 + %8457 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !304 + %.not.i1211 = icmp eq i32 %8457, 0, !dbg !304 + br i1 %.not.i1211, label %8460, label %8458, !dbg !304 + +8458: ; preds = %__nv_exp2f.exit1210 + %8459 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %8309) #3, !dbg !304 + br label %__nv_exp2f.exit1213, !dbg !304 + +8460: ; preds = %__nv_exp2f.exit1210 + %8461 = tail call float @llvm.nvvm.ex2.approx.f(float %8309) #3, !dbg !304 + br label %__nv_exp2f.exit1213, !dbg !304 + +__nv_exp2f.exit1213: ; preds = %8458, %8460 + %.0.i1212 = phi float [ %8459, %8458 ], [ %8461, %8460 ], !dbg !304 + %8462 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !304 + %.not.i1214 = icmp eq i32 %8462, 0, !dbg !304 + br i1 %.not.i1214, label %8465, label %8463, !dbg !304 + +8463: ; preds = %__nv_exp2f.exit1213 + %8464 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %8310) #3, !dbg !304 + br label %__nv_exp2f.exit1216, !dbg !304 + +8465: ; preds = %__nv_exp2f.exit1213 + %8466 = tail call float @llvm.nvvm.ex2.approx.f(float %8310) #3, !dbg !304 + br label %__nv_exp2f.exit1216, !dbg !304 + +__nv_exp2f.exit1216: ; preds = %8463, %8465 + %.0.i1215 = phi float [ %8464, %8463 ], [ %8466, %8465 ], !dbg !304 + %8467 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !304 + %.not.i1217 = icmp eq i32 %8467, 0, !dbg !304 + br i1 %.not.i1217, label %8470, label %8468, !dbg !304 + +8468: ; preds = %__nv_exp2f.exit1216 + %8469 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %8311) #3, !dbg !304 + br label %__nv_exp2f.exit1219, !dbg !304 + +8470: ; preds = %__nv_exp2f.exit1216 + %8471 = tail call float @llvm.nvvm.ex2.approx.f(float %8311) #3, !dbg !304 + br label %__nv_exp2f.exit1219, !dbg !304 + +__nv_exp2f.exit1219: ; preds = %8468, %8470 + %.0.i1218 = phi float [ %8469, %8468 ], [ %8471, %8470 ], !dbg !304 + %8472 = getelementptr bfloat, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %7761, !dbg !290 + %8473 = insertelement <2 x float> poison, float %.0.i, i64 0, !dbg !305 + %8474 = insertelement <2 x float> %8473, float %.0.i1128, i64 1, !dbg !305 + %8475 = fptrunc <2 x float> %8474 to <2 x bfloat>, !dbg !305 + %8476 = insertelement <2 x float> poison, float %.0.i1131, i64 0, !dbg !305 + %8477 = insertelement <2 x float> %8476, float %.0.i1134, i64 1, !dbg !305 + %8478 = fptrunc <2 x float> %8477 to <2 x bfloat>, !dbg !305 + %8479 = insertelement <2 x float> poison, float %.0.i1137, i64 0, !dbg !305 + %8480 = insertelement <2 x float> %8479, float %.0.i1140, i64 1, !dbg !305 + %8481 = fptrunc <2 x float> %8480 to <2 x bfloat>, !dbg !305 + %8482 = insertelement <2 x float> poison, float %.0.i1143, i64 0, !dbg !305 + %8483 = insertelement <2 x float> %8482, float %.0.i1146, i64 1, !dbg !305 + %8484 = fptrunc <2 x float> %8483 to <2 x bfloat>, !dbg !305 + %8485 = insertelement <2 x float> poison, float %.0.i1149, i64 0, !dbg !305 + %8486 = insertelement <2 x float> %8485, float %.0.i1152, i64 1, !dbg !305 + %8487 = fptrunc <2 x float> %8486 to <2 x bfloat>, !dbg !305 + %8488 = insertelement <2 x float> poison, float %.0.i1155, i64 0, !dbg !305 + %8489 = insertelement <2 x float> %8488, float %.0.i1158, i64 1, !dbg !305 + %8490 = fptrunc <2 x float> %8489 to <2 x bfloat>, !dbg !305 + %8491 = insertelement <2 x float> poison, float %.0.i1161, i64 0, !dbg !305 + %8492 = insertelement <2 x float> %8491, float %.0.i1164, i64 1, !dbg !305 + %8493 = fptrunc <2 x float> %8492 to <2 x bfloat>, !dbg !305 + %8494 = insertelement <2 x float> poison, float %.0.i1167, i64 0, !dbg !305 + %8495 = insertelement <2 x float> %8494, float %.0.i1170, i64 1, !dbg !305 + %8496 = fptrunc <2 x float> %8495 to <2 x bfloat>, !dbg !305 + %8497 = insertelement <2 x float> poison, float %.0.i1173, i64 0, !dbg !305 + %8498 = insertelement <2 x float> %8497, float %.0.i1176, i64 1, !dbg !305 + %8499 = fptrunc <2 x float> %8498 to <2 x bfloat>, !dbg !305 + %8500 = insertelement <2 x float> poison, float %.0.i1179, i64 0, !dbg !305 + %8501 = insertelement <2 x float> %8500, float %.0.i1182, i64 1, !dbg !305 + %8502 = fptrunc <2 x float> %8501 to <2 x bfloat>, !dbg !305 + %8503 = insertelement <2 x float> poison, float %.0.i1185, i64 0, !dbg !305 + %8504 = insertelement <2 x float> %8503, float %.0.i1188, i64 1, !dbg !305 + %8505 = fptrunc <2 x float> %8504 to <2 x bfloat>, !dbg !305 + %8506 = insertelement <2 x float> poison, float %.0.i1191, i64 0, !dbg !305 + %8507 = insertelement <2 x float> %8506, float %.0.i1194, i64 1, !dbg !305 + %8508 = fptrunc <2 x float> %8507 to <2 x bfloat>, !dbg !305 + %8509 = insertelement <2 x float> poison, float %.0.i1197, i64 0, !dbg !305 + %8510 = insertelement <2 x float> %8509, float %.0.i1200, i64 1, !dbg !305 + %8511 = fptrunc <2 x float> %8510 to <2 x bfloat>, !dbg !305 + %8512 = insertelement <2 x float> poison, float %.0.i1203, i64 0, !dbg !305 + %8513 = insertelement <2 x float> %8512, float %.0.i1206, i64 1, !dbg !305 + %8514 = fptrunc <2 x float> %8513 to <2 x bfloat>, !dbg !305 + %8515 = insertelement <2 x float> poison, float %.0.i1209, i64 0, !dbg !305 + %8516 = insertelement <2 x float> %8515, float %.0.i1212, i64 1, !dbg !305 + %8517 = fptrunc <2 x float> %8516 to <2 x bfloat>, !dbg !305 + %8518 = insertelement <2 x float> poison, float %.0.i1215, i64 0, !dbg !305 + %8519 = insertelement <2 x float> %8518, float %.0.i1218, i64 1, !dbg !305 + %8520 = fptrunc <2 x float> %8519 to <2 x bfloat>, !dbg !305 + %8521 = bitcast <2 x bfloat> %8475 to i32, !dbg !306 + %8522 = bitcast <2 x bfloat> %8478 to i32, !dbg !306 + %8523 = bitcast <2 x bfloat> %8481 to i32, !dbg !306 + %8524 = bitcast <2 x bfloat> %8484 to i32, !dbg !306 + %8525 = bitcast <2 x bfloat> %8487 to i32, !dbg !306 + %8526 = bitcast <2 x bfloat> %8490 to i32, !dbg !306 + %8527 = bitcast <2 x bfloat> %8493 to i32, !dbg !306 + %8528 = bitcast <2 x bfloat> %8496 to i32, !dbg !306 + %8529 = bitcast <2 x bfloat> %8499 to i32, !dbg !306 + %8530 = bitcast <2 x bfloat> %8502 to i32, !dbg !306 + %8531 = bitcast <2 x bfloat> %8505 to i32, !dbg !306 + %8532 = bitcast <2 x bfloat> %8508 to i32, !dbg !306 + %8533 = bitcast <2 x bfloat> %8511 to i32, !dbg !306 + %8534 = bitcast <2 x bfloat> %8514 to i32, !dbg !306 + %8535 = bitcast <2 x bfloat> %8517 to i32, !dbg !306 + %8536 = bitcast <2 x bfloat> %8520 to i32, !dbg !306 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !306 + %8537 = ptrtoint ptr addrspace(3) %8472 to i32, !dbg !306 + %8538 = lshr exact i32 %8537, 4, !dbg !306 + %8539 = and i32 %8538, 16383, !dbg !306 + %8540 = zext nneg i32 %8539 to i64, !dbg !306 + %8541 = or disjoint i64 %8540, 4611686293338849280, !dbg !306 + %8542 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %.pn5041610, float %.pn5021611, float %.pn5001612, float %.pn4981613, float %.pn4961614, float %.pn4941615, float %.pn4921616, float %.pn4901617, float %.pn4881618, float %.pn4861619, float %.pn4841620, float %.pn4821621, float %.pn4801622, float %.pn4781623, float %.pn4761624, float %.pn4741625, float %.pn4721626, float %.pn4701627, float %.pn4681628, float %.pn4661629, float %.pn4641630, float %.pn4621631, float %.pn4601632, float %.pn4581633, float %.pn4561634, float %.pn4541635, float %.pn4521636, float %.pn4501637, float %.pn4481638, float %.pn4461639, float %.pn4441640, float %.pn4421641, float %.pn4401642, float %.pn4381643, float %.pn4361644, float %.pn4341645, float %.pn4321646, float %.pn4301647, float %.pn4281648, float %.pn4261649, float %.pn4241650, float %.pn4221651, float %.pn4201652, float %.pn4181653, float %.pn4161654, float %.pn4141655, float %.pn4121656, float %.pn4101657, float %.pn4081658, float %.pn4061659, float %.pn4041660, float %.pn4021661, float %.pn4001662, float %.pn3981663, float %.pn3961664, float %.pn3941665, float %.pn3921666, float %.pn3901667, float %.pn3881668, float %.pn3861669, float %.pn3841670, float %.pn3821671, float %.pn3801672, float %.pn3781673, i32 %8521, i32 %8522, i32 %8523, i32 %8524, i64 %8541, i1 true) #3, !dbg !306 + %8543 = add i32 %8537, 2048, !dbg !306 + %8544 = lshr exact i32 %8543, 4, !dbg !306 + %8545 = and i32 %8544, 16383, !dbg !306 + %8546 = zext nneg i32 %8545 to i64, !dbg !306 + %8547 = or disjoint i64 %8546, 4611686293338849280, !dbg !306 + %8548 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 0, !dbg !306 + %8549 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 1, !dbg !306 + %8550 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 2, !dbg !306 + %8551 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 3, !dbg !306 + %8552 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 4, !dbg !306 + %8553 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 5, !dbg !306 + %8554 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 6, !dbg !306 + %8555 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 7, !dbg !306 + %8556 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 8, !dbg !306 + %8557 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 9, !dbg !306 + %8558 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 10, !dbg !306 + %8559 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 11, !dbg !306 + %8560 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 12, !dbg !306 + %8561 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 13, !dbg !306 + %8562 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 14, !dbg !306 + %8563 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 15, !dbg !306 + %8564 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 16, !dbg !306 + %8565 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 17, !dbg !306 + %8566 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 18, !dbg !306 + %8567 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 19, !dbg !306 + %8568 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 20, !dbg !306 + %8569 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 21, !dbg !306 + %8570 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 22, !dbg !306 + %8571 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 23, !dbg !306 + %8572 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 24, !dbg !306 + %8573 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 25, !dbg !306 + %8574 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 26, !dbg !306 + %8575 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 27, !dbg !306 + %8576 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 28, !dbg !306 + %8577 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 29, !dbg !306 + %8578 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 30, !dbg !306 + %8579 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 31, !dbg !306 + %8580 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 32, !dbg !306 + %8581 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 33, !dbg !306 + %8582 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 34, !dbg !306 + %8583 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 35, !dbg !306 + %8584 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 36, !dbg !306 + %8585 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 37, !dbg !306 + %8586 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 38, !dbg !306 + %8587 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 39, !dbg !306 + %8588 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 40, !dbg !306 + %8589 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 41, !dbg !306 + %8590 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 42, !dbg !306 + %8591 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 43, !dbg !306 + %8592 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 44, !dbg !306 + %8593 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 45, !dbg !306 + %8594 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 46, !dbg !306 + %8595 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 47, !dbg !306 + %8596 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 48, !dbg !306 + %8597 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 49, !dbg !306 + %8598 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 50, !dbg !306 + %8599 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 51, !dbg !306 + %8600 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 52, !dbg !306 + %8601 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 53, !dbg !306 + %8602 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 54, !dbg !306 + %8603 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 55, !dbg !306 + %8604 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 56, !dbg !306 + %8605 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 57, !dbg !306 + %8606 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 58, !dbg !306 + %8607 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 59, !dbg !306 + %8608 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 60, !dbg !306 + %8609 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 61, !dbg !306 + %8610 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 62, !dbg !306 + %8611 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8542, 63, !dbg !306 + %8612 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %8548, float %8549, float %8550, float %8551, float %8552, float %8553, float %8554, float %8555, float %8556, float %8557, float %8558, float %8559, float %8560, float %8561, float %8562, float %8563, float %8564, float %8565, float %8566, float %8567, float %8568, float %8569, float %8570, float %8571, float %8572, float %8573, float %8574, float %8575, float %8576, float %8577, float %8578, float %8579, float %8580, float %8581, float %8582, float %8583, float %8584, float %8585, float %8586, float %8587, float %8588, float %8589, float %8590, float %8591, float %8592, float %8593, float %8594, float %8595, float %8596, float %8597, float %8598, float %8599, float %8600, float %8601, float %8602, float %8603, float %8604, float %8605, float %8606, float %8607, float %8608, float %8609, float %8610, float %8611, i32 %8525, i32 %8526, i32 %8527, i32 %8528, i64 %8547, i1 true) #3, !dbg !306 + %8613 = add i32 %8537, 4096, !dbg !306 + %8614 = lshr exact i32 %8613, 4, !dbg !306 + %8615 = and i32 %8614, 16383, !dbg !306 + %8616 = zext nneg i32 %8615 to i64, !dbg !306 + %8617 = or disjoint i64 %8616, 4611686293338849280, !dbg !306 + %8618 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 0, !dbg !306 + %8619 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 1, !dbg !306 + %8620 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 2, !dbg !306 + %8621 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 3, !dbg !306 + %8622 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 4, !dbg !306 + %8623 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 5, !dbg !306 + %8624 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 6, !dbg !306 + %8625 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 7, !dbg !306 + %8626 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 8, !dbg !306 + %8627 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 9, !dbg !306 + %8628 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 10, !dbg !306 + %8629 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 11, !dbg !306 + %8630 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 12, !dbg !306 + %8631 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 13, !dbg !306 + %8632 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 14, !dbg !306 + %8633 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 15, !dbg !306 + %8634 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 16, !dbg !306 + %8635 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 17, !dbg !306 + %8636 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 18, !dbg !306 + %8637 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 19, !dbg !306 + %8638 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 20, !dbg !306 + %8639 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 21, !dbg !306 + %8640 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 22, !dbg !306 + %8641 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 23, !dbg !306 + %8642 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 24, !dbg !306 + %8643 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 25, !dbg !306 + %8644 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 26, !dbg !306 + %8645 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 27, !dbg !306 + %8646 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 28, !dbg !306 + %8647 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 29, !dbg !306 + %8648 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 30, !dbg !306 + %8649 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 31, !dbg !306 + %8650 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 32, !dbg !306 + %8651 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 33, !dbg !306 + %8652 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 34, !dbg !306 + %8653 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 35, !dbg !306 + %8654 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 36, !dbg !306 + %8655 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 37, !dbg !306 + %8656 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 38, !dbg !306 + %8657 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 39, !dbg !306 + %8658 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 40, !dbg !306 + %8659 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 41, !dbg !306 + %8660 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 42, !dbg !306 + %8661 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 43, !dbg !306 + %8662 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 44, !dbg !306 + %8663 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 45, !dbg !306 + %8664 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 46, !dbg !306 + %8665 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 47, !dbg !306 + %8666 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 48, !dbg !306 + %8667 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 49, !dbg !306 + %8668 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 50, !dbg !306 + %8669 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 51, !dbg !306 + %8670 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 52, !dbg !306 + %8671 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 53, !dbg !306 + %8672 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 54, !dbg !306 + %8673 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 55, !dbg !306 + %8674 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 56, !dbg !306 + %8675 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 57, !dbg !306 + %8676 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 58, !dbg !306 + %8677 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 59, !dbg !306 + %8678 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 60, !dbg !306 + %8679 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 61, !dbg !306 + %8680 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 62, !dbg !306 + %8681 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8612, 63, !dbg !306 + %8682 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %8618, float %8619, float %8620, float %8621, float %8622, float %8623, float %8624, float %8625, float %8626, float %8627, float %8628, float %8629, float %8630, float %8631, float %8632, float %8633, float %8634, float %8635, float %8636, float %8637, float %8638, float %8639, float %8640, float %8641, float %8642, float %8643, float %8644, float %8645, float %8646, float %8647, float %8648, float %8649, float %8650, float %8651, float %8652, float %8653, float %8654, float %8655, float %8656, float %8657, float %8658, float %8659, float %8660, float %8661, float %8662, float %8663, float %8664, float %8665, float %8666, float %8667, float %8668, float %8669, float %8670, float %8671, float %8672, float %8673, float %8674, float %8675, float %8676, float %8677, float %8678, float %8679, float %8680, float %8681, i32 %8529, i32 %8530, i32 %8531, i32 %8532, i64 %8617, i1 true) #3, !dbg !306 + %8683 = add i32 %8537, 6144, !dbg !306 + %8684 = lshr exact i32 %8683, 4, !dbg !306 + %8685 = and i32 %8684, 16383, !dbg !306 + %8686 = zext nneg i32 %8685 to i64, !dbg !306 + %8687 = or disjoint i64 %8686, 4611686293338849280, !dbg !306 + %8688 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 0, !dbg !306 + %8689 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 1, !dbg !306 + %8690 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 2, !dbg !306 + %8691 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 3, !dbg !306 + %8692 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 4, !dbg !306 + %8693 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 5, !dbg !306 + %8694 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 6, !dbg !306 + %8695 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 7, !dbg !306 + %8696 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 8, !dbg !306 + %8697 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 9, !dbg !306 + %8698 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 10, !dbg !306 + %8699 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 11, !dbg !306 + %8700 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 12, !dbg !306 + %8701 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 13, !dbg !306 + %8702 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 14, !dbg !306 + %8703 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 15, !dbg !306 + %8704 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 16, !dbg !306 + %8705 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 17, !dbg !306 + %8706 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 18, !dbg !306 + %8707 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 19, !dbg !306 + %8708 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 20, !dbg !306 + %8709 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 21, !dbg !306 + %8710 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 22, !dbg !306 + %8711 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 23, !dbg !306 + %8712 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 24, !dbg !306 + %8713 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 25, !dbg !306 + %8714 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 26, !dbg !306 + %8715 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 27, !dbg !306 + %8716 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 28, !dbg !306 + %8717 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 29, !dbg !306 + %8718 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 30, !dbg !306 + %8719 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 31, !dbg !306 + %8720 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 32, !dbg !306 + %8721 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 33, !dbg !306 + %8722 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 34, !dbg !306 + %8723 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 35, !dbg !306 + %8724 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 36, !dbg !306 + %8725 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 37, !dbg !306 + %8726 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 38, !dbg !306 + %8727 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 39, !dbg !306 + %8728 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 40, !dbg !306 + %8729 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 41, !dbg !306 + %8730 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 42, !dbg !306 + %8731 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 43, !dbg !306 + %8732 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 44, !dbg !306 + %8733 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 45, !dbg !306 + %8734 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 46, !dbg !306 + %8735 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 47, !dbg !306 + %8736 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 48, !dbg !306 + %8737 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 49, !dbg !306 + %8738 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 50, !dbg !306 + %8739 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 51, !dbg !306 + %8740 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 52, !dbg !306 + %8741 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 53, !dbg !306 + %8742 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 54, !dbg !306 + %8743 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 55, !dbg !306 + %8744 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 56, !dbg !306 + %8745 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 57, !dbg !306 + %8746 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 58, !dbg !306 + %8747 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 59, !dbg !306 + %8748 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 60, !dbg !306 + %8749 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 61, !dbg !306 + %8750 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 62, !dbg !306 + %8751 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8682, 63, !dbg !306 + %8752 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %8688, float %8689, float %8690, float %8691, float %8692, float %8693, float %8694, float %8695, float %8696, float %8697, float %8698, float %8699, float %8700, float %8701, float %8702, float %8703, float %8704, float %8705, float %8706, float %8707, float %8708, float %8709, float %8710, float %8711, float %8712, float %8713, float %8714, float %8715, float %8716, float %8717, float %8718, float %8719, float %8720, float %8721, float %8722, float %8723, float %8724, float %8725, float %8726, float %8727, float %8728, float %8729, float %8730, float %8731, float %8732, float %8733, float %8734, float %8735, float %8736, float %8737, float %8738, float %8739, float %8740, float %8741, float %8742, float %8743, float %8744, float %8745, float %8746, float %8747, float %8748, float %8749, float %8750, float %8751, i32 %8533, i32 %8534, i32 %8535, i32 %8536, i64 %8687, i1 true) #3, !dbg !306 + %8753 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 0, !dbg !306 + %8754 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 1, !dbg !306 + %8755 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 2, !dbg !306 + %8756 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 3, !dbg !306 + %8757 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 4, !dbg !306 + %8758 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 5, !dbg !306 + %8759 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 6, !dbg !306 + %8760 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 7, !dbg !306 + %8761 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 8, !dbg !306 + %8762 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 9, !dbg !306 + %8763 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 10, !dbg !306 + %8764 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 11, !dbg !306 + %8765 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 12, !dbg !306 + %8766 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 13, !dbg !306 + %8767 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 14, !dbg !306 + %8768 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 15, !dbg !306 + %8769 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 16, !dbg !306 + %8770 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 17, !dbg !306 + %8771 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 18, !dbg !306 + %8772 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 19, !dbg !306 + %8773 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 20, !dbg !306 + %8774 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 21, !dbg !306 + %8775 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 22, !dbg !306 + %8776 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 23, !dbg !306 + %8777 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 24, !dbg !306 + %8778 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 25, !dbg !306 + %8779 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 26, !dbg !306 + %8780 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 27, !dbg !306 + %8781 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 28, !dbg !306 + %8782 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 29, !dbg !306 + %8783 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 30, !dbg !306 + %8784 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 31, !dbg !306 + %8785 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 32, !dbg !306 + %8786 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 33, !dbg !306 + %8787 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 34, !dbg !306 + %8788 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 35, !dbg !306 + %8789 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 36, !dbg !306 + %8790 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 37, !dbg !306 + %8791 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 38, !dbg !306 + %8792 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 39, !dbg !306 + %8793 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 40, !dbg !306 + %8794 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 41, !dbg !306 + %8795 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 42, !dbg !306 + %8796 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 43, !dbg !306 + %8797 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 44, !dbg !306 + %8798 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 45, !dbg !306 + %8799 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 46, !dbg !306 + %8800 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 47, !dbg !306 + %8801 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 48, !dbg !306 + %8802 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 49, !dbg !306 + %8803 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 50, !dbg !306 + %8804 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 51, !dbg !306 + %8805 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 52, !dbg !306 + %8806 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 53, !dbg !306 + %8807 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 54, !dbg !306 + %8808 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 55, !dbg !306 + %8809 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 56, !dbg !306 + %8810 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 57, !dbg !306 + %8811 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 58, !dbg !306 + %8812 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 59, !dbg !306 + %8813 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 60, !dbg !306 + %8814 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 61, !dbg !306 + %8815 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 62, !dbg !306 + %8816 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8752, 63, !dbg !306 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !306 + %8817 = getelementptr float, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %7763, !dbg !294 + %8818 = getelementptr inbounds nuw i8, ptr addrspace(3) %8817, i32 %4812, !dbg !294 + %8819 = getelementptr inbounds nuw i8, ptr addrspace(3) %8817, i32 %4815, !dbg !294 + %8820 = getelementptr inbounds nuw i8, ptr addrspace(3) %8817, i32 %4817, !dbg !294 + %8821 = getelementptr inbounds nuw i8, ptr addrspace(3) %8817, i32 %4819, !dbg !294 + %8822 = getelementptr inbounds nuw i8, ptr addrspace(3) %8817, i32 %4821, !dbg !294 + %8823 = getelementptr inbounds nuw i8, ptr addrspace(3) %8817, i32 %4823, !dbg !294 + %8824 = getelementptr inbounds nuw i8, ptr addrspace(3) %8817, i32 %4825, !dbg !294 + %8825 = getelementptr inbounds nuw i8, ptr addrspace(3) %8817, i32 %4827, !dbg !294 + %8826 = add i32 %7831, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !307 + %8827 = lshr exact i32 %8826, 4, !dbg !307 + %8828 = and i32 %8827, 16383, !dbg !307 + %8829 = zext nneg i32 %8828 to i64, !dbg !307 + %8830 = or disjoint i64 %8829, 4611686293372403712, !dbg !307 + %8831 = add i32 %7843, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !307 + %8832 = lshr exact i32 %8831, 4, !dbg !307 + %8833 = and i32 %8832, 16383, !dbg !307 + %8834 = zext nneg i32 %8833 to i64, !dbg !307 + %8835 = or disjoint i64 %8834, 4611686293372403712, !dbg !307 + %8836 = add i32 %8537, 32, !dbg !307 + %8837 = lshr exact i32 %8836, 4, !dbg !307 + %8838 = and i32 %8837, 16383, !dbg !307 + %8839 = zext nneg i32 %8838 to i64, !dbg !307 + %8840 = or disjoint i64 %8839, 4611686293338849280, !dbg !307 + %8841 = add i32 %7887, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !307 + %8842 = lshr exact i32 %8841, 4, !dbg !307 + %8843 = and i32 %8842, 16383, !dbg !307 + %8844 = zext nneg i32 %8843 to i64, !dbg !307 + %8845 = or disjoint i64 %8844, 4611686293372403712, !dbg !307 + %8846 = add i32 %8537, 64, !dbg !307 + %8847 = lshr exact i32 %8846, 4, !dbg !307 + %8848 = and i32 %8847, 16383, !dbg !307 + %8849 = zext nneg i32 %8848 to i64, !dbg !307 + %8850 = or disjoint i64 %8849, 4611686293338849280, !dbg !307 + %8851 = add i32 %7931, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !307 + %8852 = lshr exact i32 %8851, 4, !dbg !307 + %8853 = and i32 %8852, 16383, !dbg !307 + %8854 = zext nneg i32 %8853 to i64, !dbg !307 + %8855 = or disjoint i64 %8854, 4611686293372403712, !dbg !307 + %8856 = add i32 %8537, 96, !dbg !307 + %8857 = lshr exact i32 %8856, 4, !dbg !307 + %8858 = and i32 %8857, 16383, !dbg !307 + %8859 = zext nneg i32 %8858 to i64, !dbg !307 + %8860 = or disjoint i64 %8859, 4611686293338849280, !dbg !307 + %8861 = add i32 %7975, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !307 + %8862 = lshr exact i32 %8861, 4, !dbg !307 + %8863 = and i32 %8862, 16383, !dbg !307 + %8864 = zext nneg i32 %8863 to i64, !dbg !307 + %8865 = or disjoint i64 %8864, 4611686293372403712, !dbg !307 + %8866 = add i32 %8537, 8192, !dbg !307 + %8867 = lshr exact i32 %8866, 4, !dbg !307 + %8868 = and i32 %8867, 16383, !dbg !307 + %8869 = zext nneg i32 %8868 to i64, !dbg !307 + %8870 = or disjoint i64 %8869, 4611686293338849280, !dbg !307 + %8871 = add i32 %8019, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !307 + %8872 = lshr exact i32 %8871, 4, !dbg !307 + %8873 = and i32 %8872, 16383, !dbg !307 + %8874 = zext nneg i32 %8873 to i64, !dbg !307 + %8875 = or disjoint i64 %8874, 4611686293372403712, !dbg !307 + %8876 = add i32 %8537, 8224, !dbg !307 + %8877 = lshr exact i32 %8876, 4, !dbg !307 + %8878 = and i32 %8877, 16383, !dbg !307 + %8879 = zext nneg i32 %8878 to i64, !dbg !307 + %8880 = or disjoint i64 %8879, 4611686293338849280, !dbg !307 + %8881 = add i32 %8063, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !307 + %8882 = lshr exact i32 %8881, 4, !dbg !307 + %8883 = and i32 %8882, 16383, !dbg !307 + %8884 = zext nneg i32 %8883 to i64, !dbg !307 + %8885 = or disjoint i64 %8884, 4611686293372403712, !dbg !307 + %8886 = add i32 %8537, 8256, !dbg !307 + %8887 = lshr exact i32 %8886, 4, !dbg !307 + %8888 = and i32 %8887, 16383, !dbg !307 + %8889 = zext nneg i32 %8888 to i64, !dbg !307 + %8890 = or disjoint i64 %8889, 4611686293338849280, !dbg !307 + %8891 = add i32 %8107, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !307 + %8892 = lshr exact i32 %8891, 4, !dbg !307 + %8893 = and i32 %8892, 16383, !dbg !307 + %8894 = zext nneg i32 %8893 to i64, !dbg !307 + %8895 = or disjoint i64 %8894, 4611686293372403712, !dbg !307 + %8896 = add i32 %8537, 8288, !dbg !307 + %8897 = lshr exact i32 %8896, 4, !dbg !307 + %8898 = and i32 %8897, 16383, !dbg !307 + %8899 = zext nneg i32 %8898 to i64, !dbg !307 + %8900 = or disjoint i64 %8899, 4611686293338849280, !dbg !307 + %8901 = load <2 x float>, ptr addrspace(3) %8825, align 8, !dbg !294 + %8902 = load <2 x float>, ptr addrspace(3) %8824, align 8, !dbg !294 + %8903 = load <2 x float>, ptr addrspace(3) %8823, align 8, !dbg !294 + %8904 = load <2 x float>, ptr addrspace(3) %8822, align 8, !dbg !294 + %8905 = load <2 x float>, ptr addrspace(3) %8821, align 8, !dbg !294 + %8906 = load <2 x float>, ptr addrspace(3) %8820, align 8, !dbg !294 + %8907 = load <2 x float>, ptr addrspace(3) %8819, align 8, !dbg !294 + %8908 = load <2 x float>, ptr addrspace(3) %8818, align 8, !dbg !294 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !307 + %8909 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $32, $33, 0, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,l,l"(i64 %8830, i64 %8541) #3, !dbg !307 + %8910 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8909, 0, !dbg !307 + %8911 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8909, 1, !dbg !307 + %8912 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8909, 2, !dbg !307 + %8913 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8909, 3, !dbg !307 + %8914 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8909, 4, !dbg !307 + %8915 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8909, 5, !dbg !307 + %8916 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8909, 6, !dbg !307 + %8917 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8909, 7, !dbg !307 + %8918 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8909, 8, !dbg !307 + %8919 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8909, 9, !dbg !307 + %8920 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8909, 10, !dbg !307 + %8921 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8909, 11, !dbg !307 + %8922 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8909, 12, !dbg !307 + %8923 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8909, 13, !dbg !307 + %8924 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8909, 14, !dbg !307 + %8925 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8909, 15, !dbg !307 + %8926 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8909, 16, !dbg !307 + %8927 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8909, 17, !dbg !307 + %8928 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8909, 18, !dbg !307 + %8929 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8909, 19, !dbg !307 + %8930 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8909, 20, !dbg !307 + %8931 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8909, 21, !dbg !307 + %8932 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8909, 22, !dbg !307 + %8933 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8909, 23, !dbg !307 + %8934 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8909, 24, !dbg !307 + %8935 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8909, 25, !dbg !307 + %8936 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8909, 26, !dbg !307 + %8937 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8909, 27, !dbg !307 + %8938 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8909, 28, !dbg !307 + %8939 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8909, 29, !dbg !307 + %8940 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8909, 30, !dbg !307 + %8941 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8909, 31, !dbg !307 + %8942 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %8910, float %8911, float %8912, float %8913, float %8914, float %8915, float %8916, float %8917, float %8918, float %8919, float %8920, float %8921, float %8922, float %8923, float %8924, float %8925, float %8926, float %8927, float %8928, float %8929, float %8930, float %8931, float %8932, float %8933, float %8934, float %8935, float %8936, float %8937, float %8938, float %8939, float %8940, float %8941, i64 %8835, i64 %8840, i1 true) #3, !dbg !307 + %8943 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8942, 0, !dbg !307 + %8944 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8942, 1, !dbg !307 + %8945 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8942, 2, !dbg !307 + %8946 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8942, 3, !dbg !307 + %8947 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8942, 4, !dbg !307 + %8948 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8942, 5, !dbg !307 + %8949 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8942, 6, !dbg !307 + %8950 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8942, 7, !dbg !307 + %8951 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8942, 8, !dbg !307 + %8952 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8942, 9, !dbg !307 + %8953 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8942, 10, !dbg !307 + %8954 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8942, 11, !dbg !307 + %8955 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8942, 12, !dbg !307 + %8956 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8942, 13, !dbg !307 + %8957 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8942, 14, !dbg !307 + %8958 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8942, 15, !dbg !307 + %8959 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8942, 16, !dbg !307 + %8960 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8942, 17, !dbg !307 + %8961 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8942, 18, !dbg !307 + %8962 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8942, 19, !dbg !307 + %8963 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8942, 20, !dbg !307 + %8964 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8942, 21, !dbg !307 + %8965 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8942, 22, !dbg !307 + %8966 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8942, 23, !dbg !307 + %8967 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8942, 24, !dbg !307 + %8968 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8942, 25, !dbg !307 + %8969 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8942, 26, !dbg !307 + %8970 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8942, 27, !dbg !307 + %8971 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8942, 28, !dbg !307 + %8972 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8942, 29, !dbg !307 + %8973 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8942, 30, !dbg !307 + %8974 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8942, 31, !dbg !307 + %8975 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %8943, float %8944, float %8945, float %8946, float %8947, float %8948, float %8949, float %8950, float %8951, float %8952, float %8953, float %8954, float %8955, float %8956, float %8957, float %8958, float %8959, float %8960, float %8961, float %8962, float %8963, float %8964, float %8965, float %8966, float %8967, float %8968, float %8969, float %8970, float %8971, float %8972, float %8973, float %8974, i64 %8845, i64 %8850, i1 true) #3, !dbg !307 + %8976 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8975, 0, !dbg !307 + %8977 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8975, 1, !dbg !307 + %8978 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8975, 2, !dbg !307 + %8979 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8975, 3, !dbg !307 + %8980 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8975, 4, !dbg !307 + %8981 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8975, 5, !dbg !307 + %8982 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8975, 6, !dbg !307 + %8983 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8975, 7, !dbg !307 + %8984 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8975, 8, !dbg !307 + %8985 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8975, 9, !dbg !307 + %8986 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8975, 10, !dbg !307 + %8987 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8975, 11, !dbg !307 + %8988 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8975, 12, !dbg !307 + %8989 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8975, 13, !dbg !307 + %8990 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8975, 14, !dbg !307 + %8991 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8975, 15, !dbg !307 + %8992 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8975, 16, !dbg !307 + %8993 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8975, 17, !dbg !307 + %8994 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8975, 18, !dbg !307 + %8995 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8975, 19, !dbg !307 + %8996 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8975, 20, !dbg !307 + %8997 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8975, 21, !dbg !307 + %8998 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8975, 22, !dbg !307 + %8999 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8975, 23, !dbg !307 + %9000 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8975, 24, !dbg !307 + %9001 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8975, 25, !dbg !307 + %9002 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8975, 26, !dbg !307 + %9003 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8975, 27, !dbg !307 + %9004 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8975, 28, !dbg !307 + %9005 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8975, 29, !dbg !307 + %9006 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8975, 30, !dbg !307 + %9007 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8975, 31, !dbg !307 + %9008 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %8976, float %8977, float %8978, float %8979, float %8980, float %8981, float %8982, float %8983, float %8984, float %8985, float %8986, float %8987, float %8988, float %8989, float %8990, float %8991, float %8992, float %8993, float %8994, float %8995, float %8996, float %8997, float %8998, float %8999, float %9000, float %9001, float %9002, float %9003, float %9004, float %9005, float %9006, float %9007, i64 %8855, i64 %8860, i1 true) #3, !dbg !307 + %9009 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9008, 0, !dbg !307 + %9010 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9008, 1, !dbg !307 + %9011 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9008, 2, !dbg !307 + %9012 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9008, 3, !dbg !307 + %9013 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9008, 4, !dbg !307 + %9014 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9008, 5, !dbg !307 + %9015 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9008, 6, !dbg !307 + %9016 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9008, 7, !dbg !307 + %9017 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9008, 8, !dbg !307 + %9018 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9008, 9, !dbg !307 + %9019 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9008, 10, !dbg !307 + %9020 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9008, 11, !dbg !307 + %9021 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9008, 12, !dbg !307 + %9022 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9008, 13, !dbg !307 + %9023 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9008, 14, !dbg !307 + %9024 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9008, 15, !dbg !307 + %9025 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9008, 16, !dbg !307 + %9026 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9008, 17, !dbg !307 + %9027 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9008, 18, !dbg !307 + %9028 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9008, 19, !dbg !307 + %9029 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9008, 20, !dbg !307 + %9030 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9008, 21, !dbg !307 + %9031 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9008, 22, !dbg !307 + %9032 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9008, 23, !dbg !307 + %9033 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9008, 24, !dbg !307 + %9034 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9008, 25, !dbg !307 + %9035 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9008, 26, !dbg !307 + %9036 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9008, 27, !dbg !307 + %9037 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9008, 28, !dbg !307 + %9038 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9008, 29, !dbg !307 + %9039 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9008, 30, !dbg !307 + %9040 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9008, 31, !dbg !307 + %9041 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %9009, float %9010, float %9011, float %9012, float %9013, float %9014, float %9015, float %9016, float %9017, float %9018, float %9019, float %9020, float %9021, float %9022, float %9023, float %9024, float %9025, float %9026, float %9027, float %9028, float %9029, float %9030, float %9031, float %9032, float %9033, float %9034, float %9035, float %9036, float %9037, float %9038, float %9039, float %9040, i64 %8865, i64 %8870, i1 true) #3, !dbg !307 + %9042 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9041, 0, !dbg !307 + %9043 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9041, 1, !dbg !307 + %9044 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9041, 2, !dbg !307 + %9045 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9041, 3, !dbg !307 + %9046 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9041, 4, !dbg !307 + %9047 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9041, 5, !dbg !307 + %9048 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9041, 6, !dbg !307 + %9049 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9041, 7, !dbg !307 + %9050 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9041, 8, !dbg !307 + %9051 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9041, 9, !dbg !307 + %9052 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9041, 10, !dbg !307 + %9053 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9041, 11, !dbg !307 + %9054 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9041, 12, !dbg !307 + %9055 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9041, 13, !dbg !307 + %9056 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9041, 14, !dbg !307 + %9057 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9041, 15, !dbg !307 + %9058 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9041, 16, !dbg !307 + %9059 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9041, 17, !dbg !307 + %9060 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9041, 18, !dbg !307 + %9061 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9041, 19, !dbg !307 + %9062 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9041, 20, !dbg !307 + %9063 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9041, 21, !dbg !307 + %9064 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9041, 22, !dbg !307 + %9065 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9041, 23, !dbg !307 + %9066 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9041, 24, !dbg !307 + %9067 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9041, 25, !dbg !307 + %9068 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9041, 26, !dbg !307 + %9069 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9041, 27, !dbg !307 + %9070 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9041, 28, !dbg !307 + %9071 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9041, 29, !dbg !307 + %9072 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9041, 30, !dbg !307 + %9073 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9041, 31, !dbg !307 + %9074 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %9042, float %9043, float %9044, float %9045, float %9046, float %9047, float %9048, float %9049, float %9050, float %9051, float %9052, float %9053, float %9054, float %9055, float %9056, float %9057, float %9058, float %9059, float %9060, float %9061, float %9062, float %9063, float %9064, float %9065, float %9066, float %9067, float %9068, float %9069, float %9070, float %9071, float %9072, float %9073, i64 %8875, i64 %8880, i1 true) #3, !dbg !307 + %9075 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9074, 0, !dbg !307 + %9076 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9074, 1, !dbg !307 + %9077 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9074, 2, !dbg !307 + %9078 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9074, 3, !dbg !307 + %9079 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9074, 4, !dbg !307 + %9080 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9074, 5, !dbg !307 + %9081 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9074, 6, !dbg !307 + %9082 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9074, 7, !dbg !307 + %9083 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9074, 8, !dbg !307 + %9084 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9074, 9, !dbg !307 + %9085 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9074, 10, !dbg !307 + %9086 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9074, 11, !dbg !307 + %9087 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9074, 12, !dbg !307 + %9088 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9074, 13, !dbg !307 + %9089 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9074, 14, !dbg !307 + %9090 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9074, 15, !dbg !307 + %9091 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9074, 16, !dbg !307 + %9092 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9074, 17, !dbg !307 + %9093 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9074, 18, !dbg !307 + %9094 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9074, 19, !dbg !307 + %9095 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9074, 20, !dbg !307 + %9096 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9074, 21, !dbg !307 + %9097 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9074, 22, !dbg !307 + %9098 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9074, 23, !dbg !307 + %9099 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9074, 24, !dbg !307 + %9100 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9074, 25, !dbg !307 + %9101 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9074, 26, !dbg !307 + %9102 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9074, 27, !dbg !307 + %9103 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9074, 28, !dbg !307 + %9104 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9074, 29, !dbg !307 + %9105 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9074, 30, !dbg !307 + %9106 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9074, 31, !dbg !307 + %9107 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %9075, float %9076, float %9077, float %9078, float %9079, float %9080, float %9081, float %9082, float %9083, float %9084, float %9085, float %9086, float %9087, float %9088, float %9089, float %9090, float %9091, float %9092, float %9093, float %9094, float %9095, float %9096, float %9097, float %9098, float %9099, float %9100, float %9101, float %9102, float %9103, float %9104, float %9105, float %9106, i64 %8885, i64 %8890, i1 true) #3, !dbg !307 + %9108 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9107, 0, !dbg !307 + %9109 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9107, 1, !dbg !307 + %9110 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9107, 2, !dbg !307 + %9111 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9107, 3, !dbg !307 + %9112 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9107, 4, !dbg !307 + %9113 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9107, 5, !dbg !307 + %9114 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9107, 6, !dbg !307 + %9115 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9107, 7, !dbg !307 + %9116 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9107, 8, !dbg !307 + %9117 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9107, 9, !dbg !307 + %9118 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9107, 10, !dbg !307 + %9119 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9107, 11, !dbg !307 + %9120 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9107, 12, !dbg !307 + %9121 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9107, 13, !dbg !307 + %9122 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9107, 14, !dbg !307 + %9123 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9107, 15, !dbg !307 + %9124 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9107, 16, !dbg !307 + %9125 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9107, 17, !dbg !307 + %9126 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9107, 18, !dbg !307 + %9127 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9107, 19, !dbg !307 + %9128 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9107, 20, !dbg !307 + %9129 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9107, 21, !dbg !307 + %9130 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9107, 22, !dbg !307 + %9131 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9107, 23, !dbg !307 + %9132 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9107, 24, !dbg !307 + %9133 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9107, 25, !dbg !307 + %9134 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9107, 26, !dbg !307 + %9135 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9107, 27, !dbg !307 + %9136 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9107, 28, !dbg !307 + %9137 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9107, 29, !dbg !307 + %9138 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9107, 30, !dbg !307 + %9139 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9107, 31, !dbg !307 + %9140 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %9108, float %9109, float %9110, float %9111, float %9112, float %9113, float %9114, float %9115, float %9116, float %9117, float %9118, float %9119, float %9120, float %9121, float %9122, float %9123, float %9124, float %9125, float %9126, float %9127, float %9128, float %9129, float %9130, float %9131, float %9132, float %9133, float %9134, float %9135, float %9136, float %9137, float %9138, float %9139, i64 %8895, i64 %8900, i1 true) #3, !dbg !307 + %9141 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9140, 0, !dbg !307 + %9142 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9140, 1, !dbg !307 + %9143 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9140, 2, !dbg !307 + %9144 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9140, 3, !dbg !307 + %9145 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9140, 4, !dbg !307 + %9146 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9140, 5, !dbg !307 + %9147 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9140, 6, !dbg !307 + %9148 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9140, 7, !dbg !307 + %9149 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9140, 8, !dbg !307 + %9150 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9140, 9, !dbg !307 + %9151 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9140, 10, !dbg !307 + %9152 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9140, 11, !dbg !307 + %9153 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9140, 12, !dbg !307 + %9154 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9140, 13, !dbg !307 + %9155 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9140, 14, !dbg !307 + %9156 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9140, 15, !dbg !307 + %9157 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9140, 16, !dbg !307 + %9158 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9140, 17, !dbg !307 + %9159 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9140, 18, !dbg !307 + %9160 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9140, 19, !dbg !307 + %9161 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9140, 20, !dbg !307 + %9162 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9140, 21, !dbg !307 + %9163 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9140, 22, !dbg !307 + %9164 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9140, 23, !dbg !307 + %9165 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9140, 24, !dbg !307 + %9166 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9140, 25, !dbg !307 + %9167 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9140, 26, !dbg !307 + %9168 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9140, 27, !dbg !307 + %9169 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9140, 28, !dbg !307 + %9170 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9140, 29, !dbg !307 + %9171 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9140, 30, !dbg !307 + %9172 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9140, 31, !dbg !307 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !307 + %9173 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } asm sideeffect "// wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37\0A\09wgmma.wait_group.sync.aligned 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37"(float %9141, float %9142, float %9143, float %9144, float %9145, float %9146, float %9147, float %9148, float %9149, float %9150, float %9151, float %9152, float %9153, float %9154, float %9155, float %9156, float %9157, float %9158, float %9159, float %9160, float %9161, float %9162, float %9163, float %9164, float %9165, float %9166, float %9167, float %9168, float %9169, float %9170, float %9171, float %9172, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096), i32 0, i32 0, ptr addrspace(3) %8472, i32 0, i32 0) #3, !dbg !307 + %9174 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9173, 0, !dbg !307 + %9175 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9173, 1, !dbg !307 + %9176 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9173, 2, !dbg !307 + %9177 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9173, 3, !dbg !307 + %9178 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9173, 4, !dbg !307 + %9179 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9173, 5, !dbg !307 + %9180 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9173, 6, !dbg !307 + %9181 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9173, 7, !dbg !307 + %9182 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9173, 8, !dbg !307 + %9183 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9173, 9, !dbg !307 + %9184 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9173, 10, !dbg !307 + %9185 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9173, 11, !dbg !307 + %9186 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9173, 12, !dbg !307 + %9187 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9173, 13, !dbg !307 + %9188 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9173, 14, !dbg !307 + %9189 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9173, 15, !dbg !307 + %9190 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9173, 16, !dbg !307 + %9191 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9173, 17, !dbg !307 + %9192 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9173, 18, !dbg !307 + %9193 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9173, 19, !dbg !307 + %9194 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9173, 20, !dbg !307 + %9195 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9173, 21, !dbg !307 + %9196 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9173, 22, !dbg !307 + %9197 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9173, 23, !dbg !307 + %9198 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9173, 24, !dbg !307 + %9199 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9173, 25, !dbg !307 + %9200 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9173, 26, !dbg !307 + %9201 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9173, 27, !dbg !307 + %9202 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9173, 28, !dbg !307 + %9203 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9173, 29, !dbg !307 + %9204 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9173, 30, !dbg !307 + %9205 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9173, 31, !dbg !307 + %9206 = insertelement <2 x float> poison, float %9176, i64 0, !dbg !308 + %9207 = insertelement <2 x float> %9206, float %9177, i64 1, !dbg !308 + %9208 = fsub <2 x float> %9207, %8908, !dbg !308 + %9209 = insertelement <2 x float> poison, float %9180, i64 0, !dbg !308 + %9210 = insertelement <2 x float> %9209, float %9181, i64 1, !dbg !308 + %9211 = fsub <2 x float> %9210, %8907, !dbg !308 + %9212 = insertelement <2 x float> poison, float %9184, i64 0, !dbg !308 + %9213 = insertelement <2 x float> %9212, float %9185, i64 1, !dbg !308 + %9214 = fsub <2 x float> %9213, %8906, !dbg !308 + %9215 = insertelement <2 x float> poison, float %9188, i64 0, !dbg !308 + %9216 = insertelement <2 x float> %9215, float %9189, i64 1, !dbg !308 + %9217 = fsub <2 x float> %9216, %8905, !dbg !308 + %9218 = insertelement <2 x float> poison, float %9192, i64 0, !dbg !308 + %9219 = insertelement <2 x float> %9218, float %9193, i64 1, !dbg !308 + %9220 = fsub <2 x float> %9219, %8904, !dbg !308 + %9221 = insertelement <2 x float> poison, float %9196, i64 0, !dbg !308 + %9222 = insertelement <2 x float> %9221, float %9197, i64 1, !dbg !308 + %9223 = fsub <2 x float> %9222, %8903, !dbg !308 + %9224 = insertelement <2 x float> poison, float %9200, i64 0, !dbg !308 + %9225 = insertelement <2 x float> %9224, float %9201, i64 1, !dbg !308 + %9226 = fsub <2 x float> %9225, %8902, !dbg !308 + %9227 = insertelement <2 x float> poison, float %9204, i64 0, !dbg !308 + %9228 = insertelement <2 x float> %9227, float %9205, i64 1, !dbg !308 + %9229 = fsub <2 x float> %9228, %8901, !dbg !308 + %9230 = fmul <2 x float> %8477, %9208, !dbg !309 + %9231 = fmul <2 x float> %8483, %9211, !dbg !309 + %9232 = fmul <2 x float> %8489, %9214, !dbg !309 + %9233 = fmul <2 x float> %8495, %9217, !dbg !309 + %9234 = fmul <2 x float> %8501, %9220, !dbg !309 + %9235 = fmul <2 x float> %8507, %9223, !dbg !309 + %9236 = fmul <2 x float> %8513, %9226, !dbg !309 + %9237 = fmul <2 x float> %8519, %9229, !dbg !309 + %9238 = insertelement <2 x float> poison, float %9174, i64 0, !dbg !308 + %9239 = insertelement <2 x float> %9238, float %9175, i64 1, !dbg !308 + %9240 = fsub <2 x float> %9239, %8908, !dbg !308 + %9241 = fmul <2 x float> %8474, %9240, !dbg !309 + %9242 = fptrunc <2 x float> %9241 to <2 x bfloat>, !dbg !310 + %9243 = fptrunc <2 x float> %9230 to <2 x bfloat>, !dbg !310 + %9244 = insertelement <2 x float> poison, float %9178, i64 0, !dbg !308 + %9245 = insertelement <2 x float> %9244, float %9179, i64 1, !dbg !308 + %9246 = fsub <2 x float> %9245, %8907, !dbg !308 + %9247 = fmul <2 x float> %8480, %9246, !dbg !309 + %9248 = fptrunc <2 x float> %9247 to <2 x bfloat>, !dbg !310 + %9249 = fptrunc <2 x float> %9231 to <2 x bfloat>, !dbg !310 + %9250 = insertelement <2 x float> poison, float %9182, i64 0, !dbg !308 + %9251 = insertelement <2 x float> %9250, float %9183, i64 1, !dbg !308 + %9252 = fsub <2 x float> %9251, %8906, !dbg !308 + %9253 = fmul <2 x float> %8486, %9252, !dbg !309 + %9254 = fptrunc <2 x float> %9253 to <2 x bfloat>, !dbg !310 + %9255 = fptrunc <2 x float> %9232 to <2 x bfloat>, !dbg !310 + %9256 = insertelement <2 x float> poison, float %9186, i64 0, !dbg !308 + %9257 = insertelement <2 x float> %9256, float %9187, i64 1, !dbg !308 + %9258 = fsub <2 x float> %9257, %8905, !dbg !308 + %9259 = fmul <2 x float> %8492, %9258, !dbg !309 + %9260 = fptrunc <2 x float> %9259 to <2 x bfloat>, !dbg !310 + %9261 = fptrunc <2 x float> %9233 to <2 x bfloat>, !dbg !310 + %9262 = insertelement <2 x float> poison, float %9190, i64 0, !dbg !308 + %9263 = insertelement <2 x float> %9262, float %9191, i64 1, !dbg !308 + %9264 = fsub <2 x float> %9263, %8904, !dbg !308 + %9265 = fmul <2 x float> %8498, %9264, !dbg !309 + %9266 = fptrunc <2 x float> %9265 to <2 x bfloat>, !dbg !310 + %9267 = fptrunc <2 x float> %9234 to <2 x bfloat>, !dbg !310 + %9268 = insertelement <2 x float> poison, float %9194, i64 0, !dbg !308 + %9269 = insertelement <2 x float> %9268, float %9195, i64 1, !dbg !308 + %9270 = fsub <2 x float> %9269, %8903, !dbg !308 + %9271 = fmul <2 x float> %8504, %9270, !dbg !309 + %9272 = fptrunc <2 x float> %9271 to <2 x bfloat>, !dbg !310 + %9273 = fptrunc <2 x float> %9235 to <2 x bfloat>, !dbg !310 + %9274 = insertelement <2 x float> poison, float %9198, i64 0, !dbg !308 + %9275 = insertelement <2 x float> %9274, float %9199, i64 1, !dbg !308 + %9276 = fsub <2 x float> %9275, %8902, !dbg !308 + %9277 = fmul <2 x float> %8510, %9276, !dbg !309 + %9278 = fptrunc <2 x float> %9277 to <2 x bfloat>, !dbg !310 + %9279 = fptrunc <2 x float> %9236 to <2 x bfloat>, !dbg !310 + %9280 = insertelement <2 x float> poison, float %9202, i64 0, !dbg !308 + %9281 = insertelement <2 x float> %9280, float %9203, i64 1, !dbg !308 + %9282 = fsub <2 x float> %9281, %8901, !dbg !308 + %9283 = fmul <2 x float> %8516, %9282, !dbg !309 + %9284 = fptrunc <2 x float> %9283 to <2 x bfloat>, !dbg !310 + %9285 = fptrunc <2 x float> %9237 to <2 x bfloat>, !dbg !310 + %9286 = bitcast <2 x bfloat> %9242 to i32, !dbg !311 + %9287 = bitcast <2 x bfloat> %9243 to i32, !dbg !311 + %9288 = bitcast <2 x bfloat> %9248 to i32, !dbg !311 + %9289 = bitcast <2 x bfloat> %9249 to i32, !dbg !311 + %9290 = bitcast <2 x bfloat> %9254 to i32, !dbg !311 + %9291 = bitcast <2 x bfloat> %9255 to i32, !dbg !311 + %9292 = bitcast <2 x bfloat> %9260 to i32, !dbg !311 + %9293 = bitcast <2 x bfloat> %9261 to i32, !dbg !311 + %9294 = bitcast <2 x bfloat> %9266 to i32, !dbg !311 + %9295 = bitcast <2 x bfloat> %9267 to i32, !dbg !311 + %9296 = bitcast <2 x bfloat> %9272 to i32, !dbg !311 + %9297 = bitcast <2 x bfloat> %9273 to i32, !dbg !311 + %9298 = bitcast <2 x bfloat> %9278 to i32, !dbg !311 + %9299 = bitcast <2 x bfloat> %9279 to i32, !dbg !311 + %9300 = bitcast <2 x bfloat> %9284 to i32, !dbg !311 + %9301 = bitcast <2 x bfloat> %9285 to i32, !dbg !311 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !311 + %9302 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %.pn3761546, float %.pn3741547, float %.pn3721548, float %.pn3701549, float %.pn3681550, float %.pn3661551, float %.pn3641552, float %.pn3621553, float %.pn3601554, float %.pn3581555, float %.pn3561556, float %.pn3541557, float %.pn3521558, float %.pn3501559, float %.pn3481560, float %.pn3461561, float %.pn3441562, float %.pn3421563, float %.pn3401564, float %.pn3381565, float %.pn3361566, float %.pn3341567, float %.pn3321568, float %.pn3301569, float %.pn3281570, float %.pn3261571, float %.pn3241572, float %.pn3221573, float %.pn3201574, float %.pn3181575, float %.pn3161576, float %.pn3141577, float %.pn3121578, float %.pn3101579, float %.pn3081580, float %.pn3061581, float %.pn3041582, float %.pn3021583, float %.pn3001584, float %.pn2981585, float %.pn2961586, float %.pn2941587, float %.pn2921588, float %.pn2901589, float %.pn2881590, float %.pn2861591, float %.pn2841592, float %.pn2821593, float %.pn2801594, float %.pn2781595, float %.pn2761596, float %.pn2741597, float %.pn2721598, float %.pn2701599, float %.pn2681600, float %.pn2661601, float %.pn2641602, float %.pn2621603, float %.pn2601604, float %.pn2581605, float %.pn2561606, float %.pn2541607, float %.pn2521608, float %.pn2501609, i32 %9286, i32 %9287, i32 %9288, i32 %9289, i64 %7841, i1 true) #3, !dbg !311 + %9303 = add i32 %7837, 2048, !dbg !311 + %9304 = lshr exact i32 %9303, 4, !dbg !311 + %9305 = and i32 %9304, 16383, !dbg !311 + %9306 = zext nneg i32 %9305 to i64, !dbg !311 + %9307 = or disjoint i64 %9306, 4611686293338849280, !dbg !311 + %9308 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 0, !dbg !311 + %9309 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 1, !dbg !311 + %9310 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 2, !dbg !311 + %9311 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 3, !dbg !311 + %9312 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 4, !dbg !311 + %9313 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 5, !dbg !311 + %9314 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 6, !dbg !311 + %9315 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 7, !dbg !311 + %9316 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 8, !dbg !311 + %9317 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 9, !dbg !311 + %9318 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 10, !dbg !311 + %9319 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 11, !dbg !311 + %9320 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 12, !dbg !311 + %9321 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 13, !dbg !311 + %9322 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 14, !dbg !311 + %9323 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 15, !dbg !311 + %9324 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 16, !dbg !311 + %9325 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 17, !dbg !311 + %9326 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 18, !dbg !311 + %9327 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 19, !dbg !311 + %9328 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 20, !dbg !311 + %9329 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 21, !dbg !311 + %9330 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 22, !dbg !311 + %9331 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 23, !dbg !311 + %9332 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 24, !dbg !311 + %9333 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 25, !dbg !311 + %9334 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 26, !dbg !311 + %9335 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 27, !dbg !311 + %9336 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 28, !dbg !311 + %9337 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 29, !dbg !311 + %9338 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 30, !dbg !311 + %9339 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 31, !dbg !311 + %9340 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 32, !dbg !311 + %9341 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 33, !dbg !311 + %9342 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 34, !dbg !311 + %9343 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 35, !dbg !311 + %9344 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 36, !dbg !311 + %9345 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 37, !dbg !311 + %9346 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 38, !dbg !311 + %9347 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 39, !dbg !311 + %9348 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 40, !dbg !311 + %9349 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 41, !dbg !311 + %9350 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 42, !dbg !311 + %9351 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 43, !dbg !311 + %9352 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 44, !dbg !311 + %9353 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 45, !dbg !311 + %9354 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 46, !dbg !311 + %9355 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 47, !dbg !311 + %9356 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 48, !dbg !311 + %9357 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 49, !dbg !311 + %9358 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 50, !dbg !311 + %9359 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 51, !dbg !311 + %9360 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 52, !dbg !311 + %9361 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 53, !dbg !311 + %9362 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 54, !dbg !311 + %9363 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 55, !dbg !311 + %9364 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 56, !dbg !311 + %9365 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 57, !dbg !311 + %9366 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 58, !dbg !311 + %9367 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 59, !dbg !311 + %9368 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 60, !dbg !311 + %9369 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 61, !dbg !311 + %9370 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 62, !dbg !311 + %9371 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9302, 63, !dbg !311 + %9372 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %9308, float %9309, float %9310, float %9311, float %9312, float %9313, float %9314, float %9315, float %9316, float %9317, float %9318, float %9319, float %9320, float %9321, float %9322, float %9323, float %9324, float %9325, float %9326, float %9327, float %9328, float %9329, float %9330, float %9331, float %9332, float %9333, float %9334, float %9335, float %9336, float %9337, float %9338, float %9339, float %9340, float %9341, float %9342, float %9343, float %9344, float %9345, float %9346, float %9347, float %9348, float %9349, float %9350, float %9351, float %9352, float %9353, float %9354, float %9355, float %9356, float %9357, float %9358, float %9359, float %9360, float %9361, float %9362, float %9363, float %9364, float %9365, float %9366, float %9367, float %9368, float %9369, float %9370, float %9371, i32 %9290, i32 %9291, i32 %9292, i32 %9293, i64 %9307, i1 true) #3, !dbg !311 + %9373 = add i32 %7837, 4096, !dbg !311 + %9374 = lshr exact i32 %9373, 4, !dbg !311 + %9375 = and i32 %9374, 16383, !dbg !311 + %9376 = zext nneg i32 %9375 to i64, !dbg !311 + %9377 = or disjoint i64 %9376, 4611686293338849280, !dbg !311 + %9378 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 0, !dbg !311 + %9379 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 1, !dbg !311 + %9380 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 2, !dbg !311 + %9381 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 3, !dbg !311 + %9382 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 4, !dbg !311 + %9383 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 5, !dbg !311 + %9384 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 6, !dbg !311 + %9385 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 7, !dbg !311 + %9386 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 8, !dbg !311 + %9387 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 9, !dbg !311 + %9388 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 10, !dbg !311 + %9389 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 11, !dbg !311 + %9390 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 12, !dbg !311 + %9391 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 13, !dbg !311 + %9392 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 14, !dbg !311 + %9393 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 15, !dbg !311 + %9394 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 16, !dbg !311 + %9395 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 17, !dbg !311 + %9396 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 18, !dbg !311 + %9397 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 19, !dbg !311 + %9398 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 20, !dbg !311 + %9399 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 21, !dbg !311 + %9400 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 22, !dbg !311 + %9401 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 23, !dbg !311 + %9402 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 24, !dbg !311 + %9403 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 25, !dbg !311 + %9404 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 26, !dbg !311 + %9405 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 27, !dbg !311 + %9406 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 28, !dbg !311 + %9407 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 29, !dbg !311 + %9408 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 30, !dbg !311 + %9409 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 31, !dbg !311 + %9410 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 32, !dbg !311 + %9411 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 33, !dbg !311 + %9412 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 34, !dbg !311 + %9413 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 35, !dbg !311 + %9414 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 36, !dbg !311 + %9415 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 37, !dbg !311 + %9416 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 38, !dbg !311 + %9417 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 39, !dbg !311 + %9418 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 40, !dbg !311 + %9419 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 41, !dbg !311 + %9420 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 42, !dbg !311 + %9421 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 43, !dbg !311 + %9422 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 44, !dbg !311 + %9423 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 45, !dbg !311 + %9424 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 46, !dbg !311 + %9425 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 47, !dbg !311 + %9426 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 48, !dbg !311 + %9427 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 49, !dbg !311 + %9428 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 50, !dbg !311 + %9429 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 51, !dbg !311 + %9430 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 52, !dbg !311 + %9431 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 53, !dbg !311 + %9432 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 54, !dbg !311 + %9433 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 55, !dbg !311 + %9434 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 56, !dbg !311 + %9435 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 57, !dbg !311 + %9436 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 58, !dbg !311 + %9437 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 59, !dbg !311 + %9438 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 60, !dbg !311 + %9439 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 61, !dbg !311 + %9440 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 62, !dbg !311 + %9441 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9372, 63, !dbg !311 + %9442 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %9378, float %9379, float %9380, float %9381, float %9382, float %9383, float %9384, float %9385, float %9386, float %9387, float %9388, float %9389, float %9390, float %9391, float %9392, float %9393, float %9394, float %9395, float %9396, float %9397, float %9398, float %9399, float %9400, float %9401, float %9402, float %9403, float %9404, float %9405, float %9406, float %9407, float %9408, float %9409, float %9410, float %9411, float %9412, float %9413, float %9414, float %9415, float %9416, float %9417, float %9418, float %9419, float %9420, float %9421, float %9422, float %9423, float %9424, float %9425, float %9426, float %9427, float %9428, float %9429, float %9430, float %9431, float %9432, float %9433, float %9434, float %9435, float %9436, float %9437, float %9438, float %9439, float %9440, float %9441, i32 %9294, i32 %9295, i32 %9296, i32 %9297, i64 %9377, i1 true) #3, !dbg !311 + %9443 = add i32 %7837, 6144, !dbg !311 + %9444 = lshr exact i32 %9443, 4, !dbg !311 + %9445 = and i32 %9444, 16383, !dbg !311 + %9446 = zext nneg i32 %9445 to i64, !dbg !311 + %9447 = or disjoint i64 %9446, 4611686293338849280, !dbg !311 + %9448 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 0, !dbg !311 + %9449 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 1, !dbg !311 + %9450 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 2, !dbg !311 + %9451 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 3, !dbg !311 + %9452 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 4, !dbg !311 + %9453 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 5, !dbg !311 + %9454 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 6, !dbg !311 + %9455 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 7, !dbg !311 + %9456 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 8, !dbg !311 + %9457 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 9, !dbg !311 + %9458 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 10, !dbg !311 + %9459 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 11, !dbg !311 + %9460 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 12, !dbg !311 + %9461 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 13, !dbg !311 + %9462 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 14, !dbg !311 + %9463 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 15, !dbg !311 + %9464 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 16, !dbg !311 + %9465 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 17, !dbg !311 + %9466 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 18, !dbg !311 + %9467 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 19, !dbg !311 + %9468 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 20, !dbg !311 + %9469 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 21, !dbg !311 + %9470 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 22, !dbg !311 + %9471 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 23, !dbg !311 + %9472 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 24, !dbg !311 + %9473 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 25, !dbg !311 + %9474 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 26, !dbg !311 + %9475 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 27, !dbg !311 + %9476 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 28, !dbg !311 + %9477 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 29, !dbg !311 + %9478 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 30, !dbg !311 + %9479 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 31, !dbg !311 + %9480 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 32, !dbg !311 + %9481 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 33, !dbg !311 + %9482 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 34, !dbg !311 + %9483 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 35, !dbg !311 + %9484 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 36, !dbg !311 + %9485 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 37, !dbg !311 + %9486 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 38, !dbg !311 + %9487 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 39, !dbg !311 + %9488 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 40, !dbg !311 + %9489 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 41, !dbg !311 + %9490 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 42, !dbg !311 + %9491 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 43, !dbg !311 + %9492 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 44, !dbg !311 + %9493 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 45, !dbg !311 + %9494 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 46, !dbg !311 + %9495 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 47, !dbg !311 + %9496 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 48, !dbg !311 + %9497 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 49, !dbg !311 + %9498 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 50, !dbg !311 + %9499 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 51, !dbg !311 + %9500 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 52, !dbg !311 + %9501 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 53, !dbg !311 + %9502 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 54, !dbg !311 + %9503 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 55, !dbg !311 + %9504 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 56, !dbg !311 + %9505 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 57, !dbg !311 + %9506 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 58, !dbg !311 + %9507 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 59, !dbg !311 + %9508 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 60, !dbg !311 + %9509 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 61, !dbg !311 + %9510 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 62, !dbg !311 + %9511 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9442, 63, !dbg !311 + %9512 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %9448, float %9449, float %9450, float %9451, float %9452, float %9453, float %9454, float %9455, float %9456, float %9457, float %9458, float %9459, float %9460, float %9461, float %9462, float %9463, float %9464, float %9465, float %9466, float %9467, float %9468, float %9469, float %9470, float %9471, float %9472, float %9473, float %9474, float %9475, float %9476, float %9477, float %9478, float %9479, float %9480, float %9481, float %9482, float %9483, float %9484, float %9485, float %9486, float %9487, float %9488, float %9489, float %9490, float %9491, float %9492, float %9493, float %9494, float %9495, float %9496, float %9497, float %9498, float %9499, float %9500, float %9501, float %9502, float %9503, float %9504, float %9505, float %9506, float %9507, float %9508, float %9509, float %9510, float %9511, i32 %9298, i32 %9299, i32 %9300, i32 %9301, i64 %9447, i1 true) #3, !dbg !311 + %9513 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 0, !dbg !311 + %9514 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 1, !dbg !311 + %9515 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 2, !dbg !311 + %9516 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 3, !dbg !311 + %9517 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 4, !dbg !311 + %9518 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 5, !dbg !311 + %9519 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 6, !dbg !311 + %9520 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 7, !dbg !311 + %9521 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 8, !dbg !311 + %9522 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 9, !dbg !311 + %9523 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 10, !dbg !311 + %9524 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 11, !dbg !311 + %9525 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 12, !dbg !311 + %9526 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 13, !dbg !311 + %9527 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 14, !dbg !311 + %9528 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 15, !dbg !311 + %9529 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 16, !dbg !311 + %9530 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 17, !dbg !311 + %9531 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 18, !dbg !311 + %9532 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 19, !dbg !311 + %9533 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 20, !dbg !311 + %9534 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 21, !dbg !311 + %9535 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 22, !dbg !311 + %9536 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 23, !dbg !311 + %9537 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 24, !dbg !311 + %9538 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 25, !dbg !311 + %9539 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 26, !dbg !311 + %9540 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 27, !dbg !311 + %9541 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 28, !dbg !311 + %9542 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 29, !dbg !311 + %9543 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 30, !dbg !311 + %9544 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 31, !dbg !311 + %9545 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 32, !dbg !311 + %9546 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 33, !dbg !311 + %9547 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 34, !dbg !311 + %9548 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 35, !dbg !311 + %9549 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 36, !dbg !311 + %9550 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 37, !dbg !311 + %9551 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 38, !dbg !311 + %9552 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 39, !dbg !311 + %9553 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 40, !dbg !311 + %9554 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 41, !dbg !311 + %9555 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 42, !dbg !311 + %9556 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 43, !dbg !311 + %9557 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 44, !dbg !311 + %9558 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 45, !dbg !311 + %9559 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 46, !dbg !311 + %9560 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 47, !dbg !311 + %9561 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 48, !dbg !311 + %9562 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 49, !dbg !311 + %9563 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 50, !dbg !311 + %9564 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 51, !dbg !311 + %9565 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 52, !dbg !311 + %9566 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 53, !dbg !311 + %9567 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 54, !dbg !311 + %9568 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 55, !dbg !311 + %9569 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 56, !dbg !311 + %9570 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 57, !dbg !311 + %9571 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 58, !dbg !311 + %9572 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 59, !dbg !311 + %9573 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 60, !dbg !311 + %9574 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 61, !dbg !311 + %9575 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 62, !dbg !311 + %9576 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9512, 63, !dbg !311 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !311 + %9577 = add nuw nsw i32 %7752, 1, !dbg !297 + %9578 = lshr i32 %9577, 1, !dbg !312 + %9579 = zext nneg i32 %9578 to i64, !dbg !313 + %9580 = getelementptr i32, ptr addrspace(1) %4756, i64 %9579, !dbg !313 + %9581 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #3, !dbg !314 + %9582 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b32 { $0 }, [ $1 + 0 ], $2;", "=r,l,l,b"(ptr addrspace(1) %9580, i64 %9581, i1 %7754) #3, !dbg !314 + %9583 = add nuw nsw i32 %9578, 1, !dbg !315 + %9584 = icmp slt i32 %9583, %4760, !dbg !316 + %9585 = getelementptr i8, ptr addrspace(1) %9580, i64 4, !dbg !317 + %9586 = and i1 %7754, %9584, !dbg !297 + %9587 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #3, !dbg !318 + %9588 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b32 { $0 }, [ $1 + 0 ], $2;", "=r,l,l,b"(ptr addrspace(1) %9585, i64 %9587, i1 %9586) #3, !dbg !318 + %9589 = and i32 %7752, 1, !dbg !319 + %9590 = sub i32 %9588, %9582, !dbg !320 + %9591 = shl i32 %9590, 7, !dbg !321 + %9592 = add i32 %9591, -64, !dbg !322 + %9593 = xor i32 %9589, 1, !dbg !323 + %9594 = mul nuw nsw i32 %9592, %9593, !dbg !323 + %9595 = shl nuw nsw i32 %9589, 6, !dbg !324 + %9596 = add i32 %9594, %9595, !dbg !325 + %9597 = shl i32 %9596, 12, !dbg !326 + %9598 = sext i32 %9597 to i64, !dbg !295 + %9599 = getelementptr bfloat, ptr addrspace(1) %.pn5681674, i64 %9598, !dbg !295 + %9600 = getelementptr bfloat, ptr addrspace(1) %.pn5521675, i64 %9598, !dbg !295 + %9601 = getelementptr bfloat, ptr addrspace(1) %.pn5361676, i64 %9598, !dbg !295 + %9602 = getelementptr bfloat, ptr addrspace(1) %.pn5201677, i64 %9598, !dbg !295 + %9603 = shl i32 %9596, 7, !dbg !327 + %9604 = sext i32 %9603 to i64, !dbg !296 + %9605 = getelementptr bfloat, ptr addrspace(1) %.pn6321678, i64 %9604, !dbg !296 + %9606 = getelementptr bfloat, ptr addrspace(1) %.pn6161679, i64 %9604, !dbg !296 + %9607 = getelementptr bfloat, ptr addrspace(1) %.pn6001680, i64 %9604, !dbg !296 + %9608 = getelementptr bfloat, ptr addrspace(1) %.pn5841681, i64 %9604, !dbg !296 + %9609 = add i32 %9596, %.pn6641682, !dbg !328 + %9610 = add i32 %9596, %.pn6601683, !dbg !328 + %9611 = add i32 %9596, %.pn6561684, !dbg !328 + %9612 = add i32 %9596, %.pn6521685, !dbg !328 + %9613 = add i32 %9596, %.pn6481686, !dbg !328 + %9614 = add i32 %9596, %.pn6441687, !dbg !328 + %9615 = add i32 %9596, %.pn6401688, !dbg !328 + %9616 = add i32 %9596, %.pn6361689, !dbg !328 + %9617 = add i32 %7749, 1, !dbg !297 + %9618 = icmp sgt i32 %9617, 1, !dbg !297 + %9619 = select i1 %9618, i32 0, i32 %9617, !dbg !297 + %9620 = add i32 %7751, 1, !dbg !297 + %9621 = icmp sgt i32 %9620, 2, !dbg !297 + %9622 = select i1 %9621, i32 0, i32 %9620, !dbg !297 + %9623 = shl i32 %9622, 13, !dbg !290 + %9624 = getelementptr bfloat, ptr addrspace(3) @global_smem, i32 %9623, !dbg !290 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !290 + %9625 = getelementptr inbounds nuw i8, ptr addrspace(3) %9624, i32 %4793, !dbg !290 + %9626 = select i1 %7753, i32 16, i32 0, !dbg !290 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %9625, ptr addrspace(1) %9599, i32 %9626) #3, !dbg !290 + %9627 = getelementptr inbounds nuw i8, ptr addrspace(3) %9624, i32 %4796, !dbg !290 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %9627, ptr addrspace(1) %9600, i32 %9626) #3, !dbg !290 + %9628 = getelementptr inbounds nuw i8, ptr addrspace(3) %9624, i32 %4798, !dbg !290 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %9628, ptr addrspace(1) %9601, i32 %9626) #3, !dbg !290 + %9629 = getelementptr inbounds nuw i8, ptr addrspace(3) %9624, i32 %4800, !dbg !290 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %9629, ptr addrspace(1) %9602, i32 %9626) #3, !dbg !290 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !290 + %9630 = sext i32 %9609 to i64, !dbg !291 + %9631 = getelementptr float, ptr addrspace(1) %5087, i64 %9630, !dbg !291 + %9632 = sext i32 %9610 to i64, !dbg !291 + %9633 = getelementptr float, ptr addrspace(1) %5087, i64 %9632, !dbg !291 + %9634 = sext i32 %9611 to i64, !dbg !291 + %9635 = getelementptr float, ptr addrspace(1) %5087, i64 %9634, !dbg !291 + %9636 = sext i32 %9612 to i64, !dbg !291 + %9637 = getelementptr float, ptr addrspace(1) %5087, i64 %9636, !dbg !291 + %9638 = sext i32 %9613 to i64, !dbg !291 + %9639 = getelementptr float, ptr addrspace(1) %5087, i64 %9638, !dbg !291 + %9640 = sext i32 %9614 to i64, !dbg !291 + %9641 = getelementptr float, ptr addrspace(1) %5087, i64 %9640, !dbg !291 + %9642 = sext i32 %9615 to i64, !dbg !291 + %9643 = getelementptr float, ptr addrspace(1) %5087, i64 %9642, !dbg !291 + %9644 = sext i32 %9616 to i64, !dbg !291 + %9645 = getelementptr float, ptr addrspace(1) %5087, i64 %9644, !dbg !291 + %9646 = shl i32 %9619, 6, !dbg !292 + %9647 = getelementptr float, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %9646, !dbg !292 + %9648 = getelementptr inbounds nuw i8, ptr addrspace(3) %9647, i32 %4812, !dbg !292 + %9649 = select i1 %7753, i32 8, i32 0, !dbg !292 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) %9648, ptr addrspace(1) %9631, i32 %9649, i1 %4811) #3, !dbg !292 + %9650 = getelementptr inbounds nuw i8, ptr addrspace(3) %9647, i32 %4815, !dbg !292 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %9650, ptr addrspace(1) %9633, i32 %9649, i1 %4811) #3, !dbg !292 + %9651 = getelementptr inbounds nuw i8, ptr addrspace(3) %9647, i32 %4817, !dbg !292 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %9651, ptr addrspace(1) %9635, i32 %9649, i1 %4811) #3, !dbg !292 + %9652 = getelementptr inbounds nuw i8, ptr addrspace(3) %9647, i32 %4819, !dbg !292 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %9652, ptr addrspace(1) %9637, i32 %9649, i1 %4811) #3, !dbg !292 + %9653 = getelementptr inbounds nuw i8, ptr addrspace(3) %9647, i32 %4821, !dbg !292 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %9653, ptr addrspace(1) %9639, i32 %9649, i1 %4811) #3, !dbg !292 + %9654 = getelementptr inbounds nuw i8, ptr addrspace(3) %9647, i32 %4823, !dbg !292 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %9654, ptr addrspace(1) %9641, i32 %9649, i1 %4811) #3, !dbg !292 + %9655 = getelementptr inbounds nuw i8, ptr addrspace(3) %9647, i32 %4825, !dbg !292 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %9655, ptr addrspace(1) %9643, i32 %9649, i1 %4811) #3, !dbg !292 + %9656 = getelementptr inbounds nuw i8, ptr addrspace(3) %9647, i32 %4827, !dbg !292 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %9656, ptr addrspace(1) %9645, i32 %9649, i1 %4811) #3, !dbg !292 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !292 + %9657 = getelementptr bfloat, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %9623, !dbg !290 + %9658 = getelementptr inbounds nuw i8, ptr addrspace(3) %9657, i32 %4793, !dbg !290 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %9658, ptr addrspace(1) %9605, i32 %9626) #3, !dbg !290 + %9659 = getelementptr inbounds nuw i8, ptr addrspace(3) %9657, i32 %4796, !dbg !290 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %9659, ptr addrspace(1) %9606, i32 %9626) #3, !dbg !290 + %9660 = getelementptr inbounds nuw i8, ptr addrspace(3) %9657, i32 %4798, !dbg !290 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %9660, ptr addrspace(1) %9607, i32 %9626) #3, !dbg !290 + %9661 = getelementptr inbounds nuw i8, ptr addrspace(3) %9657, i32 %4800, !dbg !290 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %9661, ptr addrspace(1) %9608, i32 %9626) #3, !dbg !290 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !290 + %9662 = getelementptr float, ptr addrspace(1) %5088, i64 %9630, !dbg !293 + %9663 = getelementptr float, ptr addrspace(1) %5088, i64 %9632, !dbg !293 + %9664 = getelementptr float, ptr addrspace(1) %5088, i64 %9634, !dbg !293 + %9665 = getelementptr float, ptr addrspace(1) %5088, i64 %9636, !dbg !293 + %9666 = getelementptr float, ptr addrspace(1) %5088, i64 %9638, !dbg !293 + %9667 = getelementptr float, ptr addrspace(1) %5088, i64 %9640, !dbg !293 + %9668 = getelementptr float, ptr addrspace(1) %5088, i64 %9642, !dbg !293 + %9669 = getelementptr float, ptr addrspace(1) %5088, i64 %9644, !dbg !293 + %9670 = getelementptr float, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %9646, !dbg !294 + %9671 = getelementptr inbounds nuw i8, ptr addrspace(3) %9670, i32 %4812, !dbg !294 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) %9671, ptr addrspace(1) %9662, i32 %9649, i1 %4811) #3, !dbg !294 + %9672 = getelementptr inbounds nuw i8, ptr addrspace(3) %9670, i32 %4815, !dbg !294 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %9672, ptr addrspace(1) %9663, i32 %9649, i1 %4811) #3, !dbg !294 + %9673 = getelementptr inbounds nuw i8, ptr addrspace(3) %9670, i32 %4817, !dbg !294 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %9673, ptr addrspace(1) %9664, i32 %9649, i1 %4811) #3, !dbg !294 + %9674 = getelementptr inbounds nuw i8, ptr addrspace(3) %9670, i32 %4819, !dbg !294 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %9674, ptr addrspace(1) %9665, i32 %9649, i1 %4811) #3, !dbg !294 + %9675 = getelementptr inbounds nuw i8, ptr addrspace(3) %9670, i32 %4821, !dbg !294 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %9675, ptr addrspace(1) %9666, i32 %9649, i1 %4811) #3, !dbg !294 + %9676 = getelementptr inbounds nuw i8, ptr addrspace(3) %9670, i32 %4823, !dbg !294 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %9676, ptr addrspace(1) %9667, i32 %9649, i1 %4811) #3, !dbg !294 + %9677 = getelementptr inbounds nuw i8, ptr addrspace(3) %9670, i32 %4825, !dbg !294 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %9677, ptr addrspace(1) %9668, i32 %9649, i1 %4811) #3, !dbg !294 + %9678 = getelementptr inbounds nuw i8, ptr addrspace(3) %9670, i32 %4827, !dbg !294 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x8, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %9678, ptr addrspace(1) %9669, i32 %9649, i1 %4811) #3, !dbg !294 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !294 + %exitcond2122.not = icmp eq i32 %9577, %smax2121, !dbg !297 + br i1 %exitcond2122.not, label %._crit_edge1692, label %.lr.ph1691, !dbg !297 + +._crit_edge1692: ; preds = %__nv_exp2f.exit1219, %._crit_edge + %.pn376.lcssa = phi float [ %7628, %._crit_edge ], [ %9513, %__nv_exp2f.exit1219 ] + %.pn374.lcssa = phi float [ %7629, %._crit_edge ], [ %9514, %__nv_exp2f.exit1219 ] + %.pn372.lcssa = phi float [ %7630, %._crit_edge ], [ %9515, %__nv_exp2f.exit1219 ] + %.pn370.lcssa = phi float [ %7631, %._crit_edge ], [ %9516, %__nv_exp2f.exit1219 ] + %.pn368.lcssa = phi float [ %7632, %._crit_edge ], [ %9517, %__nv_exp2f.exit1219 ] + %.pn366.lcssa = phi float [ %7633, %._crit_edge ], [ %9518, %__nv_exp2f.exit1219 ] + %.pn364.lcssa = phi float [ %7634, %._crit_edge ], [ %9519, %__nv_exp2f.exit1219 ] + %.pn362.lcssa = phi float [ %7635, %._crit_edge ], [ %9520, %__nv_exp2f.exit1219 ] + %.pn360.lcssa = phi float [ %7636, %._crit_edge ], [ %9521, %__nv_exp2f.exit1219 ] + %.pn358.lcssa = phi float [ %7637, %._crit_edge ], [ %9522, %__nv_exp2f.exit1219 ] + %.pn356.lcssa = phi float [ %7638, %._crit_edge ], [ %9523, %__nv_exp2f.exit1219 ] + %.pn354.lcssa = phi float [ %7639, %._crit_edge ], [ %9524, %__nv_exp2f.exit1219 ] + %.pn352.lcssa = phi float [ %7640, %._crit_edge ], [ %9525, %__nv_exp2f.exit1219 ] + %.pn350.lcssa = phi float [ %7641, %._crit_edge ], [ %9526, %__nv_exp2f.exit1219 ] + %.pn348.lcssa = phi float [ %7642, %._crit_edge ], [ %9527, %__nv_exp2f.exit1219 ] + %.pn346.lcssa = phi float [ %7643, %._crit_edge ], [ %9528, %__nv_exp2f.exit1219 ] + %.pn344.lcssa = phi float [ %7644, %._crit_edge ], [ %9529, %__nv_exp2f.exit1219 ] + %.pn342.lcssa = phi float [ %7645, %._crit_edge ], [ %9530, %__nv_exp2f.exit1219 ] + %.pn340.lcssa = phi float [ %7646, %._crit_edge ], [ %9531, %__nv_exp2f.exit1219 ] + %.pn338.lcssa = phi float [ %7647, %._crit_edge ], [ %9532, %__nv_exp2f.exit1219 ] + %.pn336.lcssa = phi float [ %7648, %._crit_edge ], [ %9533, %__nv_exp2f.exit1219 ] + %.pn334.lcssa = phi float [ %7649, %._crit_edge ], [ %9534, %__nv_exp2f.exit1219 ] + %.pn332.lcssa = phi float [ %7650, %._crit_edge ], [ %9535, %__nv_exp2f.exit1219 ] + %.pn330.lcssa = phi float [ %7651, %._crit_edge ], [ %9536, %__nv_exp2f.exit1219 ] + %.pn328.lcssa = phi float [ %7652, %._crit_edge ], [ %9537, %__nv_exp2f.exit1219 ] + %.pn326.lcssa = phi float [ %7653, %._crit_edge ], [ %9538, %__nv_exp2f.exit1219 ] + %.pn324.lcssa = phi float [ %7654, %._crit_edge ], [ %9539, %__nv_exp2f.exit1219 ] + %.pn322.lcssa = phi float [ %7655, %._crit_edge ], [ %9540, %__nv_exp2f.exit1219 ] + %.pn320.lcssa = phi float [ %7656, %._crit_edge ], [ %9541, %__nv_exp2f.exit1219 ] + %.pn318.lcssa = phi float [ %7657, %._crit_edge ], [ %9542, %__nv_exp2f.exit1219 ] + %.pn316.lcssa = phi float [ %7658, %._crit_edge ], [ %9543, %__nv_exp2f.exit1219 ] + %.pn314.lcssa = phi float [ %7659, %._crit_edge ], [ %9544, %__nv_exp2f.exit1219 ] + %.pn312.lcssa = phi float [ %7660, %._crit_edge ], [ %9545, %__nv_exp2f.exit1219 ] + %.pn310.lcssa = phi float [ %7661, %._crit_edge ], [ %9546, %__nv_exp2f.exit1219 ] + %.pn308.lcssa = phi float [ %7662, %._crit_edge ], [ %9547, %__nv_exp2f.exit1219 ] + %.pn306.lcssa = phi float [ %7663, %._crit_edge ], [ %9548, %__nv_exp2f.exit1219 ] + %.pn304.lcssa = phi float [ %7664, %._crit_edge ], [ %9549, %__nv_exp2f.exit1219 ] + %.pn302.lcssa = phi float [ %7665, %._crit_edge ], [ %9550, %__nv_exp2f.exit1219 ] + %.pn300.lcssa = phi float [ %7666, %._crit_edge ], [ %9551, %__nv_exp2f.exit1219 ] + %.pn298.lcssa = phi float [ %7667, %._crit_edge ], [ %9552, %__nv_exp2f.exit1219 ] + %.pn296.lcssa = phi float [ %7668, %._crit_edge ], [ %9553, %__nv_exp2f.exit1219 ] + %.pn294.lcssa = phi float [ %7669, %._crit_edge ], [ %9554, %__nv_exp2f.exit1219 ] + %.pn292.lcssa = phi float [ %7670, %._crit_edge ], [ %9555, %__nv_exp2f.exit1219 ] + %.pn290.lcssa = phi float [ %7671, %._crit_edge ], [ %9556, %__nv_exp2f.exit1219 ] + %.pn288.lcssa = phi float [ %7672, %._crit_edge ], [ %9557, %__nv_exp2f.exit1219 ] + %.pn286.lcssa = phi float [ %7673, %._crit_edge ], [ %9558, %__nv_exp2f.exit1219 ] + %.pn284.lcssa = phi float [ %7674, %._crit_edge ], [ %9559, %__nv_exp2f.exit1219 ] + %.pn282.lcssa = phi float [ %7675, %._crit_edge ], [ %9560, %__nv_exp2f.exit1219 ] + %.pn280.lcssa = phi float [ %7676, %._crit_edge ], [ %9561, %__nv_exp2f.exit1219 ] + %.pn278.lcssa = phi float [ %7677, %._crit_edge ], [ %9562, %__nv_exp2f.exit1219 ] + %.pn276.lcssa = phi float [ %7678, %._crit_edge ], [ %9563, %__nv_exp2f.exit1219 ] + %.pn274.lcssa = phi float [ %7679, %._crit_edge ], [ %9564, %__nv_exp2f.exit1219 ] + %.pn272.lcssa = phi float [ %7680, %._crit_edge ], [ %9565, %__nv_exp2f.exit1219 ] + %.pn270.lcssa = phi float [ %7681, %._crit_edge ], [ %9566, %__nv_exp2f.exit1219 ] + %.pn268.lcssa = phi float [ %7682, %._crit_edge ], [ %9567, %__nv_exp2f.exit1219 ] + %.pn266.lcssa = phi float [ %7683, %._crit_edge ], [ %9568, %__nv_exp2f.exit1219 ] + %.pn264.lcssa = phi float [ %7684, %._crit_edge ], [ %9569, %__nv_exp2f.exit1219 ] + %.pn262.lcssa = phi float [ %7685, %._crit_edge ], [ %9570, %__nv_exp2f.exit1219 ] + %.pn260.lcssa = phi float [ %7686, %._crit_edge ], [ %9571, %__nv_exp2f.exit1219 ] + %.pn258.lcssa = phi float [ %7687, %._crit_edge ], [ %9572, %__nv_exp2f.exit1219 ] + %.pn256.lcssa = phi float [ %7688, %._crit_edge ], [ %9573, %__nv_exp2f.exit1219 ] + %.pn254.lcssa = phi float [ %7689, %._crit_edge ], [ %9574, %__nv_exp2f.exit1219 ] + %.pn252.lcssa = phi float [ %7690, %._crit_edge ], [ %9575, %__nv_exp2f.exit1219 ] + %.pn250.lcssa = phi float [ %7691, %._crit_edge ], [ %9576, %__nv_exp2f.exit1219 ] + %.pn504.lcssa = phi float [ %7564, %._crit_edge ], [ %8753, %__nv_exp2f.exit1219 ] + %.pn502.lcssa = phi float [ %7565, %._crit_edge ], [ %8754, %__nv_exp2f.exit1219 ] + %.pn500.lcssa = phi float [ %7566, %._crit_edge ], [ %8755, %__nv_exp2f.exit1219 ] + %.pn498.lcssa = phi float [ %7567, %._crit_edge ], [ %8756, %__nv_exp2f.exit1219 ] + %.pn496.lcssa = phi float [ %7568, %._crit_edge ], [ %8757, %__nv_exp2f.exit1219 ] + %.pn494.lcssa = phi float [ %7569, %._crit_edge ], [ %8758, %__nv_exp2f.exit1219 ] + %.pn492.lcssa = phi float [ %7570, %._crit_edge ], [ %8759, %__nv_exp2f.exit1219 ] + %.pn490.lcssa = phi float [ %7571, %._crit_edge ], [ %8760, %__nv_exp2f.exit1219 ] + %.pn488.lcssa = phi float [ %7572, %._crit_edge ], [ %8761, %__nv_exp2f.exit1219 ] + %.pn486.lcssa = phi float [ %7573, %._crit_edge ], [ %8762, %__nv_exp2f.exit1219 ] + %.pn484.lcssa = phi float [ %7574, %._crit_edge ], [ %8763, %__nv_exp2f.exit1219 ] + %.pn482.lcssa = phi float [ %7575, %._crit_edge ], [ %8764, %__nv_exp2f.exit1219 ] + %.pn480.lcssa = phi float [ %7576, %._crit_edge ], [ %8765, %__nv_exp2f.exit1219 ] + %.pn478.lcssa = phi float [ %7577, %._crit_edge ], [ %8766, %__nv_exp2f.exit1219 ] + %.pn476.lcssa = phi float [ %7578, %._crit_edge ], [ %8767, %__nv_exp2f.exit1219 ] + %.pn474.lcssa = phi float [ %7579, %._crit_edge ], [ %8768, %__nv_exp2f.exit1219 ] + %.pn472.lcssa = phi float [ %7580, %._crit_edge ], [ %8769, %__nv_exp2f.exit1219 ] + %.pn470.lcssa = phi float [ %7581, %._crit_edge ], [ %8770, %__nv_exp2f.exit1219 ] + %.pn468.lcssa = phi float [ %7582, %._crit_edge ], [ %8771, %__nv_exp2f.exit1219 ] + %.pn466.lcssa = phi float [ %7583, %._crit_edge ], [ %8772, %__nv_exp2f.exit1219 ] + %.pn464.lcssa = phi float [ %7584, %._crit_edge ], [ %8773, %__nv_exp2f.exit1219 ] + %.pn462.lcssa = phi float [ %7585, %._crit_edge ], [ %8774, %__nv_exp2f.exit1219 ] + %.pn460.lcssa = phi float [ %7586, %._crit_edge ], [ %8775, %__nv_exp2f.exit1219 ] + %.pn458.lcssa = phi float [ %7587, %._crit_edge ], [ %8776, %__nv_exp2f.exit1219 ] + %.pn456.lcssa = phi float [ %7588, %._crit_edge ], [ %8777, %__nv_exp2f.exit1219 ] + %.pn454.lcssa = phi float [ %7589, %._crit_edge ], [ %8778, %__nv_exp2f.exit1219 ] + %.pn452.lcssa = phi float [ %7590, %._crit_edge ], [ %8779, %__nv_exp2f.exit1219 ] + %.pn450.lcssa = phi float [ %7591, %._crit_edge ], [ %8780, %__nv_exp2f.exit1219 ] + %.pn448.lcssa = phi float [ %7592, %._crit_edge ], [ %8781, %__nv_exp2f.exit1219 ] + %.pn446.lcssa = phi float [ %7593, %._crit_edge ], [ %8782, %__nv_exp2f.exit1219 ] + %.pn444.lcssa = phi float [ %7594, %._crit_edge ], [ %8783, %__nv_exp2f.exit1219 ] + %.pn442.lcssa = phi float [ %7595, %._crit_edge ], [ %8784, %__nv_exp2f.exit1219 ] + %.pn440.lcssa = phi float [ %7596, %._crit_edge ], [ %8785, %__nv_exp2f.exit1219 ] + %.pn438.lcssa = phi float [ %7597, %._crit_edge ], [ %8786, %__nv_exp2f.exit1219 ] + %.pn436.lcssa = phi float [ %7598, %._crit_edge ], [ %8787, %__nv_exp2f.exit1219 ] + %.pn434.lcssa = phi float [ %7599, %._crit_edge ], [ %8788, %__nv_exp2f.exit1219 ] + %.pn432.lcssa = phi float [ %7600, %._crit_edge ], [ %8789, %__nv_exp2f.exit1219 ] + %.pn430.lcssa = phi float [ %7601, %._crit_edge ], [ %8790, %__nv_exp2f.exit1219 ] + %.pn428.lcssa = phi float [ %7602, %._crit_edge ], [ %8791, %__nv_exp2f.exit1219 ] + %.pn426.lcssa = phi float [ %7603, %._crit_edge ], [ %8792, %__nv_exp2f.exit1219 ] + %.pn424.lcssa = phi float [ %7604, %._crit_edge ], [ %8793, %__nv_exp2f.exit1219 ] + %.pn422.lcssa = phi float [ %7605, %._crit_edge ], [ %8794, %__nv_exp2f.exit1219 ] + %.pn420.lcssa = phi float [ %7606, %._crit_edge ], [ %8795, %__nv_exp2f.exit1219 ] + %.pn418.lcssa = phi float [ %7607, %._crit_edge ], [ %8796, %__nv_exp2f.exit1219 ] + %.pn416.lcssa = phi float [ %7608, %._crit_edge ], [ %8797, %__nv_exp2f.exit1219 ] + %.pn414.lcssa = phi float [ %7609, %._crit_edge ], [ %8798, %__nv_exp2f.exit1219 ] + %.pn412.lcssa = phi float [ %7610, %._crit_edge ], [ %8799, %__nv_exp2f.exit1219 ] + %.pn410.lcssa = phi float [ %7611, %._crit_edge ], [ %8800, %__nv_exp2f.exit1219 ] + %.pn408.lcssa = phi float [ %7612, %._crit_edge ], [ %8801, %__nv_exp2f.exit1219 ] + %.pn406.lcssa = phi float [ %7613, %._crit_edge ], [ %8802, %__nv_exp2f.exit1219 ] + %.pn404.lcssa = phi float [ %7614, %._crit_edge ], [ %8803, %__nv_exp2f.exit1219 ] + %.pn402.lcssa = phi float [ %7615, %._crit_edge ], [ %8804, %__nv_exp2f.exit1219 ] + %.pn400.lcssa = phi float [ %7616, %._crit_edge ], [ %8805, %__nv_exp2f.exit1219 ] + %.pn398.lcssa = phi float [ %7617, %._crit_edge ], [ %8806, %__nv_exp2f.exit1219 ] + %.pn396.lcssa = phi float [ %7618, %._crit_edge ], [ %8807, %__nv_exp2f.exit1219 ] + %.pn394.lcssa = phi float [ %7619, %._crit_edge ], [ %8808, %__nv_exp2f.exit1219 ] + %.pn392.lcssa = phi float [ %7620, %._crit_edge ], [ %8809, %__nv_exp2f.exit1219 ] + %.pn390.lcssa = phi float [ %7621, %._crit_edge ], [ %8810, %__nv_exp2f.exit1219 ] + %.pn388.lcssa = phi float [ %7622, %._crit_edge ], [ %8811, %__nv_exp2f.exit1219 ] + %.pn386.lcssa = phi float [ %7623, %._crit_edge ], [ %8812, %__nv_exp2f.exit1219 ] + %.pn384.lcssa = phi float [ %7624, %._crit_edge ], [ %8813, %__nv_exp2f.exit1219 ] + %.pn382.lcssa = phi float [ %7625, %._crit_edge ], [ %8814, %__nv_exp2f.exit1219 ] + %.pn380.lcssa = phi float [ %7626, %._crit_edge ], [ %8815, %__nv_exp2f.exit1219 ] + %.pn378.lcssa = phi float [ %7627, %._crit_edge ], [ %8816, %__nv_exp2f.exit1219 ] + %9679 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "// wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63,$64,$65,$66,$67,$68,$69,$70,$71,$72,$73,$74,$75,$76,$77,$78,$79,$80,$81,$82,$83,$84,$85,$86,$87,$88,$89,$90,$91,$92,$93,$94,$95,$96,$97,$98,$99,$100,$101,$102,$103,$104,$105,$106,$107,$108,$109,$110,$111,$112,$113,$114,$115,$116,$117,$118,$119,$120,$121,$122,$123,$124,$125,$126,$127\0A\09wgmma.wait_group.sync.aligned 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127"(float %.pn504.lcssa, float %.pn502.lcssa, float %.pn500.lcssa, float %.pn498.lcssa, float %.pn496.lcssa, float %.pn494.lcssa, float %.pn492.lcssa, float %.pn490.lcssa, float %.pn488.lcssa, float %.pn486.lcssa, float %.pn484.lcssa, float %.pn482.lcssa, float %.pn480.lcssa, float %.pn478.lcssa, float %.pn476.lcssa, float %.pn474.lcssa, float %.pn472.lcssa, float %.pn470.lcssa, float %.pn468.lcssa, float %.pn466.lcssa, float %.pn464.lcssa, float %.pn462.lcssa, float %.pn460.lcssa, float %.pn458.lcssa, float %.pn456.lcssa, float %.pn454.lcssa, float %.pn452.lcssa, float %.pn450.lcssa, float %.pn448.lcssa, float %.pn446.lcssa, float %.pn444.lcssa, float %.pn442.lcssa, float %.pn440.lcssa, float %.pn438.lcssa, float %.pn436.lcssa, float %.pn434.lcssa, float %.pn432.lcssa, float %.pn430.lcssa, float %.pn428.lcssa, float %.pn426.lcssa, float %.pn424.lcssa, float %.pn422.lcssa, float %.pn420.lcssa, float %.pn418.lcssa, float %.pn416.lcssa, float %.pn414.lcssa, float %.pn412.lcssa, float %.pn410.lcssa, float %.pn408.lcssa, float %.pn406.lcssa, float %.pn404.lcssa, float %.pn402.lcssa, float %.pn400.lcssa, float %.pn398.lcssa, float %.pn396.lcssa, float %.pn394.lcssa, float %.pn392.lcssa, float %.pn390.lcssa, float %.pn388.lcssa, float %.pn386.lcssa, float %.pn384.lcssa, float %.pn382.lcssa, float %.pn380.lcssa, float %.pn378.lcssa, float %.pn376.lcssa, float %.pn374.lcssa, float %.pn372.lcssa, float %.pn370.lcssa, float %.pn368.lcssa, float %.pn366.lcssa, float %.pn364.lcssa, float %.pn362.lcssa, float %.pn360.lcssa, float %.pn358.lcssa, float %.pn356.lcssa, float %.pn354.lcssa, float %.pn352.lcssa, float %.pn350.lcssa, float %.pn348.lcssa, float %.pn346.lcssa, float %.pn344.lcssa, float %.pn342.lcssa, float %.pn340.lcssa, float %.pn338.lcssa, float %.pn336.lcssa, float %.pn334.lcssa, float %.pn332.lcssa, float %.pn330.lcssa, float %.pn328.lcssa, float %.pn326.lcssa, float %.pn324.lcssa, float %.pn322.lcssa, float %.pn320.lcssa, float %.pn318.lcssa, float %.pn316.lcssa, float %.pn314.lcssa, float %.pn312.lcssa, float %.pn310.lcssa, float %.pn308.lcssa, float %.pn306.lcssa, float %.pn304.lcssa, float %.pn302.lcssa, float %.pn300.lcssa, float %.pn298.lcssa, float %.pn296.lcssa, float %.pn294.lcssa, float %.pn292.lcssa, float %.pn290.lcssa, float %.pn288.lcssa, float %.pn286.lcssa, float %.pn284.lcssa, float %.pn282.lcssa, float %.pn280.lcssa, float %.pn278.lcssa, float %.pn276.lcssa, float %.pn274.lcssa, float %.pn272.lcssa, float %.pn270.lcssa, float %.pn268.lcssa, float %.pn266.lcssa, float %.pn264.lcssa, float %.pn262.lcssa, float %.pn260.lcssa, float %.pn258.lcssa, float %.pn256.lcssa, float %.pn254.lcssa, float %.pn252.lcssa, float %.pn250.lcssa) #3, !dbg !297 + %9680 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 0, !dbg !297 + %9681 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 1, !dbg !297 + %9682 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 2, !dbg !297 + %9683 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 3, !dbg !297 + %9684 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 4, !dbg !297 + %9685 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 5, !dbg !297 + %9686 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 6, !dbg !297 + %9687 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 7, !dbg !297 + %9688 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 8, !dbg !297 + %9689 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 9, !dbg !297 + %9690 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 10, !dbg !297 + %9691 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 11, !dbg !297 + %9692 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 12, !dbg !297 + %9693 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 13, !dbg !297 + %9694 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 14, !dbg !297 + %9695 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 15, !dbg !297 + %9696 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 16, !dbg !297 + %9697 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 17, !dbg !297 + %9698 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 18, !dbg !297 + %9699 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 19, !dbg !297 + %9700 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 20, !dbg !297 + %9701 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 21, !dbg !297 + %9702 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 22, !dbg !297 + %9703 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 23, !dbg !297 + %9704 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 24, !dbg !297 + %9705 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 25, !dbg !297 + %9706 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 26, !dbg !297 + %9707 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 27, !dbg !297 + %9708 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 28, !dbg !297 + %9709 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 29, !dbg !297 + %9710 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 30, !dbg !297 + %9711 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 31, !dbg !297 + %9712 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 32, !dbg !297 + %9713 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 33, !dbg !297 + %9714 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 34, !dbg !297 + %9715 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 35, !dbg !297 + %9716 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 36, !dbg !297 + %9717 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 37, !dbg !297 + %9718 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 38, !dbg !297 + %9719 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 39, !dbg !297 + %9720 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 40, !dbg !297 + %9721 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 41, !dbg !297 + %9722 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 42, !dbg !297 + %9723 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 43, !dbg !297 + %9724 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 44, !dbg !297 + %9725 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 45, !dbg !297 + %9726 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 46, !dbg !297 + %9727 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 47, !dbg !297 + %9728 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 48, !dbg !297 + %9729 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 49, !dbg !297 + %9730 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 50, !dbg !297 + %9731 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 51, !dbg !297 + %9732 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 52, !dbg !297 + %9733 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 53, !dbg !297 + %9734 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 54, !dbg !297 + %9735 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 55, !dbg !297 + %9736 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 56, !dbg !297 + %9737 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 57, !dbg !297 + %9738 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 58, !dbg !297 + %9739 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 59, !dbg !297 + %9740 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 60, !dbg !297 + %9741 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 61, !dbg !297 + %9742 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 62, !dbg !297 + %9743 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 63, !dbg !297 + %9744 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 64, !dbg !297 + %9745 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 65, !dbg !297 + %9746 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 66, !dbg !297 + %9747 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 67, !dbg !297 + %9748 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 68, !dbg !297 + %9749 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 69, !dbg !297 + %9750 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 70, !dbg !297 + %9751 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 71, !dbg !297 + %9752 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 72, !dbg !297 + %9753 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 73, !dbg !297 + %9754 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 74, !dbg !297 + %9755 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 75, !dbg !297 + %9756 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 76, !dbg !297 + %9757 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 77, !dbg !297 + %9758 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 78, !dbg !297 + %9759 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 79, !dbg !297 + %9760 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 80, !dbg !297 + %9761 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 81, !dbg !297 + %9762 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 82, !dbg !297 + %9763 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 83, !dbg !297 + %9764 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 84, !dbg !297 + %9765 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 85, !dbg !297 + %9766 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 86, !dbg !297 + %9767 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 87, !dbg !297 + %9768 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 88, !dbg !297 + %9769 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 89, !dbg !297 + %9770 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 90, !dbg !297 + %9771 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 91, !dbg !297 + %9772 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 92, !dbg !297 + %9773 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 93, !dbg !297 + %9774 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 94, !dbg !297 + %9775 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 95, !dbg !297 + %9776 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 96, !dbg !297 + %9777 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 97, !dbg !297 + %9778 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 98, !dbg !297 + %9779 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 99, !dbg !297 + %9780 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 100, !dbg !297 + %9781 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 101, !dbg !297 + %9782 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 102, !dbg !297 + %9783 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 103, !dbg !297 + %9784 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 104, !dbg !297 + %9785 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 105, !dbg !297 + %9786 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 106, !dbg !297 + %9787 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 107, !dbg !297 + %9788 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 108, !dbg !297 + %9789 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 109, !dbg !297 + %9790 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 110, !dbg !297 + %9791 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 111, !dbg !297 + %9792 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 112, !dbg !297 + %9793 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 113, !dbg !297 + %9794 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 114, !dbg !297 + %9795 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 115, !dbg !297 + %9796 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 116, !dbg !297 + %9797 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 117, !dbg !297 + %9798 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 118, !dbg !297 + %9799 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 119, !dbg !297 + %9800 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 120, !dbg !297 + %9801 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 121, !dbg !297 + %9802 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 122, !dbg !297 + %9803 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 123, !dbg !297 + %9804 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 124, !dbg !297 + %9805 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 125, !dbg !297 + %9806 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 126, !dbg !297 + %9807 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9679, 127, !dbg !297 + tail call void @llvm.nvvm.cp.async.wait.group(i32 0), !dbg !297 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !297 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1, !dbg !227 + %exitcond2123.not = icmp eq i64 %indvars.iv.next, 4, !dbg !227 + br i1 %exitcond2123.not, label %9808, label %4945, !dbg !227 + +9808: ; preds = %._crit_edge1692 + %9809 = getelementptr bfloat, ptr addrspace(1) %34, i64 %4480, !dbg !329 + %9810 = getelementptr bfloat, ptr addrspace(1) %34, i64 %4482, !dbg !329 + %9811 = getelementptr bfloat, ptr addrspace(1) %34, i64 %4484, !dbg !329 + %9812 = getelementptr bfloat, ptr addrspace(1) %34, i64 %4486, !dbg !329 + %9813 = getelementptr bfloat, ptr addrspace(1) %34, i64 %4488, !dbg !329 + %9814 = getelementptr bfloat, ptr addrspace(1) %34, i64 %4490, !dbg !329 + %9815 = getelementptr bfloat, ptr addrspace(1) %34, i64 %4492, !dbg !329 + %9816 = getelementptr bfloat, ptr addrspace(1) %34, i64 %4494, !dbg !329 + %9817 = getelementptr bfloat, ptr addrspace(1) %9809, i64 %4498, !dbg !330 + %9818 = getelementptr bfloat, ptr addrspace(1) %9810, i64 %4498, !dbg !330 + %9819 = getelementptr bfloat, ptr addrspace(1) %9811, i64 %4498, !dbg !330 + %9820 = getelementptr bfloat, ptr addrspace(1) %9812, i64 %4498, !dbg !330 + %9821 = getelementptr bfloat, ptr addrspace(1) %9813, i64 %4498, !dbg !330 + %9822 = getelementptr bfloat, ptr addrspace(1) %9814, i64 %4498, !dbg !330 + %9823 = getelementptr bfloat, ptr addrspace(1) %9815, i64 %4498, !dbg !330 + %9824 = getelementptr bfloat, ptr addrspace(1) %9816, i64 %4498, !dbg !330 + %9825 = insertelement <2 x float> poison, float %9680, i64 0, !dbg !331 + %9826 = insertelement <2 x float> %9825, float %9681, i64 1, !dbg !331 + %9827 = fptrunc <2 x float> %9826 to <2 x bfloat>, !dbg !331 + %9828 = insertelement <2 x float> poison, float %9682, i64 0, !dbg !331 + %9829 = insertelement <2 x float> %9828, float %9683, i64 1, !dbg !331 + %9830 = fptrunc <2 x float> %9829 to <2 x bfloat>, !dbg !331 + %9831 = insertelement <2 x float> poison, float %9684, i64 0, !dbg !331 + %9832 = insertelement <2 x float> %9831, float %9685, i64 1, !dbg !331 + %9833 = fptrunc <2 x float> %9832 to <2 x bfloat>, !dbg !331 + %9834 = insertelement <2 x float> poison, float %9686, i64 0, !dbg !331 + %9835 = insertelement <2 x float> %9834, float %9687, i64 1, !dbg !331 + %9836 = fptrunc <2 x float> %9835 to <2 x bfloat>, !dbg !331 + %9837 = insertelement <2 x float> poison, float %9688, i64 0, !dbg !331 + %9838 = insertelement <2 x float> %9837, float %9689, i64 1, !dbg !331 + %9839 = fptrunc <2 x float> %9838 to <2 x bfloat>, !dbg !331 + %9840 = insertelement <2 x float> poison, float %9690, i64 0, !dbg !331 + %9841 = insertelement <2 x float> %9840, float %9691, i64 1, !dbg !331 + %9842 = fptrunc <2 x float> %9841 to <2 x bfloat>, !dbg !331 + %9843 = insertelement <2 x float> poison, float %9692, i64 0, !dbg !331 + %9844 = insertelement <2 x float> %9843, float %9693, i64 1, !dbg !331 + %9845 = fptrunc <2 x float> %9844 to <2 x bfloat>, !dbg !331 + %9846 = insertelement <2 x float> poison, float %9694, i64 0, !dbg !331 + %9847 = insertelement <2 x float> %9846, float %9695, i64 1, !dbg !331 + %9848 = fptrunc <2 x float> %9847 to <2 x bfloat>, !dbg !331 + %9849 = insertelement <2 x float> poison, float %9696, i64 0, !dbg !331 + %9850 = insertelement <2 x float> %9849, float %9697, i64 1, !dbg !331 + %9851 = fptrunc <2 x float> %9850 to <2 x bfloat>, !dbg !331 + %9852 = insertelement <2 x float> poison, float %9698, i64 0, !dbg !331 + %9853 = insertelement <2 x float> %9852, float %9699, i64 1, !dbg !331 + %9854 = fptrunc <2 x float> %9853 to <2 x bfloat>, !dbg !331 + %9855 = insertelement <2 x float> poison, float %9700, i64 0, !dbg !331 + %9856 = insertelement <2 x float> %9855, float %9701, i64 1, !dbg !331 + %9857 = fptrunc <2 x float> %9856 to <2 x bfloat>, !dbg !331 + %9858 = insertelement <2 x float> poison, float %9702, i64 0, !dbg !331 + %9859 = insertelement <2 x float> %9858, float %9703, i64 1, !dbg !331 + %9860 = fptrunc <2 x float> %9859 to <2 x bfloat>, !dbg !331 + %9861 = insertelement <2 x float> poison, float %9704, i64 0, !dbg !331 + %9862 = insertelement <2 x float> %9861, float %9705, i64 1, !dbg !331 + %9863 = fptrunc <2 x float> %9862 to <2 x bfloat>, !dbg !331 + %9864 = insertelement <2 x float> poison, float %9706, i64 0, !dbg !331 + %9865 = insertelement <2 x float> %9864, float %9707, i64 1, !dbg !331 + %9866 = fptrunc <2 x float> %9865 to <2 x bfloat>, !dbg !331 + %9867 = insertelement <2 x float> poison, float %9708, i64 0, !dbg !331 + %9868 = insertelement <2 x float> %9867, float %9709, i64 1, !dbg !331 + %9869 = fptrunc <2 x float> %9868 to <2 x bfloat>, !dbg !331 + %9870 = insertelement <2 x float> poison, float %9710, i64 0, !dbg !331 + %9871 = insertelement <2 x float> %9870, float %9711, i64 1, !dbg !331 + %9872 = fptrunc <2 x float> %9871 to <2 x bfloat>, !dbg !331 + %9873 = insertelement <2 x float> poison, float %9712, i64 0, !dbg !331 + %9874 = insertelement <2 x float> %9873, float %9713, i64 1, !dbg !331 + %9875 = fptrunc <2 x float> %9874 to <2 x bfloat>, !dbg !331 + %9876 = insertelement <2 x float> poison, float %9714, i64 0, !dbg !331 + %9877 = insertelement <2 x float> %9876, float %9715, i64 1, !dbg !331 + %9878 = fptrunc <2 x float> %9877 to <2 x bfloat>, !dbg !331 + %9879 = insertelement <2 x float> poison, float %9716, i64 0, !dbg !331 + %9880 = insertelement <2 x float> %9879, float %9717, i64 1, !dbg !331 + %9881 = fptrunc <2 x float> %9880 to <2 x bfloat>, !dbg !331 + %9882 = insertelement <2 x float> poison, float %9718, i64 0, !dbg !331 + %9883 = insertelement <2 x float> %9882, float %9719, i64 1, !dbg !331 + %9884 = fptrunc <2 x float> %9883 to <2 x bfloat>, !dbg !331 + %9885 = insertelement <2 x float> poison, float %9720, i64 0, !dbg !331 + %9886 = insertelement <2 x float> %9885, float %9721, i64 1, !dbg !331 + %9887 = fptrunc <2 x float> %9886 to <2 x bfloat>, !dbg !331 + %9888 = insertelement <2 x float> poison, float %9722, i64 0, !dbg !331 + %9889 = insertelement <2 x float> %9888, float %9723, i64 1, !dbg !331 + %9890 = fptrunc <2 x float> %9889 to <2 x bfloat>, !dbg !331 + %9891 = insertelement <2 x float> poison, float %9724, i64 0, !dbg !331 + %9892 = insertelement <2 x float> %9891, float %9725, i64 1, !dbg !331 + %9893 = fptrunc <2 x float> %9892 to <2 x bfloat>, !dbg !331 + %9894 = insertelement <2 x float> poison, float %9726, i64 0, !dbg !331 + %9895 = insertelement <2 x float> %9894, float %9727, i64 1, !dbg !331 + %9896 = fptrunc <2 x float> %9895 to <2 x bfloat>, !dbg !331 + %9897 = insertelement <2 x float> poison, float %9728, i64 0, !dbg !331 + %9898 = insertelement <2 x float> %9897, float %9729, i64 1, !dbg !331 + %9899 = fptrunc <2 x float> %9898 to <2 x bfloat>, !dbg !331 + %9900 = insertelement <2 x float> poison, float %9730, i64 0, !dbg !331 + %9901 = insertelement <2 x float> %9900, float %9731, i64 1, !dbg !331 + %9902 = fptrunc <2 x float> %9901 to <2 x bfloat>, !dbg !331 + %9903 = insertelement <2 x float> poison, float %9732, i64 0, !dbg !331 + %9904 = insertelement <2 x float> %9903, float %9733, i64 1, !dbg !331 + %9905 = fptrunc <2 x float> %9904 to <2 x bfloat>, !dbg !331 + %9906 = insertelement <2 x float> poison, float %9734, i64 0, !dbg !331 + %9907 = insertelement <2 x float> %9906, float %9735, i64 1, !dbg !331 + %9908 = fptrunc <2 x float> %9907 to <2 x bfloat>, !dbg !331 + %9909 = insertelement <2 x float> poison, float %9736, i64 0, !dbg !331 + %9910 = insertelement <2 x float> %9909, float %9737, i64 1, !dbg !331 + %9911 = fptrunc <2 x float> %9910 to <2 x bfloat>, !dbg !331 + %9912 = insertelement <2 x float> poison, float %9738, i64 0, !dbg !331 + %9913 = insertelement <2 x float> %9912, float %9739, i64 1, !dbg !331 + %9914 = fptrunc <2 x float> %9913 to <2 x bfloat>, !dbg !331 + %9915 = insertelement <2 x float> poison, float %9740, i64 0, !dbg !331 + %9916 = insertelement <2 x float> %9915, float %9741, i64 1, !dbg !331 + %9917 = fptrunc <2 x float> %9916 to <2 x bfloat>, !dbg !331 + %9918 = insertelement <2 x float> poison, float %9742, i64 0, !dbg !331 + %9919 = insertelement <2 x float> %9918, float %9743, i64 1, !dbg !331 + %9920 = fptrunc <2 x float> %9919 to <2 x bfloat>, !dbg !331 + %9921 = shl nuw nsw i32 %4713, 13, !dbg !331 + %9922 = shl nuw nsw i32 %35, 5, !dbg !331 + %9923 = and i32 %9922, 7264, !dbg !331 + %9924 = and i32 %35, 24, !dbg !331 + %9925 = shl nuw nsw i32 %9924, 4, !dbg !331 + %9926 = shl nuw nsw i32 %35, 2, !dbg !331 + %9927 = and i32 %9926, 16, !dbg !331 + %9928 = or disjoint i32 %9921, %9927, !dbg !331 + %9929 = or disjoint i32 %9923, %9925, !dbg !331 + %9930 = or disjoint i32 %9928, %9929, !dbg !331 + %9931 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %9930, !dbg !331 + %9932 = bitcast <2 x bfloat> %9827 to i32, !dbg !331 + %9933 = bitcast <2 x bfloat> %9833 to i32, !dbg !331 + %9934 = bitcast <2 x bfloat> %9839 to i32, !dbg !331 + %9935 = bitcast <2 x bfloat> %9845 to i32, !dbg !331 + %9936 = insertelement <4 x i32> poison, i32 %9932, i64 0, !dbg !331 + %9937 = insertelement <4 x i32> %9936, i32 %9933, i64 1, !dbg !331 + %9938 = insertelement <4 x i32> %9937, i32 %9934, i64 2, !dbg !331 + %9939 = insertelement <4 x i32> %9938, i32 %9935, i64 3, !dbg !331 + store <4 x i32> %9939, ptr addrspace(3) %9931, align 16, !dbg !331 + %9940 = getelementptr inbounds nuw i8, ptr addrspace(3) %9931, i32 512, !dbg !331 + %9941 = bitcast <2 x bfloat> %9830 to i32, !dbg !331 + %9942 = bitcast <2 x bfloat> %9836 to i32, !dbg !331 + %9943 = bitcast <2 x bfloat> %9842 to i32, !dbg !331 + %9944 = bitcast <2 x bfloat> %9848 to i32, !dbg !331 + %9945 = insertelement <4 x i32> poison, i32 %9941, i64 0, !dbg !331 + %9946 = insertelement <4 x i32> %9945, i32 %9942, i64 1, !dbg !331 + %9947 = insertelement <4 x i32> %9946, i32 %9943, i64 2, !dbg !331 + %9948 = insertelement <4 x i32> %9947, i32 %9944, i64 3, !dbg !331 + store <4 x i32> %9948, ptr addrspace(3) %9940, align 16, !dbg !331 + %9949 = xor i32 %9930, 32, !dbg !331 + %9950 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %9949, !dbg !331 + %9951 = bitcast <2 x bfloat> %9851 to i32, !dbg !331 + %9952 = bitcast <2 x bfloat> %9857 to i32, !dbg !331 + %9953 = bitcast <2 x bfloat> %9863 to i32, !dbg !331 + %9954 = bitcast <2 x bfloat> %9869 to i32, !dbg !331 + %9955 = insertelement <4 x i32> poison, i32 %9951, i64 0, !dbg !331 + %9956 = insertelement <4 x i32> %9955, i32 %9952, i64 1, !dbg !331 + %9957 = insertelement <4 x i32> %9956, i32 %9953, i64 2, !dbg !331 + %9958 = insertelement <4 x i32> %9957, i32 %9954, i64 3, !dbg !331 + store <4 x i32> %9958, ptr addrspace(3) %9950, align 16, !dbg !331 + %9959 = getelementptr inbounds nuw i8, ptr addrspace(3) %9950, i32 512, !dbg !331 + %9960 = bitcast <2 x bfloat> %9854 to i32, !dbg !331 + %9961 = bitcast <2 x bfloat> %9860 to i32, !dbg !331 + %9962 = bitcast <2 x bfloat> %9866 to i32, !dbg !331 + %9963 = bitcast <2 x bfloat> %9872 to i32, !dbg !331 + %9964 = insertelement <4 x i32> poison, i32 %9960, i64 0, !dbg !331 + %9965 = insertelement <4 x i32> %9964, i32 %9961, i64 1, !dbg !331 + %9966 = insertelement <4 x i32> %9965, i32 %9962, i64 2, !dbg !331 + %9967 = insertelement <4 x i32> %9966, i32 %9963, i64 3, !dbg !331 + store <4 x i32> %9967, ptr addrspace(3) %9959, align 16, !dbg !331 + %9968 = xor i32 %9930, 64, !dbg !331 + %9969 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %9968, !dbg !331 + %9970 = bitcast <2 x bfloat> %9875 to i32, !dbg !331 + %9971 = bitcast <2 x bfloat> %9881 to i32, !dbg !331 + %9972 = bitcast <2 x bfloat> %9887 to i32, !dbg !331 + %9973 = bitcast <2 x bfloat> %9893 to i32, !dbg !331 + %9974 = insertelement <4 x i32> poison, i32 %9970, i64 0, !dbg !331 + %9975 = insertelement <4 x i32> %9974, i32 %9971, i64 1, !dbg !331 + %9976 = insertelement <4 x i32> %9975, i32 %9972, i64 2, !dbg !331 + %9977 = insertelement <4 x i32> %9976, i32 %9973, i64 3, !dbg !331 + store <4 x i32> %9977, ptr addrspace(3) %9969, align 16, !dbg !331 + %9978 = getelementptr inbounds nuw i8, ptr addrspace(3) %9969, i32 512, !dbg !331 + %9979 = bitcast <2 x bfloat> %9878 to i32, !dbg !331 + %9980 = bitcast <2 x bfloat> %9884 to i32, !dbg !331 + %9981 = bitcast <2 x bfloat> %9890 to i32, !dbg !331 + %9982 = bitcast <2 x bfloat> %9896 to i32, !dbg !331 + %9983 = insertelement <4 x i32> poison, i32 %9979, i64 0, !dbg !331 + %9984 = insertelement <4 x i32> %9983, i32 %9980, i64 1, !dbg !331 + %9985 = insertelement <4 x i32> %9984, i32 %9981, i64 2, !dbg !331 + %9986 = insertelement <4 x i32> %9985, i32 %9982, i64 3, !dbg !331 + store <4 x i32> %9986, ptr addrspace(3) %9978, align 16, !dbg !331 + %9987 = xor i32 %9930, 96, !dbg !331 + %9988 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %9987, !dbg !331 + %9989 = bitcast <2 x bfloat> %9899 to i32, !dbg !331 + %9990 = bitcast <2 x bfloat> %9905 to i32, !dbg !331 + %9991 = bitcast <2 x bfloat> %9911 to i32, !dbg !331 + %9992 = bitcast <2 x bfloat> %9917 to i32, !dbg !331 + %9993 = insertelement <4 x i32> poison, i32 %9989, i64 0, !dbg !331 + %9994 = insertelement <4 x i32> %9993, i32 %9990, i64 1, !dbg !331 + %9995 = insertelement <4 x i32> %9994, i32 %9991, i64 2, !dbg !331 + %9996 = insertelement <4 x i32> %9995, i32 %9992, i64 3, !dbg !331 + store <4 x i32> %9996, ptr addrspace(3) %9988, align 16, !dbg !331 + %9997 = getelementptr inbounds nuw i8, ptr addrspace(3) %9988, i32 512, !dbg !331 + %9998 = bitcast <2 x bfloat> %9902 to i32, !dbg !331 + %9999 = bitcast <2 x bfloat> %9908 to i32, !dbg !331 + %10000 = bitcast <2 x bfloat> %9914 to i32, !dbg !331 + %10001 = bitcast <2 x bfloat> %9920 to i32, !dbg !331 + %10002 = insertelement <4 x i32> poison, i32 %9998, i64 0, !dbg !331 + %10003 = insertelement <4 x i32> %10002, i32 %9999, i64 1, !dbg !331 + %10004 = insertelement <4 x i32> %10003, i32 %10000, i64 2, !dbg !331 + %10005 = insertelement <4 x i32> %10004, i32 %10001, i64 3, !dbg !331 + store <4 x i32> %10005, ptr addrspace(3) %9997, align 16, !dbg !331 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !331 + %10006 = shl nuw nsw i32 %9924, 10, !dbg !331 + %10007 = shl nuw nsw i32 %4713, 5, !dbg !331 + %10008 = and i32 %9926, 1008, !dbg !331 + %10009 = or disjoint i32 %10006, %10007, !dbg !331 + %10010 = xor i32 %10009, %10008, !dbg !331 + %10011 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %10010, !dbg !331 + %10012 = ptrtoint ptr addrspace(3) %10011 to i32, !dbg !331 + %10013 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %10012) #3, !dbg !331 + %10014 = extractvalue { i32, i32, i32, i32 } %10013, 0, !dbg !331 + %10015 = extractvalue { i32, i32, i32, i32 } %10013, 1, !dbg !331 + %10016 = extractvalue { i32, i32, i32, i32 } %10013, 2, !dbg !331 + %10017 = extractvalue { i32, i32, i32, i32 } %10013, 3, !dbg !331 + %10018 = getelementptr inbounds nuw i8, ptr addrspace(3) %10011, i32 1024, !dbg !331 + %10019 = ptrtoint ptr addrspace(3) %10018 to i32, !dbg !331 + %10020 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %10019) #3, !dbg !331 + %10021 = extractvalue { i32, i32, i32, i32 } %10020, 0, !dbg !331 + %10022 = extractvalue { i32, i32, i32, i32 } %10020, 1, !dbg !331 + %10023 = extractvalue { i32, i32, i32, i32 } %10020, 2, !dbg !331 + %10024 = extractvalue { i32, i32, i32, i32 } %10020, 3, !dbg !331 + %10025 = getelementptr inbounds nuw i8, ptr addrspace(3) %10011, i32 2048, !dbg !331 + %10026 = ptrtoint ptr addrspace(3) %10025 to i32, !dbg !331 + %10027 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %10026) #3, !dbg !331 + %10028 = extractvalue { i32, i32, i32, i32 } %10027, 0, !dbg !331 + %10029 = extractvalue { i32, i32, i32, i32 } %10027, 1, !dbg !331 + %10030 = extractvalue { i32, i32, i32, i32 } %10027, 2, !dbg !331 + %10031 = extractvalue { i32, i32, i32, i32 } %10027, 3, !dbg !331 + %10032 = getelementptr inbounds nuw i8, ptr addrspace(3) %10011, i32 3072, !dbg !331 + %10033 = ptrtoint ptr addrspace(3) %10032 to i32, !dbg !331 + %10034 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %10033) #3, !dbg !331 + %10035 = extractvalue { i32, i32, i32, i32 } %10034, 0, !dbg !331 + %10036 = extractvalue { i32, i32, i32, i32 } %10034, 1, !dbg !331 + %10037 = extractvalue { i32, i32, i32, i32 } %10034, 2, !dbg !331 + %10038 = extractvalue { i32, i32, i32, i32 } %10034, 3, !dbg !331 + %10039 = getelementptr inbounds nuw i8, ptr addrspace(3) %10011, i32 4096, !dbg !331 + %10040 = ptrtoint ptr addrspace(3) %10039 to i32, !dbg !331 + %10041 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %10040) #3, !dbg !331 + %10042 = extractvalue { i32, i32, i32, i32 } %10041, 0, !dbg !331 + %10043 = extractvalue { i32, i32, i32, i32 } %10041, 1, !dbg !331 + %10044 = extractvalue { i32, i32, i32, i32 } %10041, 2, !dbg !331 + %10045 = extractvalue { i32, i32, i32, i32 } %10041, 3, !dbg !331 + %10046 = getelementptr inbounds nuw i8, ptr addrspace(3) %10011, i32 5120, !dbg !331 + %10047 = ptrtoint ptr addrspace(3) %10046 to i32, !dbg !331 + %10048 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %10047) #3, !dbg !331 + %10049 = extractvalue { i32, i32, i32, i32 } %10048, 0, !dbg !331 + %10050 = extractvalue { i32, i32, i32, i32 } %10048, 1, !dbg !331 + %10051 = extractvalue { i32, i32, i32, i32 } %10048, 2, !dbg !331 + %10052 = extractvalue { i32, i32, i32, i32 } %10048, 3, !dbg !331 + %10053 = getelementptr inbounds nuw i8, ptr addrspace(3) %10011, i32 6144, !dbg !331 + %10054 = ptrtoint ptr addrspace(3) %10053 to i32, !dbg !331 + %10055 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %10054) #3, !dbg !331 + %10056 = extractvalue { i32, i32, i32, i32 } %10055, 0, !dbg !331 + %10057 = extractvalue { i32, i32, i32, i32 } %10055, 1, !dbg !331 + %10058 = extractvalue { i32, i32, i32, i32 } %10055, 2, !dbg !331 + %10059 = extractvalue { i32, i32, i32, i32 } %10055, 3, !dbg !331 + %10060 = getelementptr inbounds nuw i8, ptr addrspace(3) %10011, i32 7168, !dbg !331 + %10061 = ptrtoint ptr addrspace(3) %10060 to i32, !dbg !331 + %10062 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %10061) #3, !dbg !331 + %10063 = extractvalue { i32, i32, i32, i32 } %10062, 0, !dbg !331 + %10064 = extractvalue { i32, i32, i32, i32 } %10062, 1, !dbg !331 + %10065 = extractvalue { i32, i32, i32, i32 } %10062, 2, !dbg !331 + %10066 = extractvalue { i32, i32, i32, i32 } %10062, 3, !dbg !331 + tail call void asm sideeffect "st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l"(i32 %10014, i32 %10015, i32 %10016, i32 %10017, ptr addrspace(1) %9817) #3, !dbg !331 + tail call void asm sideeffect "st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l"(i32 %10021, i32 %10022, i32 %10023, i32 %10024, ptr addrspace(1) %9818) #3, !dbg !331 + tail call void asm sideeffect "st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l"(i32 %10028, i32 %10029, i32 %10030, i32 %10031, ptr addrspace(1) %9819) #3, !dbg !331 + tail call void asm sideeffect "st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l"(i32 %10035, i32 %10036, i32 %10037, i32 %10038, ptr addrspace(1) %9820) #3, !dbg !331 + tail call void asm sideeffect "st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l"(i32 %10042, i32 %10043, i32 %10044, i32 %10045, ptr addrspace(1) %9821) #3, !dbg !331 + tail call void asm sideeffect "st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l"(i32 %10049, i32 %10050, i32 %10051, i32 %10052, ptr addrspace(1) %9822) #3, !dbg !331 + tail call void asm sideeffect "st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l"(i32 %10056, i32 %10057, i32 %10058, i32 %10059, ptr addrspace(1) %9823) #3, !dbg !331 + tail call void asm sideeffect "st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l"(i32 %10063, i32 %10064, i32 %10065, i32 %10066, ptr addrspace(1) %9824) #3, !dbg !331 + %10067 = insertelement <2 x float> poison, float %9744, i64 0, !dbg !332 + %10068 = insertelement <2 x float> %10067, float %9745, i64 1, !dbg !332 + %10069 = fmul <2 x float> %10068, splat (float 0x3FB6A09E60000000), !dbg !332 + %10070 = insertelement <2 x float> poison, float %9746, i64 0, !dbg !332 + %10071 = insertelement <2 x float> %10070, float %9747, i64 1, !dbg !332 + %10072 = fmul <2 x float> %10071, splat (float 0x3FB6A09E60000000), !dbg !332 + %10073 = insertelement <2 x float> poison, float %9748, i64 0, !dbg !332 + %10074 = insertelement <2 x float> %10073, float %9749, i64 1, !dbg !332 + %10075 = fmul <2 x float> %10074, splat (float 0x3FB6A09E60000000), !dbg !332 + %10076 = insertelement <2 x float> poison, float %9750, i64 0, !dbg !332 + %10077 = insertelement <2 x float> %10076, float %9751, i64 1, !dbg !332 + %10078 = fmul <2 x float> %10077, splat (float 0x3FB6A09E60000000), !dbg !332 + %10079 = insertelement <2 x float> poison, float %9752, i64 0, !dbg !332 + %10080 = insertelement <2 x float> %10079, float %9753, i64 1, !dbg !332 + %10081 = fmul <2 x float> %10080, splat (float 0x3FB6A09E60000000), !dbg !332 + %10082 = insertelement <2 x float> poison, float %9754, i64 0, !dbg !332 + %10083 = insertelement <2 x float> %10082, float %9755, i64 1, !dbg !332 + %10084 = fmul <2 x float> %10083, splat (float 0x3FB6A09E60000000), !dbg !332 + %10085 = insertelement <2 x float> poison, float %9756, i64 0, !dbg !332 + %10086 = insertelement <2 x float> %10085, float %9757, i64 1, !dbg !332 + %10087 = fmul <2 x float> %10086, splat (float 0x3FB6A09E60000000), !dbg !332 + %10088 = insertelement <2 x float> poison, float %9758, i64 0, !dbg !332 + %10089 = insertelement <2 x float> %10088, float %9759, i64 1, !dbg !332 + %10090 = fmul <2 x float> %10089, splat (float 0x3FB6A09E60000000), !dbg !332 + %10091 = insertelement <2 x float> poison, float %9760, i64 0, !dbg !332 + %10092 = insertelement <2 x float> %10091, float %9761, i64 1, !dbg !332 + %10093 = fmul <2 x float> %10092, splat (float 0x3FB6A09E60000000), !dbg !332 + %10094 = insertelement <2 x float> poison, float %9762, i64 0, !dbg !332 + %10095 = insertelement <2 x float> %10094, float %9763, i64 1, !dbg !332 + %10096 = fmul <2 x float> %10095, splat (float 0x3FB6A09E60000000), !dbg !332 + %10097 = insertelement <2 x float> poison, float %9764, i64 0, !dbg !332 + %10098 = insertelement <2 x float> %10097, float %9765, i64 1, !dbg !332 + %10099 = fmul <2 x float> %10098, splat (float 0x3FB6A09E60000000), !dbg !332 + %10100 = insertelement <2 x float> poison, float %9766, i64 0, !dbg !332 + %10101 = insertelement <2 x float> %10100, float %9767, i64 1, !dbg !332 + %10102 = fmul <2 x float> %10101, splat (float 0x3FB6A09E60000000), !dbg !332 + %10103 = insertelement <2 x float> poison, float %9768, i64 0, !dbg !332 + %10104 = insertelement <2 x float> %10103, float %9769, i64 1, !dbg !332 + %10105 = fmul <2 x float> %10104, splat (float 0x3FB6A09E60000000), !dbg !332 + %10106 = insertelement <2 x float> poison, float %9770, i64 0, !dbg !332 + %10107 = insertelement <2 x float> %10106, float %9771, i64 1, !dbg !332 + %10108 = fmul <2 x float> %10107, splat (float 0x3FB6A09E60000000), !dbg !332 + %10109 = insertelement <2 x float> poison, float %9772, i64 0, !dbg !332 + %10110 = insertelement <2 x float> %10109, float %9773, i64 1, !dbg !332 + %10111 = fmul <2 x float> %10110, splat (float 0x3FB6A09E60000000), !dbg !332 + %10112 = insertelement <2 x float> poison, float %9774, i64 0, !dbg !332 + %10113 = insertelement <2 x float> %10112, float %9775, i64 1, !dbg !332 + %10114 = fmul <2 x float> %10113, splat (float 0x3FB6A09E60000000), !dbg !332 + %10115 = insertelement <2 x float> poison, float %9776, i64 0, !dbg !332 + %10116 = insertelement <2 x float> %10115, float %9777, i64 1, !dbg !332 + %10117 = fmul <2 x float> %10116, splat (float 0x3FB6A09E60000000), !dbg !332 + %10118 = insertelement <2 x float> poison, float %9778, i64 0, !dbg !332 + %10119 = insertelement <2 x float> %10118, float %9779, i64 1, !dbg !332 + %10120 = fmul <2 x float> %10119, splat (float 0x3FB6A09E60000000), !dbg !332 + %10121 = insertelement <2 x float> poison, float %9780, i64 0, !dbg !332 + %10122 = insertelement <2 x float> %10121, float %9781, i64 1, !dbg !332 + %10123 = fmul <2 x float> %10122, splat (float 0x3FB6A09E60000000), !dbg !332 + %10124 = insertelement <2 x float> poison, float %9782, i64 0, !dbg !332 + %10125 = insertelement <2 x float> %10124, float %9783, i64 1, !dbg !332 + %10126 = fmul <2 x float> %10125, splat (float 0x3FB6A09E60000000), !dbg !332 + %10127 = insertelement <2 x float> poison, float %9784, i64 0, !dbg !332 + %10128 = insertelement <2 x float> %10127, float %9785, i64 1, !dbg !332 + %10129 = fmul <2 x float> %10128, splat (float 0x3FB6A09E60000000), !dbg !332 + %10130 = insertelement <2 x float> poison, float %9786, i64 0, !dbg !332 + %10131 = insertelement <2 x float> %10130, float %9787, i64 1, !dbg !332 + %10132 = fmul <2 x float> %10131, splat (float 0x3FB6A09E60000000), !dbg !332 + %10133 = insertelement <2 x float> poison, float %9788, i64 0, !dbg !332 + %10134 = insertelement <2 x float> %10133, float %9789, i64 1, !dbg !332 + %10135 = fmul <2 x float> %10134, splat (float 0x3FB6A09E60000000), !dbg !332 + %10136 = insertelement <2 x float> poison, float %9790, i64 0, !dbg !332 + %10137 = insertelement <2 x float> %10136, float %9791, i64 1, !dbg !332 + %10138 = fmul <2 x float> %10137, splat (float 0x3FB6A09E60000000), !dbg !332 + %10139 = insertelement <2 x float> poison, float %9792, i64 0, !dbg !332 + %10140 = insertelement <2 x float> %10139, float %9793, i64 1, !dbg !332 + %10141 = fmul <2 x float> %10140, splat (float 0x3FB6A09E60000000), !dbg !332 + %10142 = insertelement <2 x float> poison, float %9794, i64 0, !dbg !332 + %10143 = insertelement <2 x float> %10142, float %9795, i64 1, !dbg !332 + %10144 = fmul <2 x float> %10143, splat (float 0x3FB6A09E60000000), !dbg !332 + %10145 = insertelement <2 x float> poison, float %9796, i64 0, !dbg !332 + %10146 = insertelement <2 x float> %10145, float %9797, i64 1, !dbg !332 + %10147 = fmul <2 x float> %10146, splat (float 0x3FB6A09E60000000), !dbg !332 + %10148 = insertelement <2 x float> poison, float %9798, i64 0, !dbg !332 + %10149 = insertelement <2 x float> %10148, float %9799, i64 1, !dbg !332 + %10150 = fmul <2 x float> %10149, splat (float 0x3FB6A09E60000000), !dbg !332 + %10151 = insertelement <2 x float> poison, float %9800, i64 0, !dbg !332 + %10152 = insertelement <2 x float> %10151, float %9801, i64 1, !dbg !332 + %10153 = fmul <2 x float> %10152, splat (float 0x3FB6A09E60000000), !dbg !332 + %10154 = insertelement <2 x float> poison, float %9802, i64 0, !dbg !332 + %10155 = insertelement <2 x float> %10154, float %9803, i64 1, !dbg !332 + %10156 = fmul <2 x float> %10155, splat (float 0x3FB6A09E60000000), !dbg !332 + %10157 = insertelement <2 x float> poison, float %9804, i64 0, !dbg !332 + %10158 = insertelement <2 x float> %10157, float %9805, i64 1, !dbg !332 + %10159 = fmul <2 x float> %10158, splat (float 0x3FB6A09E60000000), !dbg !332 + %10160 = insertelement <2 x float> poison, float %9806, i64 0, !dbg !332 + %10161 = insertelement <2 x float> %10160, float %9807, i64 1, !dbg !332 + %10162 = fmul <2 x float> %10161, splat (float 0x3FB6A09E60000000), !dbg !332 + %10163 = or disjoint i32 %4472, %30, !dbg !333 + %10164 = or disjoint i32 %10163, %4497, !dbg !333 + %10165 = or disjoint i32 %4473, %30, !dbg !333 + %10166 = or disjoint i32 %10165, %4497, !dbg !333 + %10167 = or disjoint i32 %4474, %30, !dbg !333 + %10168 = or disjoint i32 %10167, %4497, !dbg !333 + %10169 = or disjoint i32 %4475, %30, !dbg !333 + %10170 = or disjoint i32 %10169, %4497, !dbg !333 + %10171 = or disjoint i32 %4476, %30, !dbg !333 + %10172 = or disjoint i32 %10171, %4497, !dbg !333 + %10173 = or disjoint i32 %4477, %30, !dbg !333 + %10174 = or disjoint i32 %10173, %4497, !dbg !333 + %10175 = or disjoint i32 %4478, %30, !dbg !333 + %10176 = or disjoint i32 %10175, %4497, !dbg !333 + %10177 = or disjoint i32 %4479, %30, !dbg !333 + %10178 = or disjoint i32 %10177, %4497, !dbg !333 + %10179 = sext i32 %10164 to i64, !dbg !334 + %10180 = getelementptr bfloat, ptr addrspace(1) %17, i64 %10179, !dbg !334 + %10181 = sext i32 %10166 to i64, !dbg !334 + %10182 = getelementptr bfloat, ptr addrspace(1) %17, i64 %10181, !dbg !334 + %10183 = sext i32 %10168 to i64, !dbg !334 + %10184 = getelementptr bfloat, ptr addrspace(1) %17, i64 %10183, !dbg !334 + %10185 = sext i32 %10170 to i64, !dbg !334 + %10186 = getelementptr bfloat, ptr addrspace(1) %17, i64 %10185, !dbg !334 + %10187 = sext i32 %10172 to i64, !dbg !334 + %10188 = getelementptr bfloat, ptr addrspace(1) %17, i64 %10187, !dbg !334 + %10189 = sext i32 %10174 to i64, !dbg !334 + %10190 = getelementptr bfloat, ptr addrspace(1) %17, i64 %10189, !dbg !334 + %10191 = sext i32 %10176 to i64, !dbg !334 + %10192 = getelementptr bfloat, ptr addrspace(1) %17, i64 %10191, !dbg !334 + %10193 = sext i32 %10178 to i64, !dbg !334 + %10194 = getelementptr bfloat, ptr addrspace(1) %17, i64 %10193, !dbg !334 + %10195 = fptrunc <2 x float> %10069 to <2 x bfloat>, !dbg !335 + %10196 = fptrunc <2 x float> %10072 to <2 x bfloat>, !dbg !335 + %10197 = fptrunc <2 x float> %10075 to <2 x bfloat>, !dbg !335 + %10198 = fptrunc <2 x float> %10078 to <2 x bfloat>, !dbg !335 + %10199 = fptrunc <2 x float> %10081 to <2 x bfloat>, !dbg !335 + %10200 = fptrunc <2 x float> %10084 to <2 x bfloat>, !dbg !335 + %10201 = fptrunc <2 x float> %10087 to <2 x bfloat>, !dbg !335 + %10202 = fptrunc <2 x float> %10090 to <2 x bfloat>, !dbg !335 + %10203 = fptrunc <2 x float> %10093 to <2 x bfloat>, !dbg !335 + %10204 = fptrunc <2 x float> %10096 to <2 x bfloat>, !dbg !335 + %10205 = fptrunc <2 x float> %10099 to <2 x bfloat>, !dbg !335 + %10206 = fptrunc <2 x float> %10102 to <2 x bfloat>, !dbg !335 + %10207 = fptrunc <2 x float> %10105 to <2 x bfloat>, !dbg !335 + %10208 = fptrunc <2 x float> %10108 to <2 x bfloat>, !dbg !335 + %10209 = fptrunc <2 x float> %10111 to <2 x bfloat>, !dbg !335 + %10210 = fptrunc <2 x float> %10114 to <2 x bfloat>, !dbg !335 + %10211 = fptrunc <2 x float> %10117 to <2 x bfloat>, !dbg !335 + %10212 = fptrunc <2 x float> %10120 to <2 x bfloat>, !dbg !335 + %10213 = fptrunc <2 x float> %10123 to <2 x bfloat>, !dbg !335 + %10214 = fptrunc <2 x float> %10126 to <2 x bfloat>, !dbg !335 + %10215 = fptrunc <2 x float> %10129 to <2 x bfloat>, !dbg !335 + %10216 = fptrunc <2 x float> %10132 to <2 x bfloat>, !dbg !335 + %10217 = fptrunc <2 x float> %10135 to <2 x bfloat>, !dbg !335 + %10218 = fptrunc <2 x float> %10138 to <2 x bfloat>, !dbg !335 + %10219 = fptrunc <2 x float> %10141 to <2 x bfloat>, !dbg !335 + %10220 = fptrunc <2 x float> %10144 to <2 x bfloat>, !dbg !335 + %10221 = fptrunc <2 x float> %10147 to <2 x bfloat>, !dbg !335 + %10222 = fptrunc <2 x float> %10150 to <2 x bfloat>, !dbg !335 + %10223 = fptrunc <2 x float> %10153 to <2 x bfloat>, !dbg !335 + %10224 = fptrunc <2 x float> %10156 to <2 x bfloat>, !dbg !335 + %10225 = fptrunc <2 x float> %10159 to <2 x bfloat>, !dbg !335 + %10226 = fptrunc <2 x float> %10162 to <2 x bfloat>, !dbg !335 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !335 + %10227 = bitcast <2 x bfloat> %10195 to i32, !dbg !335 + %10228 = bitcast <2 x bfloat> %10197 to i32, !dbg !335 + %10229 = bitcast <2 x bfloat> %10199 to i32, !dbg !335 + %10230 = bitcast <2 x bfloat> %10201 to i32, !dbg !335 + %10231 = insertelement <4 x i32> poison, i32 %10227, i64 0, !dbg !335 + %10232 = insertelement <4 x i32> %10231, i32 %10228, i64 1, !dbg !335 + %10233 = insertelement <4 x i32> %10232, i32 %10229, i64 2, !dbg !335 + %10234 = insertelement <4 x i32> %10233, i32 %10230, i64 3, !dbg !335 + store <4 x i32> %10234, ptr addrspace(3) %9931, align 16, !dbg !335 + %10235 = bitcast <2 x bfloat> %10196 to i32, !dbg !335 + %10236 = bitcast <2 x bfloat> %10198 to i32, !dbg !335 + %10237 = bitcast <2 x bfloat> %10200 to i32, !dbg !335 + %10238 = bitcast <2 x bfloat> %10202 to i32, !dbg !335 + %10239 = insertelement <4 x i32> poison, i32 %10235, i64 0, !dbg !335 + %10240 = insertelement <4 x i32> %10239, i32 %10236, i64 1, !dbg !335 + %10241 = insertelement <4 x i32> %10240, i32 %10237, i64 2, !dbg !335 + %10242 = insertelement <4 x i32> %10241, i32 %10238, i64 3, !dbg !335 + store <4 x i32> %10242, ptr addrspace(3) %9940, align 16, !dbg !335 + %10243 = bitcast <2 x bfloat> %10203 to i32, !dbg !335 + %10244 = bitcast <2 x bfloat> %10205 to i32, !dbg !335 + %10245 = bitcast <2 x bfloat> %10207 to i32, !dbg !335 + %10246 = bitcast <2 x bfloat> %10209 to i32, !dbg !335 + %10247 = insertelement <4 x i32> poison, i32 %10243, i64 0, !dbg !335 + %10248 = insertelement <4 x i32> %10247, i32 %10244, i64 1, !dbg !335 + %10249 = insertelement <4 x i32> %10248, i32 %10245, i64 2, !dbg !335 + %10250 = insertelement <4 x i32> %10249, i32 %10246, i64 3, !dbg !335 + store <4 x i32> %10250, ptr addrspace(3) %9950, align 16, !dbg !335 + %10251 = bitcast <2 x bfloat> %10204 to i32, !dbg !335 + %10252 = bitcast <2 x bfloat> %10206 to i32, !dbg !335 + %10253 = bitcast <2 x bfloat> %10208 to i32, !dbg !335 + %10254 = bitcast <2 x bfloat> %10210 to i32, !dbg !335 + %10255 = insertelement <4 x i32> poison, i32 %10251, i64 0, !dbg !335 + %10256 = insertelement <4 x i32> %10255, i32 %10252, i64 1, !dbg !335 + %10257 = insertelement <4 x i32> %10256, i32 %10253, i64 2, !dbg !335 + %10258 = insertelement <4 x i32> %10257, i32 %10254, i64 3, !dbg !335 + store <4 x i32> %10258, ptr addrspace(3) %9959, align 16, !dbg !335 + %10259 = bitcast <2 x bfloat> %10211 to i32, !dbg !335 + %10260 = bitcast <2 x bfloat> %10213 to i32, !dbg !335 + %10261 = bitcast <2 x bfloat> %10215 to i32, !dbg !335 + %10262 = bitcast <2 x bfloat> %10217 to i32, !dbg !335 + %10263 = insertelement <4 x i32> poison, i32 %10259, i64 0, !dbg !335 + %10264 = insertelement <4 x i32> %10263, i32 %10260, i64 1, !dbg !335 + %10265 = insertelement <4 x i32> %10264, i32 %10261, i64 2, !dbg !335 + %10266 = insertelement <4 x i32> %10265, i32 %10262, i64 3, !dbg !335 + store <4 x i32> %10266, ptr addrspace(3) %9969, align 16, !dbg !335 + %10267 = bitcast <2 x bfloat> %10212 to i32, !dbg !335 + %10268 = bitcast <2 x bfloat> %10214 to i32, !dbg !335 + %10269 = bitcast <2 x bfloat> %10216 to i32, !dbg !335 + %10270 = bitcast <2 x bfloat> %10218 to i32, !dbg !335 + %10271 = insertelement <4 x i32> poison, i32 %10267, i64 0, !dbg !335 + %10272 = insertelement <4 x i32> %10271, i32 %10268, i64 1, !dbg !335 + %10273 = insertelement <4 x i32> %10272, i32 %10269, i64 2, !dbg !335 + %10274 = insertelement <4 x i32> %10273, i32 %10270, i64 3, !dbg !335 + store <4 x i32> %10274, ptr addrspace(3) %9978, align 16, !dbg !335 + %10275 = bitcast <2 x bfloat> %10219 to i32, !dbg !335 + %10276 = bitcast <2 x bfloat> %10221 to i32, !dbg !335 + %10277 = bitcast <2 x bfloat> %10223 to i32, !dbg !335 + %10278 = bitcast <2 x bfloat> %10225 to i32, !dbg !335 + %10279 = insertelement <4 x i32> poison, i32 %10275, i64 0, !dbg !335 + %10280 = insertelement <4 x i32> %10279, i32 %10276, i64 1, !dbg !335 + %10281 = insertelement <4 x i32> %10280, i32 %10277, i64 2, !dbg !335 + %10282 = insertelement <4 x i32> %10281, i32 %10278, i64 3, !dbg !335 + store <4 x i32> %10282, ptr addrspace(3) %9988, align 16, !dbg !335 + %10283 = bitcast <2 x bfloat> %10220 to i32, !dbg !335 + %10284 = bitcast <2 x bfloat> %10222 to i32, !dbg !335 + %10285 = bitcast <2 x bfloat> %10224 to i32, !dbg !335 + %10286 = bitcast <2 x bfloat> %10226 to i32, !dbg !335 + %10287 = insertelement <4 x i32> poison, i32 %10283, i64 0, !dbg !335 + %10288 = insertelement <4 x i32> %10287, i32 %10284, i64 1, !dbg !335 + %10289 = insertelement <4 x i32> %10288, i32 %10285, i64 2, !dbg !335 + %10290 = insertelement <4 x i32> %10289, i32 %10286, i64 3, !dbg !335 + store <4 x i32> %10290, ptr addrspace(3) %9997, align 16, !dbg !335 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !335 + %10291 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %10012) #3, !dbg !335 + %10292 = extractvalue { i32, i32, i32, i32 } %10291, 0, !dbg !335 + %10293 = extractvalue { i32, i32, i32, i32 } %10291, 1, !dbg !335 + %10294 = extractvalue { i32, i32, i32, i32 } %10291, 2, !dbg !335 + %10295 = extractvalue { i32, i32, i32, i32 } %10291, 3, !dbg !335 + %10296 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %10019) #3, !dbg !335 + %10297 = extractvalue { i32, i32, i32, i32 } %10296, 0, !dbg !335 + %10298 = extractvalue { i32, i32, i32, i32 } %10296, 1, !dbg !335 + %10299 = extractvalue { i32, i32, i32, i32 } %10296, 2, !dbg !335 + %10300 = extractvalue { i32, i32, i32, i32 } %10296, 3, !dbg !335 + %10301 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %10026) #3, !dbg !335 + %10302 = extractvalue { i32, i32, i32, i32 } %10301, 0, !dbg !335 + %10303 = extractvalue { i32, i32, i32, i32 } %10301, 1, !dbg !335 + %10304 = extractvalue { i32, i32, i32, i32 } %10301, 2, !dbg !335 + %10305 = extractvalue { i32, i32, i32, i32 } %10301, 3, !dbg !335 + %10306 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %10033) #3, !dbg !335 + %10307 = extractvalue { i32, i32, i32, i32 } %10306, 0, !dbg !335 + %10308 = extractvalue { i32, i32, i32, i32 } %10306, 1, !dbg !335 + %10309 = extractvalue { i32, i32, i32, i32 } %10306, 2, !dbg !335 + %10310 = extractvalue { i32, i32, i32, i32 } %10306, 3, !dbg !335 + %10311 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %10040) #3, !dbg !335 + %10312 = extractvalue { i32, i32, i32, i32 } %10311, 0, !dbg !335 + %10313 = extractvalue { i32, i32, i32, i32 } %10311, 1, !dbg !335 + %10314 = extractvalue { i32, i32, i32, i32 } %10311, 2, !dbg !335 + %10315 = extractvalue { i32, i32, i32, i32 } %10311, 3, !dbg !335 + %10316 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %10047) #3, !dbg !335 + %10317 = extractvalue { i32, i32, i32, i32 } %10316, 0, !dbg !335 + %10318 = extractvalue { i32, i32, i32, i32 } %10316, 1, !dbg !335 + %10319 = extractvalue { i32, i32, i32, i32 } %10316, 2, !dbg !335 + %10320 = extractvalue { i32, i32, i32, i32 } %10316, 3, !dbg !335 + %10321 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %10054) #3, !dbg !335 + %10322 = extractvalue { i32, i32, i32, i32 } %10321, 0, !dbg !335 + %10323 = extractvalue { i32, i32, i32, i32 } %10321, 1, !dbg !335 + %10324 = extractvalue { i32, i32, i32, i32 } %10321, 2, !dbg !335 + %10325 = extractvalue { i32, i32, i32, i32 } %10321, 3, !dbg !335 + %10326 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %10061) #3, !dbg !335 + %10327 = extractvalue { i32, i32, i32, i32 } %10326, 0, !dbg !335 + %10328 = extractvalue { i32, i32, i32, i32 } %10326, 1, !dbg !335 + %10329 = extractvalue { i32, i32, i32, i32 } %10326, 2, !dbg !335 + %10330 = extractvalue { i32, i32, i32, i32 } %10326, 3, !dbg !335 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %10292, i32 %10293, i32 %10294, i32 %10295, ptr addrspace(1) %10180, i1 true) #3, !dbg !335 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %10297, i32 %10298, i32 %10299, i32 %10300, ptr addrspace(1) %10182, i1 true) #3, !dbg !335 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %10302, i32 %10303, i32 %10304, i32 %10305, ptr addrspace(1) %10184, i1 true) #3, !dbg !335 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %10307, i32 %10308, i32 %10309, i32 %10310, ptr addrspace(1) %10186, i1 true) #3, !dbg !335 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %10312, i32 %10313, i32 %10314, i32 %10315, ptr addrspace(1) %10188, i1 true) #3, !dbg !335 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %10317, i32 %10318, i32 %10319, i32 %10320, ptr addrspace(1) %10190, i1 true) #3, !dbg !335 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %10322, i32 %10323, i32 %10324, i32 %10325, ptr addrspace(1) %10192, i1 true) #3, !dbg !335 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %10327, i32 %10328, i32 %10329, i32 %10330, ptr addrspace(1) %10194, i1 true) #3, !dbg !335 + br label %10331, !dbg !24 + +10331: ; preds = %._crit_edge1859, %9808 + ret void, !dbg !336 +} + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare noundef range(i32 0, 2147483647) i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #1 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare noundef range(i32 0, 65535) i32 @llvm.nvvm.read.ptx.sreg.ctaid.y() #1 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare noundef range(i32 0, 65535) i32 @llvm.nvvm.read.ptx.sreg.ctaid.z() #1 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare noundef range(i32 0, 1024) i32 @llvm.nvvm.read.ptx.sreg.tid.x() #1 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare i32 @llvm.smin.i32(i32, i32) #1 + +; Function Attrs: convergent nocallback nounwind +declare void @llvm.nvvm.barrier.cta.sync.aligned.all(i32) #2 + +; Function Attrs: nounwind +declare void @llvm.nvvm.cp.async.commit.group() #3 + +; Function Attrs: nounwind +declare void @llvm.nvvm.cp.async.wait.group(i32 immarg) #3 + +; Function Attrs: convergent nocallback nounwind memory(inaccessiblemem: readwrite) +declare i32 @llvm.nvvm.shfl.sync.idx.i32(i32, i32, i32, i32) #4 + +; Function Attrs: convergent nounwind +declare void @llvm.nvvm.wgmma.fence.sync.aligned() #5 + +; Function Attrs: convergent nounwind +declare void @llvm.nvvm.wgmma.commit_group.sync.aligned() #5 + +declare i32 @__nvvm_reflect(ptr) local_unnamed_addr #6 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind willreturn memory(none) +declare float @llvm.nvvm.ex2.approx.ftz.f(float) #7 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind willreturn memory(none) +declare float @llvm.nvvm.ex2.approx.f(float) #7 + +; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare i32 @llvm.smax.i32(i32, i32) #8 + +; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare i32 @llvm.umin.i32(i32, i32) #8 + +attributes #0 = { nounwind "nvvm.reqntid"="256" } +attributes #1 = { mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) } +attributes #2 = { convergent nocallback nounwind } +attributes #3 = { nounwind } +attributes #4 = { convergent nocallback nounwind memory(inaccessiblemem: readwrite) } +attributes #5 = { convergent nounwind } +attributes #6 = { "disable-tail-calls"="false" "frame-pointer"="all" "less-precise-fpmad"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #7 = { mustprogress nocallback nofree nosync nounwind willreturn memory(none) } +attributes #8 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } + +!llvm.dbg.cu = !{!0} +!llvm.module.flags = !{!2, !3} +!llvm.ident = !{!4} + +!0 = distinct !DICompileUnit(language: DW_LANG_C, file: !1, producer: "triton", isOptimized: true, runtimeVersion: 0, emissionKind: LineTablesOnly) +!1 = !DIFile(filename: "cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py", directory: "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm") +!2 = !{i32 2, !"Debug Info Version", i32 3} +!3 = !{i32 4, !"nvvm-reflect-ftz", i32 1} +!4 = !{!"clang version 3.8.0 (tags/RELEASE_380/final)"} +!5 = distinct !DISubprogram(name: "triton_tem_fused_zeros_1", linkageName: "triton_tem_fused_zeros_1", scope: !1, file: !1, line: 18, type: !6, scopeLine: 18, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !0) +!6 = !DISubroutineType(cc: DW_CC_normal, types: !7) +!7 = !{} +!8 = !DILocation(line: 111, column: 24, scope: !5) +!9 = !DILocation(line: 115, column: 27, scope: !5) +!10 = !DILocation(line: 116, column: 28, scope: !5) +!11 = !DILocation(line: 117, column: 23, scope: !5) +!12 = !DILocation(line: 124, column: 25, scope: !5) +!13 = !DILocation(line: 124, column: 47, scope: !5) +!14 = !DILocation(line: 124, column: 35, scope: !5) +!15 = !DILocation(line: 124, column: 59, scope: !5) +!16 = !DILocation(line: 128, column: 50, scope: !5) +!17 = !DILocation(line: 128, column: 37, scope: !5) +!18 = !DILocation(line: 128, column: 61, scope: !5) +!19 = !DILocation(line: 131, column: 9, scope: !5) +!20 = !DILocation(line: 132, column: 9, scope: !5) +!21 = !DILocation(line: 133, column: 10, scope: !5) +!22 = !DILocation(line: 136, column: 26, scope: !5) +!23 = !DILocation(line: 139, column: 14, scope: !5) +!24 = !DILocation(line: 139, column: 7, scope: !5) +!25 = !DILocation(line: 140, column: 24, scope: !5) +!26 = !DILocation(line: 144, column: 29, scope: !5) +!27 = !DILocation(line: 144, column: 54, scope: !5) +!28 = !DILocation(line: 144, column: 44, scope: !5) +!29 = !DILocation(line: 145, column: 35, scope: !5) +!30 = !DILocation(line: 154, column: 55, scope: !5) +!31 = !DILocation(line: 154, column: 78, scope: !5) +!32 = !DILocation(line: 155, column: 50, scope: !5) +!33 = !DILocation(line: 155, column: 83, scope: !5) +!34 = !DILocation(line: 155, column: 68, scope: !5) +!35 = !DILocation(line: 158, column: 30, scope: !5) +!36 = !DILocation(line: 158, column: 52, scope: !5) +!37 = !DILocation(line: 158, column: 40, scope: !5) +!38 = !DILocation(line: 158, column: 63, scope: !5) +!39 = !DILocation(line: 159, column: 32, scope: !5) +!40 = !DILocation(line: 159, column: 42, scope: !5) +!41 = !DILocation(line: 159, column: 66, scope: !5) +!42 = !DILocation(line: 161, column: 46, scope: !5) +!43 = !DILocation(line: 161, column: 56, scope: !5) +!44 = !DILocation(line: 163, column: 17, scope: !5) +!45 = !DILocation(line: 164, column: 19, scope: !5) +!46 = !DILocation(line: 167, column: 19, scope: !5) +!47 = !DILocation(line: 168, column: 21, scope: !5) +!48 = !DILocation(line: 169, column: 25, scope: !5) +!49 = !DILocation(line: 174, column: 36, scope: !5) +!50 = !DILocation(line: 175, column: 29, scope: !5) +!51 = !DILocation(line: 825, column: 38, scope: !52, inlinedAt: !53) +!52 = distinct !DILexicalBlockFile(scope: !5, file: !1, discriminator: 0) +!53 = !DILocation(line: 178, column: 107, scope: !5) +!54 = !DILocation(line: 825, column: 20, scope: !52, inlinedAt: !53) +!55 = !DILocation(line: 825, column: 56, scope: !52, inlinedAt: !53) +!56 = !DILocation(line: 825, column: 49, scope: !52, inlinedAt: !53) +!57 = !DILocation(line: 835, column: 23, scope: !52, inlinedAt: !53) +!58 = !DILocation(line: 825, column: 38, scope: !52, inlinedAt: !59) +!59 = !DILocation(line: 179, column: 111, scope: !5) +!60 = !DILocation(line: 825, column: 20, scope: !52, inlinedAt: !59) +!61 = !DILocation(line: 825, column: 49, scope: !52, inlinedAt: !59) +!62 = !DILocation(line: 835, column: 23, scope: !52, inlinedAt: !59) +!63 = !DILocation(line: 185, column: 34, scope: !5) +!64 = !DILocation(line: 185, column: 25, scope: !5) +!65 = !DILocation(line: 186, column: 33, scope: !5) +!66 = !DILocation(line: 186, column: 26, scope: !5) +!67 = !DILocation(line: 190, column: 30, scope: !5) +!68 = !DILocation(line: 190, column: 50, scope: !5) +!69 = !DILocation(line: 195, column: 30, scope: !5) +!70 = !DILocation(line: 196, column: 27, scope: !5) +!71 = !DILocation(line: 196, column: 41, scope: !5) +!72 = !DILocation(line: 197, column: 53, scope: !5) +!73 = !DILocation(line: 197, column: 39, scope: !5) +!74 = !DILocation(line: 199, column: 42, scope: !5) +!75 = !DILocation(line: 199, column: 29, scope: !5) +!76 = !DILocation(line: 390, column: 37, scope: !52, inlinedAt: !77) +!77 = !DILocation(line: 207, column: 12, scope: !5) +!78 = !DILocation(line: 390, column: 18, scope: !52, inlinedAt: !77) +!79 = !DILocation(line: 390, column: 49, scope: !52, inlinedAt: !77) +!80 = !DILocation(line: 391, column: 18, scope: !52, inlinedAt: !77) +!81 = !DILocation(line: 391, column: 49, scope: !52, inlinedAt: !77) +!82 = !DILocation(line: 395, column: 43, scope: !52, inlinedAt: !77) +!83 = !DILocation(line: 485, column: 34, scope: !52, inlinedAt: !77) +!84 = !DILocation(line: 397, column: 28, scope: !52, inlinedAt: !77) +!85 = !DILocation(line: 485, column: 23, scope: !52, inlinedAt: !77) +!86 = !DILocation(line: 488, column: 23, scope: !52, inlinedAt: !77) +!87 = !DILocation(line: 835, column: 23, scope: !52, inlinedAt: !77) +!88 = !DILocation(line: 414, column: 19, scope: !52, inlinedAt: !77) +!89 = !DILocation(line: 415, column: 19, scope: !52, inlinedAt: !77) +!90 = !DILocation(line: 459, column: 19, scope: !52, inlinedAt: !77) +!91 = !DILocation(line: 395, column: 63, scope: !52, inlinedAt: !77) +!92 = !DILocation(line: 504, column: 24, scope: !52, inlinedAt: !77) +!93 = !DILocation(line: 461, column: 14, scope: !52, inlinedAt: !77) +!94 = !DILocation(line: 482, column: 23, scope: !52, inlinedAt: !77) +!95 = !DILocation(line: 494, column: 24, scope: !52, inlinedAt: !77) +!96 = !DILocation(line: 490, column: 23, scope: !52, inlinedAt: !77) +!97 = !DILocation(line: 493, column: 24, scope: !52, inlinedAt: !77) +!98 = !DILocation(line: 496, column: 25, scope: !52, inlinedAt: !77) +!99 = !DILocation(line: 497, column: 92, scope: !52, inlinedAt: !77) +!100 = !DILocation(line: 503, column: 25, scope: !52, inlinedAt: !77) +!101 = !DILocation(line: 500, column: 24, scope: !52, inlinedAt: !77) +!102 = !DILocation(line: 501, column: 24, scope: !52, inlinedAt: !77) +!103 = !DILocation(line: 502, column: 39, scope: !52, inlinedAt: !77) +!104 = !DILocation(line: 505, column: 24, scope: !52, inlinedAt: !77) +!105 = !DILocation(line: 506, column: 23, scope: !52, inlinedAt: !77) +!106 = !DILocation(line: 513, column: 39, scope: !52, inlinedAt: !77) +!107 = !DILocation(line: 514, column: 25, scope: !52, inlinedAt: !77) +!108 = !DILocation(line: 515, column: 24, scope: !52, inlinedAt: !77) +!109 = !DILocation(line: 516, column: 24, scope: !52, inlinedAt: !77) +!110 = !DILocation(line: 524, column: 27, scope: !52, inlinedAt: !77) +!111 = !DILocation(line: 521, column: 69, scope: !52, inlinedAt: !77) +!112 = !DILocation(line: 525, column: 39, scope: !52, inlinedAt: !77) +!113 = !DILocation(line: 525, column: 21, scope: !52, inlinedAt: !77) +!114 = !DILocation(line: 530, column: 20, scope: !52, inlinedAt: !77) +!115 = !DILocation(line: 531, column: 19, scope: !52, inlinedAt: !77) +!116 = !DILocation(line: 531, column: 14, scope: !52, inlinedAt: !77) +!117 = !DILocation(line: 551, column: 15, scope: !52, inlinedAt: !77) +!118 = !DILocation(line: 549, column: 43, scope: !52, inlinedAt: !77) +!119 = !DILocation(line: 553, column: 21, scope: !52, inlinedAt: !77) +!120 = !DILocation(line: 417, column: 19, scope: !52, inlinedAt: !77) +!121 = !DILocation(line: 788, column: 33, scope: !52, inlinedAt: !77) +!122 = !DILocation(line: 789, column: 38, scope: !52, inlinedAt: !77) +!123 = !DILocation(line: 789, column: 24, scope: !52, inlinedAt: !77) +!124 = !DILocation(line: 790, column: 109, scope: !52, inlinedAt: !77) +!125 = !DILocation(line: 790, column: 113, scope: !52, inlinedAt: !77) +!126 = !DILocation(line: 790, column: 55, scope: !52, inlinedAt: !77) +!127 = !DILocation(line: 790, column: 25, scope: !52, inlinedAt: !77) +!128 = !DILocation(line: 791, column: 35, scope: !52, inlinedAt: !77) +!129 = !DILocation(line: 792, column: 34, scope: !52, inlinedAt: !77) +!130 = !DILocation(line: 792, column: 48, scope: !52, inlinedAt: !77) +!131 = !DILocation(line: 792, column: 63, scope: !52, inlinedAt: !77) +!132 = !DILocation(line: 793, column: 29, scope: !52, inlinedAt: !77) +!133 = !DILocation(line: 793, column: 61, scope: !52, inlinedAt: !77) +!134 = !DILocation(line: 793, column: 42, scope: !52, inlinedAt: !77) +!135 = !DILocation(line: 414, column: 28, scope: !52, inlinedAt: !77) +!136 = !DILocation(line: 214, column: 39, scope: !5) +!137 = !DILocation(line: 215, column: 31, scope: !5) +!138 = !DILocation(line: 215, column: 45, scope: !5) +!139 = !DILocation(line: 216, column: 62, scope: !5) +!140 = !DILocation(line: 216, column: 43, scope: !5) +!141 = !DILocation(line: 218, column: 33, scope: !5) +!142 = !DILocation(line: 390, column: 37, scope: !52, inlinedAt: !143) +!143 = !DILocation(line: 226, column: 16, scope: !5) +!144 = !DILocation(line: 390, column: 18, scope: !52, inlinedAt: !143) +!145 = !DILocation(line: 390, column: 49, scope: !52, inlinedAt: !143) +!146 = !DILocation(line: 391, column: 18, scope: !52, inlinedAt: !143) +!147 = !DILocation(line: 391, column: 49, scope: !52, inlinedAt: !143) +!148 = !DILocation(line: 395, column: 43, scope: !52, inlinedAt: !143) +!149 = !DILocation(line: 397, column: 28, scope: !52, inlinedAt: !143) +!150 = !DILocation(line: 835, column: 23, scope: !52, inlinedAt: !143) +!151 = !DILocation(line: 414, column: 19, scope: !52, inlinedAt: !143) +!152 = !DILocation(line: 415, column: 19, scope: !52, inlinedAt: !143) +!153 = !DILocation(line: 459, column: 19, scope: !52, inlinedAt: !143) +!154 = !DILocation(line: 395, column: 63, scope: !52, inlinedAt: !143) +!155 = !DILocation(line: 531, column: 19, scope: !52, inlinedAt: !143) +!156 = !DILocation(line: 461, column: 14, scope: !52, inlinedAt: !143) +!157 = !DILocation(line: 524, column: 27, scope: !52, inlinedAt: !143) +!158 = !DILocation(line: 525, column: 39, scope: !52, inlinedAt: !143) +!159 = !DILocation(line: 525, column: 21, scope: !52, inlinedAt: !143) +!160 = !DILocation(line: 530, column: 20, scope: !52, inlinedAt: !143) +!161 = !DILocation(line: 531, column: 14, scope: !52, inlinedAt: !143) +!162 = !DILocation(line: 551, column: 15, scope: !52, inlinedAt: !143) +!163 = !DILocation(line: 553, column: 21, scope: !52, inlinedAt: !143) +!164 = !DILocation(line: 788, column: 33, scope: !52, inlinedAt: !143) +!165 = !DILocation(line: 789, column: 38, scope: !52, inlinedAt: !143) +!166 = !DILocation(line: 789, column: 24, scope: !52, inlinedAt: !143) +!167 = !DILocation(line: 790, column: 109, scope: !52, inlinedAt: !143) +!168 = !DILocation(line: 790, column: 113, scope: !52, inlinedAt: !143) +!169 = !DILocation(line: 790, column: 55, scope: !52, inlinedAt: !143) +!170 = !DILocation(line: 790, column: 25, scope: !52, inlinedAt: !143) +!171 = !DILocation(line: 791, column: 35, scope: !52, inlinedAt: !143) +!172 = !DILocation(line: 792, column: 34, scope: !52, inlinedAt: !143) +!173 = !DILocation(line: 792, column: 48, scope: !52, inlinedAt: !143) +!174 = !DILocation(line: 792, column: 63, scope: !52, inlinedAt: !143) +!175 = !DILocation(line: 414, column: 28, scope: !52, inlinedAt: !143) +!176 = !DILocation(line: 793, column: 29, scope: !52, inlinedAt: !143) +!177 = !DILocation(line: 231, column: 24, scope: !5) +!178 = !DILocation(line: 231, column: 56, scope: !5) +!179 = !DILocation(line: 232, column: 14, scope: !5) +!180 = !DILocation(line: 234, column: 30, scope: !5) +!181 = !DILocation(line: 252, column: 25, scope: !5) +!182 = !DILocation(line: 253, column: 29, scope: !5) +!183 = !DILocation(line: 825, column: 38, scope: !52, inlinedAt: !184) +!184 = !DILocation(line: 256, column: 107, scope: !5) +!185 = !DILocation(line: 825, column: 20, scope: !52, inlinedAt: !184) +!186 = !DILocation(line: 825, column: 56, scope: !52, inlinedAt: !184) +!187 = !DILocation(line: 825, column: 49, scope: !52, inlinedAt: !184) +!188 = !DILocation(line: 835, column: 23, scope: !52, inlinedAt: !184) +!189 = !DILocation(line: 825, column: 20, scope: !52, inlinedAt: !190) +!190 = !DILocation(line: 257, column: 107, scope: !5) +!191 = !DILocation(line: 825, column: 49, scope: !52, inlinedAt: !190) +!192 = !DILocation(line: 835, column: 23, scope: !52, inlinedAt: !190) +!193 = !DILocation(line: 263, column: 32, scope: !5) +!194 = !DILocation(line: 266, column: 56, scope: !5) +!195 = !DILocation(line: 281, column: 58, scope: !5) +!196 = !DILocation(line: 281, column: 80, scope: !5) +!197 = !DILocation(line: 282, column: 53, scope: !5) +!198 = !DILocation(line: 282, column: 81, scope: !5) +!199 = !DILocation(line: 282, column: 70, scope: !5) +!200 = !DILocation(line: 286, column: 32, scope: !5) +!201 = !DILocation(line: 287, column: 30, scope: !5) +!202 = !DILocation(line: 287, column: 43, scope: !5) +!203 = !DILocation(line: 288, column: 55, scope: !5) +!204 = !DILocation(line: 288, column: 42, scope: !5) +!205 = !DILocation(line: 290, column: 45, scope: !5) +!206 = !DILocation(line: 290, column: 32, scope: !5) +!207 = !DILocation(line: 601, column: 37, scope: !52, inlinedAt: !208) +!208 = !DILocation(line: 298, column: 16, scope: !5) +!209 = !DILocation(line: 602, column: 38, scope: !52, inlinedAt: !208) +!210 = !DILocation(line: 608, column: 42, scope: !52, inlinedAt: !208) +!211 = !DILocation(line: 608, column: 61, scope: !52, inlinedAt: !208) +!212 = !DILocation(line: 701, column: 35, scope: !52, inlinedAt: !208) +!213 = !DILocation(line: 610, column: 28, scope: !52, inlinedAt: !208) +!214 = !DILocation(line: 701, column: 24, scope: !52, inlinedAt: !208) +!215 = !DILocation(line: 306, column: 41, scope: !5) +!216 = !DILocation(line: 307, column: 34, scope: !5) +!217 = !DILocation(line: 307, column: 47, scope: !5) +!218 = !DILocation(line: 308, column: 64, scope: !5) +!219 = !DILocation(line: 308, column: 46, scope: !5) +!220 = !DILocation(line: 310, column: 36, scope: !5) +!221 = !DILocation(line: 601, column: 37, scope: !52, inlinedAt: !222) +!222 = !DILocation(line: 318, column: 20, scope: !5) +!223 = !DILocation(line: 602, column: 38, scope: !52, inlinedAt: !222) +!224 = !DILocation(line: 608, column: 42, scope: !52, inlinedAt: !222) +!225 = !DILocation(line: 608, column: 61, scope: !52, inlinedAt: !222) +!226 = !DILocation(line: 676, column: 20, scope: !52, inlinedAt: !208) +!227 = !DILocation(line: 262, column: 30, scope: !5) +!228 = !DILocation(line: 263, column: 51, scope: !5) +!229 = !DILocation(line: 266, column: 44, scope: !5) +!230 = !DILocation(line: 266, column: 67, scope: !5) +!231 = !DILocation(line: 267, column: 36, scope: !5) +!232 = !DILocation(line: 267, column: 46, scope: !5) +!233 = !DILocation(line: 267, column: 70, scope: !5) +!234 = !DILocation(line: 269, column: 50, scope: !5) +!235 = !DILocation(line: 269, column: 60, scope: !5) +!236 = !DILocation(line: 271, column: 21, scope: !5) +!237 = !DILocation(line: 272, column: 23, scope: !5) +!238 = !DILocation(line: 275, column: 25, scope: !5) +!239 = !DILocation(line: 276, column: 29, scope: !5) +!240 = !DILocation(line: 601, column: 18, scope: !52, inlinedAt: !208) +!241 = !DILocation(line: 601, column: 49, scope: !52, inlinedAt: !208) +!242 = !DILocation(line: 602, column: 19, scope: !52, inlinedAt: !208) +!243 = !DILocation(line: 602, column: 51, scope: !52, inlinedAt: !208) +!244 = !DILocation(line: 835, column: 23, scope: !52, inlinedAt: !208) +!245 = !DILocation(line: 672, column: 28, scope: !52, inlinedAt: !208) +!246 = !DILocation(line: 672, column: 22, scope: !52, inlinedAt: !208) +!247 = !DILocation(line: 746, column: 29, scope: !52, inlinedAt: !208) +!248 = !DILocation(line: 746, column: 21, scope: !52, inlinedAt: !208) +!249 = !DILocation(line: 626, column: 19, scope: !52, inlinedAt: !208) +!250 = !DILocation(line: 627, column: 19, scope: !52, inlinedAt: !208) +!251 = !DILocation(line: 675, column: 26, scope: !52, inlinedAt: !208) +!252 = !DILocation(line: 675, column: 46, scope: !52, inlinedAt: !208) +!253 = !DILocation(line: 678, column: 15, scope: !52, inlinedAt: !208) +!254 = !DILocation(line: 698, column: 25, scope: !52, inlinedAt: !208) +!255 = !DILocation(line: 704, column: 24, scope: !52, inlinedAt: !208) +!256 = !DILocation(line: 706, column: 24, scope: !52, inlinedAt: !208) +!257 = !DILocation(line: 739, column: 27, scope: !52, inlinedAt: !208) +!258 = !DILocation(line: 736, column: 69, scope: !52, inlinedAt: !208) +!259 = !DILocation(line: 740, column: 40, scope: !52, inlinedAt: !208) +!260 = !DILocation(line: 740, column: 22, scope: !52, inlinedAt: !208) +!261 = !DILocation(line: 744, column: 24, scope: !52, inlinedAt: !208) +!262 = !DILocation(line: 744, column: 43, scope: !52, inlinedAt: !208) +!263 = !DILocation(line: 750, column: 20, scope: !52, inlinedAt: !208) +!264 = !DILocation(line: 751, column: 22, scope: !52, inlinedAt: !208) +!265 = !DILocation(line: 751, column: 16, scope: !52, inlinedAt: !208) +!266 = !DILocation(line: 775, column: 24, scope: !52, inlinedAt: !208) +!267 = !DILocation(line: 773, column: 45, scope: !52, inlinedAt: !208) +!268 = !DILocation(line: 775, column: 43, scope: !52, inlinedAt: !208) +!269 = !DILocation(line: 628, column: 19, scope: !52, inlinedAt: !208) +!270 = !DILocation(line: 788, column: 33, scope: !52, inlinedAt: !208) +!271 = !DILocation(line: 789, column: 38, scope: !52, inlinedAt: !208) +!272 = !DILocation(line: 789, column: 24, scope: !52, inlinedAt: !208) +!273 = !DILocation(line: 790, column: 109, scope: !52, inlinedAt: !208) +!274 = !DILocation(line: 790, column: 113, scope: !52, inlinedAt: !208) +!275 = !DILocation(line: 790, column: 55, scope: !52, inlinedAt: !208) +!276 = !DILocation(line: 790, column: 25, scope: !52, inlinedAt: !208) +!277 = !DILocation(line: 791, column: 35, scope: !52, inlinedAt: !208) +!278 = !DILocation(line: 792, column: 34, scope: !52, inlinedAt: !208) +!279 = !DILocation(line: 792, column: 48, scope: !52, inlinedAt: !208) +!280 = !DILocation(line: 792, column: 63, scope: !52, inlinedAt: !208) +!281 = !DILocation(line: 793, column: 29, scope: !52, inlinedAt: !208) +!282 = !DILocation(line: 793, column: 61, scope: !52, inlinedAt: !208) +!283 = !DILocation(line: 793, column: 42, scope: !52, inlinedAt: !208) +!284 = !DILocation(line: 626, column: 28, scope: !52, inlinedAt: !208) +!285 = !DILocation(line: 627, column: 28, scope: !52, inlinedAt: !208) +!286 = !DILocation(line: 601, column: 18, scope: !52, inlinedAt: !222) +!287 = !DILocation(line: 601, column: 49, scope: !52, inlinedAt: !222) +!288 = !DILocation(line: 602, column: 19, scope: !52, inlinedAt: !222) +!289 = !DILocation(line: 602, column: 51, scope: !52, inlinedAt: !222) +!290 = !DILocation(line: 835, column: 23, scope: !52, inlinedAt: !222) +!291 = !DILocation(line: 672, column: 28, scope: !52, inlinedAt: !222) +!292 = !DILocation(line: 672, column: 22, scope: !52, inlinedAt: !222) +!293 = !DILocation(line: 746, column: 29, scope: !52, inlinedAt: !222) +!294 = !DILocation(line: 746, column: 21, scope: !52, inlinedAt: !222) +!295 = !DILocation(line: 626, column: 19, scope: !52, inlinedAt: !222) +!296 = !DILocation(line: 627, column: 19, scope: !52, inlinedAt: !222) +!297 = !DILocation(line: 610, column: 28, scope: !52, inlinedAt: !222) +!298 = !DILocation(line: 675, column: 26, scope: !52, inlinedAt: !222) +!299 = !DILocation(line: 675, column: 46, scope: !52, inlinedAt: !222) +!300 = !DILocation(line: 676, column: 20, scope: !52, inlinedAt: !222) +!301 = !DILocation(line: 678, column: 15, scope: !52, inlinedAt: !222) +!302 = !DILocation(line: 739, column: 27, scope: !52, inlinedAt: !222) +!303 = !DILocation(line: 740, column: 40, scope: !52, inlinedAt: !222) +!304 = !DILocation(line: 740, column: 22, scope: !52, inlinedAt: !222) +!305 = !DILocation(line: 744, column: 24, scope: !52, inlinedAt: !222) +!306 = !DILocation(line: 744, column: 43, scope: !52, inlinedAt: !222) +!307 = !DILocation(line: 750, column: 20, scope: !52, inlinedAt: !222) +!308 = !DILocation(line: 751, column: 22, scope: !52, inlinedAt: !222) +!309 = !DILocation(line: 751, column: 16, scope: !52, inlinedAt: !222) +!310 = !DILocation(line: 775, column: 24, scope: !52, inlinedAt: !222) +!311 = !DILocation(line: 775, column: 43, scope: !52, inlinedAt: !222) +!312 = !DILocation(line: 788, column: 33, scope: !52, inlinedAt: !222) +!313 = !DILocation(line: 789, column: 38, scope: !52, inlinedAt: !222) +!314 = !DILocation(line: 789, column: 24, scope: !52, inlinedAt: !222) +!315 = !DILocation(line: 790, column: 109, scope: !52, inlinedAt: !222) +!316 = !DILocation(line: 790, column: 113, scope: !52, inlinedAt: !222) +!317 = !DILocation(line: 790, column: 55, scope: !52, inlinedAt: !222) +!318 = !DILocation(line: 790, column: 25, scope: !52, inlinedAt: !222) +!319 = !DILocation(line: 791, column: 35, scope: !52, inlinedAt: !222) +!320 = !DILocation(line: 792, column: 34, scope: !52, inlinedAt: !222) +!321 = !DILocation(line: 792, column: 48, scope: !52, inlinedAt: !222) +!322 = !DILocation(line: 792, column: 63, scope: !52, inlinedAt: !222) +!323 = !DILocation(line: 793, column: 29, scope: !52, inlinedAt: !222) +!324 = !DILocation(line: 793, column: 61, scope: !52, inlinedAt: !222) +!325 = !DILocation(line: 793, column: 42, scope: !52, inlinedAt: !222) +!326 = !DILocation(line: 626, column: 28, scope: !52, inlinedAt: !222) +!327 = !DILocation(line: 627, column: 28, scope: !52, inlinedAt: !222) +!328 = !DILocation(line: 628, column: 19, scope: !52, inlinedAt: !222) +!329 = !DILocation(line: 323, column: 23, scope: !5) +!330 = !DILocation(line: 323, column: 55, scope: !5) +!331 = !DILocation(line: 330, column: 30, scope: !5) +!332 = !DILocation(line: 334, column: 14, scope: !5) +!333 = !DILocation(line: 344, column: 58, scope: !5) +!334 = !DILocation(line: 345, column: 29, scope: !5) +!335 = !DILocation(line: 345, column: 69, scope: !5) +!336 = !DILocation(line: 139, column: 4, scope: !5) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/HULI6VL32CI7Q6UGP75BIUMHXEMSAWXMPR3LDJJLGOCFNMUFGPVQ/triton_tem_fused_zeros_1.source b/SpecForge-ext/cache/compiled_kernels/triton/0/HULI6VL32CI7Q6UGP75BIUMHXEMSAWXMPR3LDJJLGOCFNMUFGPVQ/triton_tem_fused_zeros_1.source new file mode 100644 index 0000000000000000000000000000000000000000..ec5f4ac3e3e9065d8f48425b2c076c8cdcd75226 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/HULI6VL32CI7Q6UGP75BIUMHXEMSAWXMPR3LDJJLGOCFNMUFGPVQ/triton_tem_fused_zeros_1.source @@ -0,0 +1,2072 @@ +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":18:0) +#loc201 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":32:0) +#loc211 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":812:0) +#loc221 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":348:0) +#loc253 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":423:0) +#loc318 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":797:0) +#loc321 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":781:0) +#loc342 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":559:0) +#loc374 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":634:0) +#loc445 = loc("arg_Q"(#loc)) +#loc446 = loc("arg_K"(#loc)) +#loc447 = loc("arg_V"(#loc)) +#loc448 = loc("arg_LSE"(#loc)) +#loc449 = loc("arg_DELTA"(#loc)) +#loc450 = loc("arg_DO"(#loc)) +#loc451 = loc("arg_DQ"(#loc)) +#loc452 = loc("arg_DV"(#loc)) +#loc453 = loc("arg_KV_NUM_BLKS"(#loc)) +#loc454 = loc("arg_KV_IDX"(#loc)) +#loc455 = loc("arg_Q_NUM_BLKS"(#loc)) +#loc456 = loc("arg_Q_IDX"(#loc)) +#loc457 = loc("arg_FULL_KV_NUM_BLKS"(#loc)) +#loc458 = loc("arg_FULL_KV_IDX"(#loc)) +#loc459 = loc("arg_FULL_Q_NUM_BLKS"(#loc)) +#loc460 = loc("arg_FULL_Q_IDX"(#loc)) +#loc461 = loc("in_ptr16"(#loc)) +#loc462 = loc("out_ptr0"(#loc)) +#loc647 = loc("x"(#loc201)) +#loc648 = loc("ptr"(#loc211)) +#loc649 = loc("offs_m"(#loc211)) +#loc650 = loc("offs_n"(#loc211)) +#loc651 = loc("stride_m"(#loc211)) +#loc652 = loc("stride_n"(#loc211)) +#loc653 = loc("M_LEN"(#loc211)) +#loc660 = loc("arg_Q"(#loc221)) +#loc661 = loc("arg_K"(#loc221)) +#loc662 = loc("arg_V"(#loc221)) +#loc663 = loc("arg_LSE"(#loc221)) +#loc664 = loc("arg_DELTA"(#loc221)) +#loc665 = loc("arg_DO"(#loc221)) +#loc666 = loc("arg_DQ"(#loc221)) +#loc667 = loc("arg_DV"(#loc221)) +#loc668 = loc("arg_KV_NUM_BLKS"(#loc221)) +#loc669 = loc("arg_KV_IDX"(#loc221)) +#loc670 = loc("arg_Q_NUM_BLKS"(#loc221)) +#loc671 = loc("arg_Q_IDX"(#loc221)) +#loc672 = loc("arg_FULL_KV_NUM_BLKS"(#loc221)) +#loc673 = loc("arg_FULL_KV_IDX"(#loc221)) +#loc674 = loc("arg_FULL_Q_NUM_BLKS"(#loc221)) +#loc675 = loc("arg_FULL_Q_IDX"(#loc221)) +#loc676 = loc("in_ptr16"(#loc221)) +#loc677 = loc("out_ptr0"(#loc221)) +#loc678 = loc("K"(#loc221)) +#loc679 = loc("V"(#loc221)) +#loc680 = loc("dq"(#loc221)) +#loc681 = loc("q"(#loc221)) +#loc682 = loc("do"(#loc221)) +#loc683 = loc("Di"(#loc221)) +#loc684 = loc("lse"(#loc221)) +#loc685 = loc("off_z"(#loc221)) +#loc686 = loc("off_hq"(#loc221)) +#loc687 = loc("offs_m2"(#loc221)) +#loc688 = loc("offs_n2"(#loc221)) +#loc689 = loc("stride_kn"(#loc221)) +#loc690 = loc("stride_kd"(#loc221)) +#loc691 = loc("stride_vn"(#loc221)) +#loc692 = loc("stride_vd"(#loc221)) +#loc693 = loc("kv_indices"(#loc221)) +#loc694 = loc("sparse_kv_num_blocks"(#loc221)) +#loc723 = loc("arg_Q"(#loc253)) +#loc724 = loc("arg_K"(#loc253)) +#loc725 = loc("arg_V"(#loc253)) +#loc726 = loc("arg_LSE"(#loc253)) +#loc727 = loc("arg_DELTA"(#loc253)) +#loc728 = loc("arg_DO"(#loc253)) +#loc729 = loc("arg_DQ"(#loc253)) +#loc730 = loc("arg_DV"(#loc253)) +#loc731 = loc("arg_KV_NUM_BLKS"(#loc253)) +#loc732 = loc("arg_KV_IDX"(#loc253)) +#loc733 = loc("arg_Q_NUM_BLKS"(#loc253)) +#loc734 = loc("arg_Q_IDX"(#loc253)) +#loc735 = loc("arg_FULL_KV_NUM_BLKS"(#loc253)) +#loc736 = loc("arg_FULL_KV_IDX"(#loc253)) +#loc737 = loc("arg_FULL_Q_NUM_BLKS"(#loc253)) +#loc738 = loc("arg_FULL_Q_IDX"(#loc253)) +#loc739 = loc("in_ptr16"(#loc253)) +#loc740 = loc("out_ptr0"(#loc253)) +#loc741 = loc("dq"(#loc253)) +#loc742 = loc("q"(#loc253)) +#loc743 = loc("kT_ptrs"(#loc253)) +#loc744 = loc("vT_ptrs"(#loc253)) +#loc745 = loc("do"(#loc253)) +#loc746 = loc("Di"(#loc253)) +#loc747 = loc("lse"(#loc253)) +#loc748 = loc("Q_LEN"(#loc253)) +#loc749 = loc("KV_LEN"(#loc253)) +#loc750 = loc("off_z"(#loc253)) +#loc751 = loc("off_hq"(#loc253)) +#loc752 = loc("offs_m2"(#loc253)) +#loc753 = loc("offs_n2"(#loc253)) +#loc754 = loc("offs_k"(#loc253)) +#loc755 = loc("offs_v"(#loc253)) +#loc756 = loc("stride_kn"(#loc253)) +#loc757 = loc("stride_kd"(#loc253)) +#loc758 = loc("stride_vn"(#loc253)) +#loc759 = loc("stride_vd"(#loc253)) +#loc760 = loc("kv_indices"(#loc253)) +#loc761 = loc("sparse_kv_num_blocks"(#loc253)) +#loc824 = loc("N_LEN"(#loc211)) +#loc825 = loc("indices"(#loc318)) +#loc826 = loc("loop_iter"(#loc321)) +#loc827 = loc("col_indices"(#loc321)) +#loc828 = loc("total_blocks"(#loc321)) +#loc847 = loc("arg_Q"(#loc342)) +#loc848 = loc("arg_K"(#loc342)) +#loc849 = loc("arg_V"(#loc342)) +#loc850 = loc("arg_LSE"(#loc342)) +#loc851 = loc("arg_DELTA"(#loc342)) +#loc852 = loc("arg_DO"(#loc342)) +#loc853 = loc("arg_DQ"(#loc342)) +#loc854 = loc("arg_DV"(#loc342)) +#loc855 = loc("arg_KV_NUM_BLKS"(#loc342)) +#loc856 = loc("arg_KV_IDX"(#loc342)) +#loc857 = loc("arg_Q_NUM_BLKS"(#loc342)) +#loc858 = loc("arg_Q_IDX"(#loc342)) +#loc859 = loc("arg_FULL_KV_NUM_BLKS"(#loc342)) +#loc860 = loc("arg_FULL_KV_IDX"(#loc342)) +#loc861 = loc("arg_FULL_Q_NUM_BLKS"(#loc342)) +#loc862 = loc("arg_FULL_Q_IDX"(#loc342)) +#loc863 = loc("in_ptr16"(#loc342)) +#loc864 = loc("out_ptr0"(#loc342)) +#loc865 = loc("Q"(#loc342)) +#loc866 = loc("DO"(#loc342)) +#loc867 = loc("DELTA"(#loc342)) +#loc868 = loc("LSE"(#loc342)) +#loc869 = loc("dk"(#loc342)) +#loc870 = loc("dv"(#loc342)) +#loc871 = loc("k"(#loc342)) +#loc872 = loc("v"(#loc342)) +#loc873 = loc("off_z"(#loc342)) +#loc874 = loc("off_hq"(#loc342)) +#loc875 = loc("offs_n1"(#loc342)) +#loc876 = loc("offs_m1"(#loc342)) +#loc877 = loc("stride_qm"(#loc342)) +#loc878 = loc("stride_qd"(#loc342)) +#loc879 = loc("stride_dom"(#loc342)) +#loc880 = loc("stride_dod"(#loc342)) +#loc881 = loc("q_indices"(#loc342)) +#loc882 = loc("sparse_q_num_blocks"(#loc342)) +#loc910 = loc("arg_Q"(#loc374)) +#loc911 = loc("arg_K"(#loc374)) +#loc912 = loc("arg_V"(#loc374)) +#loc913 = loc("arg_LSE"(#loc374)) +#loc914 = loc("arg_DELTA"(#loc374)) +#loc915 = loc("arg_DO"(#loc374)) +#loc916 = loc("arg_DQ"(#loc374)) +#loc917 = loc("arg_DV"(#loc374)) +#loc918 = loc("arg_KV_NUM_BLKS"(#loc374)) +#loc919 = loc("arg_KV_IDX"(#loc374)) +#loc920 = loc("arg_Q_NUM_BLKS"(#loc374)) +#loc921 = loc("arg_Q_IDX"(#loc374)) +#loc922 = loc("arg_FULL_KV_NUM_BLKS"(#loc374)) +#loc923 = loc("arg_FULL_KV_IDX"(#loc374)) +#loc924 = loc("arg_FULL_Q_NUM_BLKS"(#loc374)) +#loc925 = loc("arg_FULL_Q_IDX"(#loc374)) +#loc926 = loc("in_ptr16"(#loc374)) +#loc927 = loc("out_ptr0"(#loc374)) +#loc928 = loc("dk"(#loc374)) +#loc929 = loc("dv"(#loc374)) +#loc930 = loc("qT_ptrs"(#loc374)) +#loc931 = loc("k"(#loc374)) +#loc932 = loc("v"(#loc374)) +#loc933 = loc("do_ptrs"(#loc374)) +#loc934 = loc("DELTA"(#loc374)) +#loc935 = loc("LSE"(#loc374)) +#loc936 = loc("Q_LEN"(#loc374)) +#loc937 = loc("KV_LEN"(#loc374)) +#loc938 = loc("off_z"(#loc374)) +#loc939 = loc("off_hq"(#loc374)) +#loc940 = loc("offs_n1"(#loc374)) +#loc941 = loc("offs_m1"(#loc374)) +#loc942 = loc("offs_k"(#loc374)) +#loc943 = loc("offs_v"(#loc374)) +#loc944 = loc("stride_qm"(#loc374)) +#loc945 = loc("stride_qd"(#loc374)) +#loc946 = loc("stride_dom"(#loc374)) +#loc947 = loc("stride_dod"(#loc374)) +#loc948 = loc("q_indices"(#loc374)) +#loc949 = loc("sparse_q_num_blocks"(#loc374)) +module { + tt.func public @triton_tem_fused_zeros_1(%arg_Q: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_Q"(#loc)), %arg_K: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_K"(#loc)), %arg_V: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_V"(#loc)), %arg_LSE: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_LSE"(#loc)), %arg_DELTA: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_DELTA"(#loc)), %arg_DO: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_DO"(#loc)), %arg_DQ: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_DQ"(#loc)), %arg_DV: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_DV"(#loc)), %arg_KV_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_KV_NUM_BLKS"(#loc)), %arg_KV_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_KV_IDX"(#loc)), %arg_Q_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_Q_NUM_BLKS"(#loc)), %arg_Q_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_Q_IDX"(#loc)), %arg_FULL_KV_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_KV_NUM_BLKS"(#loc)), %arg_FULL_KV_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_KV_IDX"(#loc)), %arg_FULL_Q_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_Q_NUM_BLKS"(#loc)), %arg_FULL_Q_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_Q_IDX"(#loc)), %in_ptr16: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr16"(#loc)), %out_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr0"(#loc))) attributes {noinline = false} { + %c8388608_i32 = arith.constant 8388608 : i32 loc(#loc1) + %c128_i32 = arith.constant 128 : i32 loc(#loc1) + %c4096_i32 = arith.constant 4096 : i32 loc(#loc1) + %c1_i32 = arith.constant 1 : i32 loc(#loc1) + %c2097152_i32 = arith.constant 2097152 : i32 loc(#loc2) + %c262144_i32 = arith.constant 262144 : i32 loc(#loc2) + %c128_i32_0 = arith.constant 128 : i32 loc(#loc2) + %c1_i32_1 = arith.constant 1 : i32 loc(#loc2) + %c2097152_i32_2 = arith.constant 2097152 : i32 loc(#loc3) + %c262144_i32_3 = arith.constant 262144 : i32 loc(#loc3) + %c128_i32_4 = arith.constant 128 : i32 loc(#loc3) + %c1_i32_5 = arith.constant 1 : i32 loc(#loc3) + %c8388608_i32_6 = arith.constant 8388608 : i32 loc(#loc4) + %c262144_i32_7 = arith.constant 262144 : i32 loc(#loc4) + %c128_i32_8 = arith.constant 128 : i32 loc(#loc4) + %c1_i32_9 = arith.constant 1 : i32 loc(#loc4) + %c8388608_i32_10 = arith.constant 8388608 : i32 loc(#loc5) + %c128_i32_11 = arith.constant 128 : i32 loc(#loc5) + %c4096_i32_12 = arith.constant 4096 : i32 loc(#loc5) + %c1_i32_13 = arith.constant 1 : i32 loc(#loc5) + %c2097152_i32_14 = arith.constant 2097152 : i32 loc(#loc6) + %c262144_i32_15 = arith.constant 262144 : i32 loc(#loc6) + %c128_i32_16 = arith.constant 128 : i32 loc(#loc6) + %c1_i32_17 = arith.constant 1 : i32 loc(#loc6) + %ZQ = arith.constant 2 : i32 loc(#loc463) + %HQ = arith.constant 32 : i32 loc(#loc464) + %HKV = arith.constant 8 : i32 loc(#loc465) + %Q_LEN = arith.constant 2048 : i32 loc(#loc466) + %ZKV = arith.constant 2 : i32 loc(#loc467) + %KV_LEN = arith.constant 2048 : i32 loc(#loc468) + %pid = tt.get_program_id x : i32 loc(#loc469) + %NUM_KV_BLOCKS = tt.call @"triton.language.standard.cdiv__i32__(1,)cconstexpr_128_"(%KV_LEN) : (i32) -> i32 loc(#loc470) + %NUM_Q_BLOCKS = tt.call @"triton.language.standard.cdiv__i32__(1,)cconstexpr_128_"(%Q_LEN) : (i32) -> i32 loc(#loc471) + %off_zq = tt.get_program_id y : i32 loc(#loc472) + %off_hkv = tt.get_program_id z : i32 loc(#loc473) + %off_zkv = arith.remsi %off_zq, %ZKV : i32 loc(#loc474) + %SPARSE_Z = arith.constant 2 : i32 loc(#loc475) + %SPARSE_HQ = arith.constant 1 : i32 loc(#loc476) + %sparse_idx_z = arith.remsi %off_zq, %SPARSE_Z : i32 loc(#loc477) + %k_adj = arith.muli %c262144_i32, %off_hkv : i32 loc(#loc478) + %k_adj_18 = arith.muli %c2097152_i32, %off_zkv : i32 loc(#loc479) + %k_adj_19 = arith.addi %k_adj, %k_adj_18 : i32 loc(#loc480) + %k_adj_20 = arith.extsi %k_adj_19 : i32 to i64 loc(#loc481) + %v_adj = arith.muli %c262144_i32_3, %off_hkv : i32 loc(#loc482) + %v_adj_21 = arith.muli %c2097152_i32_2, %off_zkv : i32 loc(#loc483) + %v_adj_22 = arith.addi %v_adj, %v_adj_21 : i32 loc(#loc484) + %v_adj_23 = arith.extsi %v_adj_22 : i32 to i64 loc(#loc485) + %dv_adj = arith.muli %c262144_i32_15, %off_hkv : i32 loc(#loc486) + %dv_adj_24 = arith.muli %c2097152_i32_14, %off_zq : i32 loc(#loc487) + %dv_adj_25 = arith.addi %dv_adj, %dv_adj_24 : i32 loc(#loc488) + %dv_adj_26 = arith.extsi %dv_adj_25 : i32 to i64 loc(#loc489) + %K = tt.addptr %arg_K, %k_adj_20 : !tt.ptr, i64 loc(#loc490) + %V = tt.addptr %arg_V, %v_adj_23 : !tt.ptr, i64 loc(#loc491) + %DV = tt.addptr %arg_DV, %dv_adj_26 : !tt.ptr, i64 loc(#loc492) + %RCP_LN2 = arith.constant 1.44269502 : f32 loc(#loc493) + %offs_k = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc494) + %offs_v = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc495) + %0 = arith.cmpi sge, %pid, %NUM_KV_BLOCKS : i32 loc(#loc40) + %1:2 = scf.if %0 -> (i32, i32) { + %off_pid = arith.subi %pid, %NUM_KV_BLOCKS : i32 loc(#loc496) + %SPARSE_Q_MULTIPLE = arith.constant 1 : i32 loc(#loc1018) + %SPARSE_KV_MULTIPLE = arith.constant 2 : i32 loc(#loc1019) + %off_hq2 = arith.divsi %off_pid, %NUM_Q_BLOCKS : i32 loc(#loc499) + %off_hq2_27 = arith.constant 4 : i32 loc(#loc500) + %off_hq2_28 = arith.constant 4 : i32 loc(#loc500) + %off_hq2_29 = arith.muli %off_hkv, %off_hq2_28 : i32 loc(#loc500) + %off_hq2_30 = arith.addi %off_hq2, %off_hq2_29 : i32 loc(#loc501) + %start_m2_block = arith.remsi %off_pid, %NUM_Q_BLOCKS : i32 loc(#loc502) + %off_pid_mask = arith.divsi %start_m2_block, %SPARSE_Q_MULTIPLE : i32 loc(#loc503) + %stride_kv_num_blks_h = arith.constant 16 : i32 loc(#loc504) + %stride_kv_idx_h = arith.constant 256 : i32 loc(#loc505) + %stride_kv_idx_m = arith.constant 16 : i32 loc(#loc506) + %sparse_idx_hq2 = arith.remsi %off_hq2_30, %SPARSE_HQ : i32 loc(#loc507) + %sparse_hz_offset = arith.muli %sparse_idx_z, %SPARSE_HQ : i32 loc(#loc508) + %sparse_hz_offset_31 = arith.addi %sparse_hz_offset, %sparse_idx_hq2 : i32 loc(#loc509) + %sparse_kv_num_blks_offset = arith.muli %sparse_hz_offset_31, %stride_kv_num_blks_h : i32 loc(#loc510) + %sparse_kv_num_blks_offset_32 = arith.addi %sparse_kv_num_blks_offset, %off_pid_mask : i32 loc(#loc511) + %sparse_kv_idx_offset = arith.muli %sparse_hz_offset_31, %stride_kv_idx_h : i32 loc(#loc512) + %sparse_kv_idx_offset_33 = arith.muli %off_pid_mask, %stride_kv_idx_m : i32 loc(#loc513) + %sparse_kv_idx_offset_34 = arith.addi %sparse_kv_idx_offset, %sparse_kv_idx_offset_33 : i32 loc(#loc514) + %q_adj2 = arith.muli %c128_i32, %off_hq2_30 : i32 loc(#loc515) + %q_adj2_35 = arith.muli %c8388608_i32, %off_zq : i32 loc(#loc516) + %q_adj2_36 = arith.addi %q_adj2, %q_adj2_35 : i32 loc(#loc517) + %q_adj2_37 = arith.extsi %q_adj2_36 : i32 to i64 loc(#loc518) + %do_adj2 = arith.muli %c262144_i32_7, %off_hq2_30 : i32 loc(#loc519) + %do_adj2_38 = arith.muli %c8388608_i32_6, %off_zq : i32 loc(#loc520) + %do_adj2_39 = arith.addi %do_adj2, %do_adj2_38 : i32 loc(#loc521) + %do_adj2_40 = arith.extsi %do_adj2_39 : i32 to i64 loc(#loc522) + %dq_adj2 = arith.muli %c128_i32_11, %off_hq2_30 : i32 loc(#loc523) + %dq_adj2_41 = arith.muli %c8388608_i32_10, %off_zq : i32 loc(#loc524) + %dq_adj2_42 = arith.addi %dq_adj2, %dq_adj2_41 : i32 loc(#loc525) + %dq_adj2_43 = arith.extsi %dq_adj2_42 : i32 to i64 loc(#loc526) + %off_chz2 = arith.muli %off_zq, %HQ : i32 loc(#loc527) + %off_chz2_44 = arith.addi %off_chz2, %off_hq2_30 : i32 loc(#loc528) + %off_chz2_45 = arith.muli %off_chz2_44, %Q_LEN : i32 loc(#loc529) + %off_chz2_46 = arith.extsi %off_chz2_45 : i32 to i64 loc(#loc530) + %Q2 = tt.addptr %arg_Q, %q_adj2_37 : !tt.ptr, i64 loc(#loc531) + %DO2 = tt.addptr %arg_DO, %do_adj2_40 : !tt.ptr, i64 loc(#loc532) + %DQ2 = tt.addptr %arg_DQ, %dq_adj2_43 : !tt.ptr, i64 loc(#loc533) + %LSE2 = tt.addptr %arg_LSE, %off_chz2_46 : !tt.ptr, i64 loc(#loc534) + %DELTA2 = tt.addptr %arg_DELTA, %off_chz2_46 : !tt.ptr, i64 loc(#loc535) + %dq = tt.call @"triton.language.standard.zeros____(0, 0)cconstexpr_128__(0, 1)cconstexpr_128__(1,)cconstexpr_fp32_"() : () -> tensor<128x128xf32> loc(#loc536) + %start_m2 = arith.constant 128 : i32 loc(#loc537) + %start_m2_47 = arith.constant 128 : i32 loc(#loc537) + %start_m2_48 = arith.muli %start_m2_block, %start_m2_47 : i32 loc(#loc537) + %offs_m2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc538) + %offs_m2_49 = tt.splat %start_m2_48 : i32 -> tensor<128xi32> loc(#loc539) + %offs_m2_50 = arith.addi %offs_m2_49, %offs_m2 : tensor<128xi32> loc(#loc539) + %q = tt.call @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.load_checked_2d__Pbf16_i32S128S_i32S128S_i32_i32_i32__(5,)cconstexpr_True__(6,)cconstexpr_True__(8,)cconstexpr_128_"(%Q2, %offs_m2_50, %offs_k, %c4096_i32, %c1_i32, %Q_LEN) : (!tt.ptr, tensor<128xi32>, tensor<128xi32>, i32, i32, i32) -> tensor<128x128xbf16> loc(#loc540) + %do = tt.call @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.load_checked_2d__Pbf16_i32S128S_i32S128S_i32_i32_i32__(5,)cconstexpr_True__(6,)cconstexpr_True__(8,)cconstexpr_128_"(%DO2, %offs_m2_50, %offs_v, %c128_i32_8, %c1_i32_9, %Q_LEN) : (!tt.ptr, tensor<128xi32>, tensor<128xi32>, i32, i32, i32) -> tensor<128x128xbf16> loc(#loc541) + %Di = tt.splat %DELTA2 : !tt.ptr -> tensor<128x!tt.ptr> loc(#loc542) + %Di_51 = tt.addptr %Di, %offs_m2_50 : tensor<128x!tt.ptr>, tensor<128xi32> loc(#loc542) + %Di_52 = tt.load %Di_51 : tensor<128x!tt.ptr> loc(#loc543) + %lse = tt.splat %LSE2 : !tt.ptr -> tensor<128x!tt.ptr> loc(#loc544) + %lse_53 = tt.addptr %lse, %offs_m2_50 : tensor<128x!tt.ptr>, tensor<128xi32> loc(#loc544) + %lse_54 = tt.load %lse_53 : tensor<128x!tt.ptr> loc(#loc545) + %lse_55 = arith.constant 0xFF800000 : f32 loc(#loc546) + %lse_56 = arith.constant dense<0xFF800000> : tensor<128xf32> loc(#loc546) + %lse_57 = arith.cmpf oeq, %lse_54, %lse_56 : tensor<128xf32> loc(#loc546) + %lse_58 = arith.constant 0.000000e+00 : f32 loc(#loc547) + %lse_59 = arith.constant 0.000000e+00 : f32 loc(#loc547) + %lse_60 = arith.constant dense<0.000000e+00> : tensor<128xf32> loc(#loc547) + %lse_61 = arith.select %lse_57, %lse_60, %lse_54 : tensor<128xi1>, tensor<128xf32> loc(#loc547) + %lse_62 = tt.expand_dims %lse_61 {axis = 1 : i32} : tensor<128xf32> -> tensor<128x1xf32> loc(#loc548) + %kv_indices = tt.addptr %arg_KV_IDX, %sparse_kv_idx_offset_34 : !tt.ptr, i32 loc(#loc549) + %kv_start = tt.load %kv_indices : !tt.ptr loc(#loc550) + %kv_start_63 = arith.constant 128 : i32 loc(#loc551) + %kv_start_64 = arith.constant 128 : i32 loc(#loc551) + %kv_start_65 = arith.muli %kv_start, %kv_start_64 : i32 loc(#loc551) + %sparse_kv_num_blocks = tt.addptr %arg_KV_NUM_BLKS, %sparse_kv_num_blks_offset_32 : !tt.ptr, i32 loc(#loc552) + %sparse_kv_num_blocks_66 = tt.load %sparse_kv_num_blocks : !tt.ptr loc(#loc553) + %offs_n2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc554) + %offs_n2_67 = tt.splat %kv_start_65 : i32 -> tensor<64xi32> loc(#loc555) + %offs_n2_68 = arith.addi %offs_n2_67, %offs_n2 : tensor<64xi32> loc(#loc555) + %dq_69 = tt.call @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.bwd_dq_inner__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi64_Pbf16_Pbf16_Pbf16_fp32S128_128S_bf16S128_128S_bf16S128_128S_fp32S128S_fp32S128_1S_i32_i32_i32S128S_i32S64S_i32_i32_i32_i32_Pi32_i32__(35,)cconstexpr_bf16__(36,)cconstexpr_False_"(%arg_Q, %arg_K, %arg_V, %arg_LSE, %arg_DELTA, %arg_DO, %arg_DQ, %arg_DV, %arg_KV_NUM_BLKS, %arg_KV_IDX, %arg_Q_NUM_BLKS, %arg_Q_IDX, %arg_FULL_KV_NUM_BLKS, %arg_FULL_KV_IDX, %arg_FULL_Q_NUM_BLKS, %arg_FULL_Q_IDX, %in_ptr16, %out_ptr0, %K, %V, %dq, %q, %do, %Di_52, %lse_62, %off_zq, %off_hq2_30, %offs_m2_50, %offs_n2_68, %c128_i32_0, %c1_i32_1, %c128_i32_4, %c1_i32_5, %kv_indices, %sparse_kv_num_blocks_66) : (!tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, tensor<128x128xf32>, tensor<128x128xbf16>, tensor<128x128xbf16>, tensor<128xf32>, tensor<128x1xf32>, i32, i32, tensor<128xi32>, tensor<64xi32>, i32, i32, i32, i32, !tt.ptr, i32) -> tensor<128x128xf32> loc(#loc556) + %kv_indices_70 = tt.addptr %arg_FULL_KV_IDX, %sparse_kv_idx_offset_34 : !tt.ptr, i32 loc(#loc557) + %kv_start_71 = tt.load %kv_indices_70 : !tt.ptr loc(#loc558) + %kv_start_72 = arith.constant 128 : i32 loc(#loc559) + %kv_start_73 = arith.constant 128 : i32 loc(#loc559) + %kv_start_74 = arith.muli %kv_start_71, %kv_start_73 : i32 loc(#loc559) + %sparse_kv_num_blocks_75 = tt.addptr %arg_FULL_KV_NUM_BLKS, %sparse_kv_num_blks_offset_32 : !tt.ptr, i32 loc(#loc560) + %sparse_kv_num_blocks_76 = tt.load %sparse_kv_num_blocks_75 : !tt.ptr loc(#loc561) + %offs_n2_77 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc562) + %offs_n2_78 = tt.splat %kv_start_74 : i32 -> tensor<64xi32> loc(#loc563) + %offs_n2_79 = arith.addi %offs_n2_78, %offs_n2_77 : tensor<64xi32> loc(#loc563) + %dq_80 = tt.call @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.bwd_dq_inner__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi64_Pbf16_Pbf16_Pbf16_fp32S128_128S_bf16S128_128S_bf16S128_128S_fp32S128S_fp32S128_1S_i32_i32_i32S128S_i32S64S_i32_i32_i32_i32_Pi32_i32__(35,)cconstexpr_bf16__(36,)cconstexpr_True_"(%arg_Q, %arg_K, %arg_V, %arg_LSE, %arg_DELTA, %arg_DO, %arg_DQ, %arg_DV, %arg_KV_NUM_BLKS, %arg_KV_IDX, %arg_Q_NUM_BLKS, %arg_Q_IDX, %arg_FULL_KV_NUM_BLKS, %arg_FULL_KV_IDX, %arg_FULL_Q_NUM_BLKS, %arg_FULL_Q_IDX, %in_ptr16, %out_ptr0, %K, %V, %dq_69, %q, %do, %Di_52, %lse_62, %off_zq, %off_hq2_30, %offs_m2_50, %offs_n2_79, %c128_i32_0, %c1_i32_1, %c128_i32_4, %c1_i32_5, %kv_indices_70, %sparse_kv_num_blocks_76) : (!tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, tensor<128x128xf32>, tensor<128x128xbf16>, tensor<128x128xbf16>, tensor<128xf32>, tensor<128x1xf32>, i32, i32, tensor<128xi32>, tensor<64xi32>, i32, i32, i32, i32, !tt.ptr, i32) -> tensor<128x128xf32> loc(#loc564) + %dq_ptrs = tt.expand_dims %offs_m2_50 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc565) + %dq_ptrs_81 = arith.constant dense<4096> : tensor<128x1xi32> loc(#loc566) + %dq_ptrs_82 = arith.muli %dq_ptrs, %dq_ptrs_81 : tensor<128x1xi32> loc(#loc566) + %dq_ptrs_83 = tt.splat %DQ2 : !tt.ptr -> tensor<128x1x!tt.ptr> loc(#loc567) + %dq_ptrs_84 = tt.addptr %dq_ptrs_83, %dq_ptrs_82 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> loc(#loc567) + %dq_ptrs_85 = tt.expand_dims %offs_k {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc568) + %dq_ptrs_86 = arith.constant dense<1> : tensor<1x128xi32> loc(#loc569) + %dq_ptrs_87 = arith.muli %dq_ptrs_85, %dq_ptrs_86 : tensor<1x128xi32> loc(#loc569) + %dq_ptrs_88 = tt.broadcast %dq_ptrs_84 : tensor<128x1x!tt.ptr> -> tensor<128x128x!tt.ptr> loc(#loc570) + %dq_ptrs_89 = tt.broadcast %dq_ptrs_87 : tensor<1x128xi32> -> tensor<128x128xi32> loc(#loc570) + %dq_ptrs_90 = tt.addptr %dq_ptrs_88, %dq_ptrs_89 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> loc(#loc570) + %dq_91 = arith.constant 0.0883883461 : f32 loc(#loc571) + %dq_92 = arith.constant 0.0883883461 : f32 loc(#loc571) + %dq_93 = arith.constant dense<0.0883883461> : tensor<128x128xf32> loc(#loc571) + %dq_94 = arith.mulf %dq_80, %dq_93 : tensor<128x128xf32> loc(#loc571) + %2 = arith.truncf %dq_94 : tensor<128x128xf32> to tensor<128x128xbf16> loc(#loc118) + tt.store %dq_ptrs_90, %2 : tensor<128x128x!tt.ptr> loc(#loc118) + scf.yield %SPARSE_KV_MULTIPLE, %SPARSE_Q_MULTIPLE : i32, i32 loc(#loc118) + } else { + %SPARSE_Q_MULTIPLE = arith.constant 2 : i32 loc(#loc1020) + %SPARSE_KV_MULTIPLE = arith.constant 1 : i32 loc(#loc1021) + %pid_mask = arith.divsi %pid, %SPARSE_KV_MULTIPLE : i32 loc(#loc574) + %stride_q_num_blks_h = arith.constant 16 : i32 loc(#loc575) + %stride_q_idx_h = arith.constant 256 : i32 loc(#loc576) + %stride_q_idx_n = arith.constant 16 : i32 loc(#loc577) + %dv = tt.call @"triton.language.standard.zeros____(0, 0)cconstexpr_128__(0, 1)cconstexpr_128__(1,)cconstexpr_fp32_"() : () -> tensor<128x128xf32> loc(#loc578) + %dk = tt.call @"triton.language.standard.zeros____(0, 0)cconstexpr_128__(0, 1)cconstexpr_128__(1,)cconstexpr_fp32_"() : () -> tensor<128x128xf32> loc(#loc579) + %start_n1 = arith.constant 128 : i32 loc(#loc580) + %start_n1_27 = arith.constant 128 : i32 loc(#loc580) + %start_n1_28 = arith.muli %pid, %start_n1_27 : i32 loc(#loc580) + %offs_n1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc581) + %offs_n1_29 = tt.splat %start_n1_28 : i32 -> tensor<128xi32> loc(#loc582) + %offs_n1_30 = arith.addi %offs_n1_29, %offs_n1 : tensor<128xi32> loc(#loc582) + %k = tt.call @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.load_checked_2d__Pbf16_i32S128S_i32S128S_i32_i32_i32__(5,)cconstexpr_True__(6,)cconstexpr_True__(8,)cconstexpr_128_"(%K, %offs_n1_30, %offs_k, %c128_i32_0, %c1_i32_1, %KV_LEN) : (!tt.ptr, tensor<128xi32>, tensor<128xi32>, i32, i32, i32) -> tensor<128x128xbf16> loc(#loc583) + %v = tt.call @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.load_checked_2d__Pbf16_i32S128S_i32S128S_i32_i32_i32__(5,)cconstexpr_True__(6,)cconstexpr_True__(8,)cconstexpr_128_"(%V, %offs_n1_30, %offs_v, %c128_i32_4, %c1_i32_5, %KV_LEN) : (!tt.ptr, tensor<128xi32>, tensor<128xi32>, i32, i32, i32) -> tensor<128x128xbf16> loc(#loc584) + %c0_i32 = arith.constant 0 : i32 loc(#loc132) + %c4_i32 = arith.constant 4 : i32 loc(#loc132) + %c1_i32_31 = arith.constant 1 : i32 loc(#loc132) + %2 = arith.bitcast %c0_i32 : i32 to i32 loc(#loc132) + %3 = arith.bitcast %c4_i32 : i32 to i32 loc(#loc132) + %4 = arith.bitcast %c1_i32_31 : i32 to i32 loc(#loc132) + %5 = ub.poison : i32 loc(#loc132) + %dk_32:2 = scf.for %off_g = %2 to %3 step %4 iter_args(%dv_64 = %dv, %dk_65 = %dk) -> (tensor<128x128xf32>, tensor<128x128xf32>) : i32 { + %off_hq1 = arith.constant 4 : i32 loc(#loc586) + %off_hq1_66 = arith.constant 4 : i32 loc(#loc586) + %off_hq1_67 = arith.muli %off_hkv, %off_hq1_66 : i32 loc(#loc586) + %off_hq1_68 = arith.addi %off_hq1_67, %off_g : i32 loc(#loc587) + %q_adj1 = arith.muli %c128_i32, %off_hq1_68 : i32 loc(#loc588) + %q_adj1_69 = arith.muli %c8388608_i32, %off_zq : i32 loc(#loc589) + %q_adj1_70 = arith.addi %q_adj1, %q_adj1_69 : i32 loc(#loc590) + %q_adj1_71 = arith.extsi %q_adj1_70 : i32 to i64 loc(#loc591) + %do_adj1 = arith.muli %c262144_i32_7, %off_hq1_68 : i32 loc(#loc592) + %do_adj1_72 = arith.muli %c8388608_i32_6, %off_zq : i32 loc(#loc593) + %do_adj1_73 = arith.addi %do_adj1, %do_adj1_72 : i32 loc(#loc594) + %do_adj1_74 = arith.extsi %do_adj1_73 : i32 to i64 loc(#loc595) + %dq_adj1 = arith.muli %c128_i32_11, %off_hq1_68 : i32 loc(#loc596) + %dq_adj1_75 = arith.muli %c8388608_i32_10, %off_zq : i32 loc(#loc597) + %dq_adj1_76 = arith.addi %dq_adj1, %dq_adj1_75 : i32 loc(#loc598) + %dq_adj1_77 = arith.extsi %dq_adj1_76 : i32 to i64 loc(#loc599) + %off_chz1 = arith.muli %off_zq, %HQ : i32 loc(#loc600) + %off_chz1_78 = arith.addi %off_chz1, %off_hq1_68 : i32 loc(#loc601) + %off_chz1_79 = arith.muli %off_chz1_78, %Q_LEN : i32 loc(#loc602) + %off_chz1_80 = arith.extsi %off_chz1_79 : i32 to i64 loc(#loc603) + %Q1 = tt.addptr %arg_Q, %q_adj1_71 : !tt.ptr, i64 loc(#loc604) + %DO1 = tt.addptr %arg_DO, %do_adj1_74 : !tt.ptr, i64 loc(#loc605) + %LSE1 = tt.addptr %arg_LSE, %off_chz1_80 : !tt.ptr, i64 loc(#loc606) + %DELTA1 = tt.addptr %arg_DELTA, %off_chz1_80 : !tt.ptr, i64 loc(#loc607) + %sparse_idx_hq1 = arith.remsi %off_hq1_68, %SPARSE_HQ : i32 loc(#loc608) + %sparse_hz_offset = arith.muli %sparse_idx_z, %SPARSE_HQ : i32 loc(#loc609) + %sparse_hz_offset_81 = arith.addi %sparse_hz_offset, %sparse_idx_hq1 : i32 loc(#loc610) + %sparse_q_num_blks_offset = arith.muli %sparse_hz_offset_81, %stride_q_num_blks_h : i32 loc(#loc611) + %sparse_q_num_blks_offset_82 = arith.addi %sparse_q_num_blks_offset, %pid_mask : i32 loc(#loc612) + %sparse_q_idx_offset = arith.muli %sparse_hz_offset_81, %stride_q_idx_h : i32 loc(#loc613) + %sparse_q_idx_offset_83 = arith.muli %pid_mask, %stride_q_idx_n : i32 loc(#loc614) + %sparse_q_idx_offset_84 = arith.addi %sparse_q_idx_offset, %sparse_q_idx_offset_83 : i32 loc(#loc615) + %q_indices = tt.addptr %arg_Q_IDX, %sparse_q_idx_offset_84 : !tt.ptr, i32 loc(#loc616) + %q_start = tt.load %q_indices : !tt.ptr loc(#loc617) + %q_start_85 = arith.constant 128 : i32 loc(#loc618) + %q_start_86 = arith.constant 128 : i32 loc(#loc618) + %q_start_87 = arith.muli %q_start, %q_start_86 : i32 loc(#loc618) + %sparse_q_num_blocks = tt.addptr %arg_Q_NUM_BLKS, %sparse_q_num_blks_offset_82 : !tt.ptr, i32 loc(#loc619) + %sparse_q_num_blocks_88 = tt.load %sparse_q_num_blocks : !tt.ptr loc(#loc620) + %offs_m1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc621) + %offs_m1_89 = tt.splat %q_start_87 : i32 -> tensor<64xi32> loc(#loc622) + %offs_m1_90 = arith.addi %offs_m1_89, %offs_m1 : tensor<64xi32> loc(#loc622) + %11:2 = tt.call @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.bwd_dkdv_inner__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi64_Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_fp32S128_128S_fp32S128_128S_bf16S128_128S_bf16S128_128S_i32_i32_i32S128S_i32S64S_i32_i32_i32_i32_Pi32_i32__(36,)cconstexpr_bf16__(37,)cconstexpr_False_"(%arg_Q, %arg_K, %arg_V, %arg_LSE, %arg_DELTA, %arg_DO, %arg_DQ, %arg_DV, %arg_KV_NUM_BLKS, %arg_KV_IDX, %arg_Q_NUM_BLKS, %arg_Q_IDX, %arg_FULL_KV_NUM_BLKS, %arg_FULL_KV_IDX, %arg_FULL_Q_NUM_BLKS, %arg_FULL_Q_IDX, %in_ptr16, %out_ptr0, %Q1, %DO1, %DELTA1, %LSE1, %dk_65, %dv_64, %k, %v, %off_zq, %off_hq1_68, %offs_n1_30, %offs_m1_90, %c4096_i32, %c1_i32, %c128_i32_8, %c1_i32_9, %q_indices, %sparse_q_num_blocks_88) : (!tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xbf16>, tensor<128x128xbf16>, i32, i32, tensor<128xi32>, tensor<64xi32>, i32, i32, i32, i32, !tt.ptr, i32) -> (tensor<128x128xf32>, tensor<128x128xf32>) loc(#loc170) + %q_indices_91 = tt.addptr %arg_FULL_Q_IDX, %sparse_q_idx_offset_84 : !tt.ptr, i32 loc(#loc623) + %q_start_92 = tt.load %q_indices_91 : !tt.ptr loc(#loc624) + %q_start_93 = arith.constant 128 : i32 loc(#loc625) + %q_start_94 = arith.constant 128 : i32 loc(#loc625) + %q_start_95 = arith.muli %q_start_92, %q_start_94 : i32 loc(#loc625) + %sparse_q_num_blocks_96 = tt.addptr %arg_FULL_Q_NUM_BLKS, %sparse_q_num_blks_offset_82 : !tt.ptr, i32 loc(#loc626) + %sparse_q_num_blocks_97 = tt.load %sparse_q_num_blocks_96 : !tt.ptr loc(#loc627) + %offs_m1_98 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc628) + %offs_m1_99 = tt.splat %q_start_95 : i32 -> tensor<64xi32> loc(#loc629) + %offs_m1_100 = arith.addi %offs_m1_99, %offs_m1_98 : tensor<64xi32> loc(#loc629) + %12:2 = tt.call @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.bwd_dkdv_inner__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi64_Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_fp32S128_128S_fp32S128_128S_bf16S128_128S_bf16S128_128S_i32_i32_i32S128S_i32S64S_i32_i32_i32_i32_Pi32_i32__(36,)cconstexpr_bf16__(37,)cconstexpr_True_"(%arg_Q, %arg_K, %arg_V, %arg_LSE, %arg_DELTA, %arg_DO, %arg_DQ, %arg_DV, %arg_KV_NUM_BLKS, %arg_KV_IDX, %arg_Q_NUM_BLKS, %arg_Q_IDX, %arg_FULL_KV_NUM_BLKS, %arg_FULL_KV_IDX, %arg_FULL_Q_NUM_BLKS, %arg_FULL_Q_IDX, %in_ptr16, %out_ptr0, %Q1, %DO1, %DELTA1, %LSE1, %11#0, %11#1, %k, %v, %off_zq, %off_hq1_68, %offs_n1_30, %offs_m1_100, %c4096_i32, %c1_i32, %c128_i32_8, %c1_i32_9, %q_indices_91, %sparse_q_num_blocks_97) : (!tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xbf16>, tensor<128x128xbf16>, i32, i32, tensor<128xi32>, tensor<64xi32>, i32, i32, i32, i32, !tt.ptr, i32) -> (tensor<128x128xf32>, tensor<128x128xf32>) loc(#loc178) + scf.yield %12#1, %12#0 : tensor<128x128xf32>, tensor<128x128xf32> loc(#loc179) + } loc(#loc1022) + %dv_ptrs = tt.expand_dims %offs_n1_30 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc630) + %dv_ptrs_33 = arith.constant dense<128> : tensor<128x1xi32> loc(#loc631) + %dv_ptrs_34 = arith.muli %dv_ptrs, %dv_ptrs_33 : tensor<128x1xi32> loc(#loc631) + %dv_ptrs_35 = tt.splat %DV : !tt.ptr -> tensor<128x1x!tt.ptr> loc(#loc632) + %dv_ptrs_36 = tt.addptr %dv_ptrs_35, %dv_ptrs_34 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> loc(#loc632) + %dv_ptrs_37 = tt.expand_dims %offs_v {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc633) + %dv_ptrs_38 = arith.constant dense<1> : tensor<1x128xi32> loc(#loc634) + %dv_ptrs_39 = arith.muli %dv_ptrs_37, %dv_ptrs_38 : tensor<1x128xi32> loc(#loc634) + %dv_ptrs_40 = tt.broadcast %dv_ptrs_36 : tensor<128x1x!tt.ptr> -> tensor<128x128x!tt.ptr> loc(#loc635) + %dv_ptrs_41 = tt.broadcast %dv_ptrs_39 : tensor<1x128xi32> -> tensor<128x128xi32> loc(#loc635) + %dv_ptrs_42 = tt.addptr %dv_ptrs_40, %dv_ptrs_41 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> loc(#loc635) + %index_n = tt.expand_dims %offs_n1_30 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc636) + %index_k = tt.expand_dims %offs_k {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc637) + %index_v = tt.expand_dims %offs_v {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc638) + %6 = arith.truncf %dk_32#0 : tensor<128x128xf32> to tensor<128x128xbf16> loc(#loc189) + tt.store %dv_ptrs_42, %6 : tensor<128x128x!tt.ptr> loc(#loc189) + %dk_43 = arith.constant 0.0883883461 : f32 loc(#loc639) + %dk_44 = arith.constant 0.0883883461 : f32 loc(#loc639) + %dk_45 = arith.constant dense<0.0883883461> : tensor<128x128xf32> loc(#loc639) + %dk_46 = arith.mulf %dk_32#1, %dk_45 : tensor<128x128xf32> loc(#loc639) + %mask = arith.constant dense<2048> : tensor<128x1xi32> loc(#loc640) + %mask_47 = arith.cmpi slt, %index_n, %mask : tensor<128x1xi32> loc(#loc640) + %xindex = arith.constant 128 : i32 loc(#loc641) + %xindex_48 = arith.constant 128 : i32 loc(#loc641) + %xindex_49 = arith.constant dense<128> : tensor<128x1xi32> loc(#loc641) + %xindex_50 = arith.muli %xindex_49, %index_n : tensor<128x1xi32> loc(#loc641) + %xindex_51 = tt.broadcast %index_k : tensor<1x128xi32> -> tensor<128x128xi32> loc(#loc642) + %xindex_52 = tt.broadcast %xindex_50 : tensor<128x1xi32> -> tensor<128x128xi32> loc(#loc642) + %xindex_53 = arith.addi %xindex_51, %xindex_52 : tensor<128x128xi32> loc(#loc642) + %xindex_54 = arith.constant 262144 : i32 loc(#loc643) + %xindex_55 = arith.constant 262144 : i32 loc(#loc643) + %xindex_56 = arith.muli %xindex_55, %off_hkv : i32 loc(#loc643) + %xindex_57 = tt.splat %xindex_56 : i32 -> tensor<128x128xi32> loc(#loc644) + %xindex_58 = arith.addi %xindex_53, %xindex_57 : tensor<128x128xi32> loc(#loc644) + %xindex_59 = arith.constant 2097152 : i32 loc(#loc645) + %xindex_60 = arith.constant 2097152 : i32 loc(#loc645) + %xindex_61 = arith.muli %xindex_60, %off_zq : i32 loc(#loc645) + %xindex_62 = tt.splat %xindex_61 : i32 -> tensor<128x128xi32> loc(#loc646) + %xindex_63 = arith.addi %xindex_58, %xindex_62 : tensor<128x128xi32> loc(#loc646) + %7 = tt.splat %out_ptr0 : !tt.ptr -> tensor<128x128x!tt.ptr> loc(#loc198) + %8 = tt.addptr %7, %xindex_63 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> loc(#loc198) + %9 = tt.broadcast %mask_47 : tensor<128x1xi1> -> tensor<128x128xi1> loc(#loc199) + %10 = arith.truncf %dk_46 : tensor<128x128xf32> to tensor<128x128xbf16> loc(#loc199) + tt.store %8, %10, %9 : tensor<128x128x!tt.ptr> loc(#loc199) + scf.yield %SPARSE_KV_MULTIPLE, %SPARSE_Q_MULTIPLE : i32, i32 loc(#loc199) + } loc(#loc41) + tt.return loc(#loc200) + } loc(#loc) + tt.func private @"triton.language.standard.cdiv__i32__(1,)cconstexpr_128_"(%x: i32 loc("x"(#loc201))) -> i32 attributes {noinline = false} { + %c128_i32 = arith.constant 128 : i32 loc(#loc202) + %c128_i32_0 = arith.constant 128 : i32 loc(#loc202) + %0 = arith.addi %x, %c128_i32_0 : i32 loc(#loc202) + %c1_i32 = arith.constant 1 : i32 loc(#loc203) + %c1_i32_1 = arith.constant 1 : i32 loc(#loc203) + %1 = arith.subi %0, %c1_i32_1 : i32 loc(#loc203) + %c128_i32_2 = arith.constant 128 : i32 loc(#loc204) + %c128_i32_3 = arith.constant 128 : i32 loc(#loc204) + %2 = arith.divsi %1, %c128_i32_3 : i32 loc(#loc204) + tt.return %2 : i32 loc(#loc205) + ^bb1: // no predecessors + %3 = ub.poison : i32 loc(#loc206) + tt.return %3 : i32 loc(#loc206) + } loc(#loc201) + tt.func private @"triton.language.standard.zeros____(0, 0)cconstexpr_128__(0, 1)cconstexpr_128__(1,)cconstexpr_fp32_"() -> tensor<128x128xf32> attributes {noinline = false} { + %cst = arith.constant 0.000000e+00 : f32 loc(#loc208) + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> loc(#loc208) + tt.return %cst_0 : tensor<128x128xf32> loc(#loc209) + ^bb1: // no predecessors + %0 = ub.poison : tensor<128x128xf32> loc(#loc210) + tt.return %0 : tensor<128x128xf32> loc(#loc210) + } loc(#loc207) + tt.func private @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.load_checked_2d__Pbf16_i32S128S_i32S128S_i32_i32_i32__(5,)cconstexpr_True__(6,)cconstexpr_True__(8,)cconstexpr_128_"(%ptr: !tt.ptr loc("ptr"(#loc211)), %offs_m: tensor<128xi32> loc("offs_m"(#loc211)), %offs_n: tensor<128xi32> loc("offs_n"(#loc211)), %stride_m: i32 loc("stride_m"(#loc211)), %stride_n: i32 loc("stride_n"(#loc211)), %M_LEN: i32 loc("M_LEN"(#loc211))) -> tensor<128x128xbf16> attributes {noinline = false} { + %ptr_0 = tt.expand_dims %offs_m {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc654) + %ptr_1 = tt.splat %stride_m : i32 -> tensor<128x1xi32> loc(#loc655) + %ptr_2 = arith.muli %ptr_0, %ptr_1 : tensor<128x1xi32> loc(#loc655) + %ptr_3 = tt.splat %ptr : !tt.ptr -> tensor<128x1x!tt.ptr> loc(#loc656) + %ptr_4 = tt.addptr %ptr_3, %ptr_2 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> loc(#loc656) + %ptr_5 = tt.expand_dims %offs_n {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc657) + %ptr_6 = tt.splat %stride_n : i32 -> tensor<1x128xi32> loc(#loc658) + %ptr_7 = arith.muli %ptr_5, %ptr_6 : tensor<1x128xi32> loc(#loc658) + %ptr_8 = tt.broadcast %ptr_4 : tensor<128x1x!tt.ptr> -> tensor<128x128x!tt.ptr> loc(#loc659) + %ptr_9 = tt.broadcast %ptr_7 : tensor<1x128xi32> -> tensor<128x128xi32> loc(#loc659) + %ptr_10 = tt.addptr %ptr_8, %ptr_9 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> loc(#loc659) + %0 = tt.load %ptr_10 : tensor<128x128x!tt.ptr> loc(#loc218) + tt.return %0 : tensor<128x128xbf16> loc(#loc219) + ^bb1: // no predecessors + %1 = ub.poison : tensor<128x128xbf16> loc(#loc220) + tt.return %1 : tensor<128x128xbf16> loc(#loc220) + } loc(#loc211) + tt.func private @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.bwd_dq_inner__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi64_Pbf16_Pbf16_Pbf16_fp32S128_128S_bf16S128_128S_bf16S128_128S_fp32S128S_fp32S128_1S_i32_i32_i32S128S_i32S64S_i32_i32_i32_i32_Pi32_i32__(35,)cconstexpr_bf16__(36,)cconstexpr_False_"(%arg_Q: !tt.ptr loc("arg_Q"(#loc221)), %arg_K: !tt.ptr loc("arg_K"(#loc221)), %arg_V: !tt.ptr loc("arg_V"(#loc221)), %arg_LSE: !tt.ptr loc("arg_LSE"(#loc221)), %arg_DELTA: !tt.ptr loc("arg_DELTA"(#loc221)), %arg_DO: !tt.ptr loc("arg_DO"(#loc221)), %arg_DQ: !tt.ptr loc("arg_DQ"(#loc221)), %arg_DV: !tt.ptr loc("arg_DV"(#loc221)), %arg_KV_NUM_BLKS: !tt.ptr loc("arg_KV_NUM_BLKS"(#loc221)), %arg_KV_IDX: !tt.ptr loc("arg_KV_IDX"(#loc221)), %arg_Q_NUM_BLKS: !tt.ptr loc("arg_Q_NUM_BLKS"(#loc221)), %arg_Q_IDX: !tt.ptr loc("arg_Q_IDX"(#loc221)), %arg_FULL_KV_NUM_BLKS: !tt.ptr loc("arg_FULL_KV_NUM_BLKS"(#loc221)), %arg_FULL_KV_IDX: !tt.ptr loc("arg_FULL_KV_IDX"(#loc221)), %arg_FULL_Q_NUM_BLKS: !tt.ptr loc("arg_FULL_Q_NUM_BLKS"(#loc221)), %arg_FULL_Q_IDX: !tt.ptr loc("arg_FULL_Q_IDX"(#loc221)), %in_ptr16: !tt.ptr loc("in_ptr16"(#loc221)), %out_ptr0: !tt.ptr loc("out_ptr0"(#loc221)), %K: !tt.ptr loc("K"(#loc221)), %V: !tt.ptr loc("V"(#loc221)), %dq: tensor<128x128xf32> loc("dq"(#loc221)), %q: tensor<128x128xbf16> loc("q"(#loc221)), %do: tensor<128x128xbf16> loc("do"(#loc221)), %Di: tensor<128xf32> loc("Di"(#loc221)), %lse: tensor<128x1xf32> loc("lse"(#loc221)), %off_z: i32 loc("off_z"(#loc221)), %off_hq: i32 loc("off_hq"(#loc221)), %offs_m2: tensor<128xi32> loc("offs_m2"(#loc221)), %offs_n2: tensor<64xi32> loc("offs_n2"(#loc221)), %stride_kn: i32 loc("stride_kn"(#loc221)), %stride_kd: i32 loc("stride_kd"(#loc221)), %stride_vn: i32 loc("stride_vn"(#loc221)), %stride_vd: i32 loc("stride_vd"(#loc221)), %kv_indices: !tt.ptr loc("kv_indices"(#loc221)), %sparse_kv_num_blocks: i32 loc("sparse_kv_num_blocks"(#loc221))) -> tensor<128x128xf32> attributes {noinline = false} { + %Q_LEN = arith.constant 2048 : i32 loc(#loc695) + %KV_LEN = arith.constant 2048 : i32 loc(#loc696) + %offs_k = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc697) + %offs_v = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc698) + %kT_ptrs = tt.expand_dims %offs_n2 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc699) + %kT_ptrs_0 = tt.splat %stride_kn : i32 -> tensor<1x64xi32> loc(#loc700) + %kT_ptrs_1 = arith.muli %kT_ptrs, %kT_ptrs_0 : tensor<1x64xi32> loc(#loc700) + %kT_ptrs_2 = tt.splat %K : !tt.ptr -> tensor<1x64x!tt.ptr> loc(#loc701) + %kT_ptrs_3 = tt.addptr %kT_ptrs_2, %kT_ptrs_1 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> loc(#loc701) + %kT_ptrs_4 = tt.expand_dims %offs_k {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc702) + %kT_ptrs_5 = tt.splat %stride_kd : i32 -> tensor<128x1xi32> loc(#loc703) + %kT_ptrs_6 = arith.muli %kT_ptrs_4, %kT_ptrs_5 : tensor<128x1xi32> loc(#loc703) + %kT_ptrs_7 = tt.broadcast %kT_ptrs_3 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> loc(#loc704) + %kT_ptrs_8 = tt.broadcast %kT_ptrs_6 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc704) + %kT_ptrs_9 = tt.addptr %kT_ptrs_7, %kT_ptrs_8 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc704) + %vT_ptrs = tt.expand_dims %offs_n2 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc705) + %vT_ptrs_10 = tt.splat %stride_vn : i32 -> tensor<1x64xi32> loc(#loc706) + %vT_ptrs_11 = arith.muli %vT_ptrs, %vT_ptrs_10 : tensor<1x64xi32> loc(#loc706) + %vT_ptrs_12 = tt.splat %V : !tt.ptr -> tensor<1x64x!tt.ptr> loc(#loc707) + %vT_ptrs_13 = tt.addptr %vT_ptrs_12, %vT_ptrs_11 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> loc(#loc707) + %vT_ptrs_14 = tt.expand_dims %offs_v {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc708) + %vT_ptrs_15 = tt.splat %stride_vd : i32 -> tensor<128x1xi32> loc(#loc709) + %vT_ptrs_16 = arith.muli %vT_ptrs_14, %vT_ptrs_15 : tensor<128x1xi32> loc(#loc709) + %vT_ptrs_17 = tt.broadcast %vT_ptrs_13 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> loc(#loc710) + %vT_ptrs_18 = tt.broadcast %vT_ptrs_16 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc710) + %vT_ptrs_19 = tt.addptr %vT_ptrs_17, %vT_ptrs_18 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc710) + %hi = arith.constant 2 : i32 loc(#loc711) + %hi_20 = arith.constant 2 : i32 loc(#loc711) + %hi_21 = arith.muli %sparse_kv_num_blocks, %hi_20 : i32 loc(#loc711) + %hi_22 = tt.call @"triton.language.standard.cdiv__i32__(1,)cconstexpr_64_"(%KV_LEN) : (i32) -> i32 loc(#loc712) + %hi_23 = arith.constant 1 : i32 loc(#loc713) + %hi_24 = arith.maxsi %hi_22, %hi_23 : i32 loc(#loc713) + %hi_25 = arith.minsi %hi_21, %hi_24 : i32 loc(#loc714) + %c0_i32 = arith.constant 0 : i32 loc(#loc242) + %c1_i32 = arith.constant 1 : i32 loc(#loc242) + %0 = arith.bitcast %c0_i32 : i32 to i32 loc(#loc242) + %1 = arith.bitcast %hi_25 : i32 to i32 loc(#loc242) + %2 = arith.bitcast %c1_i32 : i32 to i32 loc(#loc242) + %3 = ub.poison : i32 loc(#loc242) + %vT_ptrs_26:4 = scf.for %start_n = %0 to %1 step %2 iter_args(%dq_27 = %dq, %offs_n2_28 = %offs_n2, %kT_ptrs_29 = %kT_ptrs_9, %vT_ptrs_30 = %vT_ptrs_19) -> (tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<128x64x!tt.ptr>) : i32 { + %dq_31 = tt.call @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.bwd_dq_block_mn__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi64_Pbf16_fp32S128_128S_bf16S128_128S_Pbf16S128_64S_Pbf16S128_64S_bf16S128_128S_fp32S128S_fp32S128_1S_i32_i32_i32_i32_i32S128S_i32S64S_i32S128S_i32S128S_i32_i32_i32_i32_Pi32_i32__(39,)cconstexpr_bf16__(40,)cconstexpr_1_d_44269504__(41,)cconstexpr_False_"(%arg_Q, %arg_K, %arg_V, %arg_LSE, %arg_DELTA, %arg_DO, %arg_DQ, %arg_DV, %arg_KV_NUM_BLKS, %arg_KV_IDX, %arg_Q_NUM_BLKS, %arg_Q_IDX, %arg_FULL_KV_NUM_BLKS, %arg_FULL_KV_IDX, %arg_FULL_Q_NUM_BLKS, %arg_FULL_Q_IDX, %in_ptr16, %out_ptr0, %dq_27, %q, %kT_ptrs_29, %vT_ptrs_30, %do, %Di, %lse, %Q_LEN, %KV_LEN, %off_z, %off_hq, %offs_m2, %offs_n2_28, %offs_k, %offs_v, %stride_kn, %stride_kd, %stride_vn, %stride_vd, %kv_indices, %sparse_kv_num_blocks) : (!tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, tensor<128x128xf32>, tensor<128x128xbf16>, tensor<128x64x!tt.ptr>, tensor<128x64x!tt.ptr>, tensor<128x128xbf16>, tensor<128xf32>, tensor<128x1xf32>, i32, i32, i32, i32, tensor<128xi32>, tensor<64xi32>, tensor<128xi32>, tensor<128xi32>, i32, i32, i32, i32, !tt.ptr, i32) -> tensor<128x128xf32> loc(#loc716) + %offset = tt.call @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.get_offset_for_next_block__i32_Pi32_i32__(3,)cconstexpr_128__(4,)cconstexpr_2__(5,)cconstexpr_64__(6,)cconstexpr_False_"(%start_n, %kv_indices, %sparse_kv_num_blocks) : (i32, !tt.ptr, i32) -> i32 loc(#loc717) + %kT_ptrs_32 = arith.muli %offset, %stride_kn : i32 loc(#loc718) + %kT_ptrs_33 = tt.splat %kT_ptrs_32 : i32 -> tensor<128x64xi32> loc(#loc719) + %kT_ptrs_34 = tt.addptr %kT_ptrs_29, %kT_ptrs_33 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc719) + %vT_ptrs_35 = arith.muli %offset, %stride_vn : i32 loc(#loc720) + %vT_ptrs_36 = tt.splat %vT_ptrs_35 : i32 -> tensor<128x64xi32> loc(#loc721) + %vT_ptrs_37 = tt.addptr %vT_ptrs_30, %vT_ptrs_36 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc721) + %offs_n2_38 = tt.splat %offset : i32 -> tensor<64xi32> loc(#loc722) + %offs_n2_39 = arith.addi %offs_n2_28, %offs_n2_38 : tensor<64xi32> loc(#loc722) + scf.yield %dq_31, %offs_n2_39, %kT_ptrs_34, %vT_ptrs_37 : tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<128x64x!tt.ptr> loc(#loc250) + } loc(#loc1027) + tt.return %vT_ptrs_26#0 : tensor<128x128xf32> loc(#loc251) + ^bb1: // no predecessors + %4 = ub.poison : tensor<128x128xf32> loc(#loc252) + tt.return %4 : tensor<128x128xf32> loc(#loc252) + } loc(#loc221) + tt.func private @"triton.language.standard.cdiv__i32__(1,)cconstexpr_64_"(%x: i32 loc("x"(#loc201))) -> i32 attributes {noinline = false} { + %c64_i32 = arith.constant 64 : i32 loc(#loc202) + %c64_i32_0 = arith.constant 64 : i32 loc(#loc202) + %0 = arith.addi %x, %c64_i32_0 : i32 loc(#loc202) + %c1_i32 = arith.constant 1 : i32 loc(#loc203) + %c1_i32_1 = arith.constant 1 : i32 loc(#loc203) + %1 = arith.subi %0, %c1_i32_1 : i32 loc(#loc203) + %c64_i32_2 = arith.constant 64 : i32 loc(#loc204) + %c64_i32_3 = arith.constant 64 : i32 loc(#loc204) + %2 = arith.divsi %1, %c64_i32_3 : i32 loc(#loc204) + tt.return %2 : i32 loc(#loc205) + ^bb1: // no predecessors + %3 = ub.poison : i32 loc(#loc206) + tt.return %3 : i32 loc(#loc206) + } loc(#loc201) + tt.func private @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.bwd_dq_block_mn__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi64_Pbf16_fp32S128_128S_bf16S128_128S_Pbf16S128_64S_Pbf16S128_64S_bf16S128_128S_fp32S128S_fp32S128_1S_i32_i32_i32_i32_i32S128S_i32S64S_i32S128S_i32S128S_i32_i32_i32_i32_Pi32_i32__(39,)cconstexpr_bf16__(40,)cconstexpr_1_d_44269504__(41,)cconstexpr_False_"(%arg_Q: !tt.ptr loc("arg_Q"(#loc253)), %arg_K: !tt.ptr loc("arg_K"(#loc253)), %arg_V: !tt.ptr loc("arg_V"(#loc253)), %arg_LSE: !tt.ptr loc("arg_LSE"(#loc253)), %arg_DELTA: !tt.ptr loc("arg_DELTA"(#loc253)), %arg_DO: !tt.ptr loc("arg_DO"(#loc253)), %arg_DQ: !tt.ptr loc("arg_DQ"(#loc253)), %arg_DV: !tt.ptr loc("arg_DV"(#loc253)), %arg_KV_NUM_BLKS: !tt.ptr loc("arg_KV_NUM_BLKS"(#loc253)), %arg_KV_IDX: !tt.ptr loc("arg_KV_IDX"(#loc253)), %arg_Q_NUM_BLKS: !tt.ptr loc("arg_Q_NUM_BLKS"(#loc253)), %arg_Q_IDX: !tt.ptr loc("arg_Q_IDX"(#loc253)), %arg_FULL_KV_NUM_BLKS: !tt.ptr loc("arg_FULL_KV_NUM_BLKS"(#loc253)), %arg_FULL_KV_IDX: !tt.ptr loc("arg_FULL_KV_IDX"(#loc253)), %arg_FULL_Q_NUM_BLKS: !tt.ptr loc("arg_FULL_Q_NUM_BLKS"(#loc253)), %arg_FULL_Q_IDX: !tt.ptr loc("arg_FULL_Q_IDX"(#loc253)), %in_ptr16: !tt.ptr loc("in_ptr16"(#loc253)), %out_ptr0: !tt.ptr loc("out_ptr0"(#loc253)), %dq: tensor<128x128xf32> loc("dq"(#loc253)), %q: tensor<128x128xbf16> loc("q"(#loc253)), %kT_ptrs: tensor<128x64x!tt.ptr> loc("kT_ptrs"(#loc253)), %vT_ptrs: tensor<128x64x!tt.ptr> loc("vT_ptrs"(#loc253)), %do: tensor<128x128xbf16> loc("do"(#loc253)), %Di: tensor<128xf32> loc("Di"(#loc253)), %lse: tensor<128x1xf32> loc("lse"(#loc253)), %Q_LEN: i32 loc("Q_LEN"(#loc253)), %KV_LEN: i32 loc("KV_LEN"(#loc253)), %off_z: i32 loc("off_z"(#loc253)), %off_hq: i32 loc("off_hq"(#loc253)), %offs_m2: tensor<128xi32> loc("offs_m2"(#loc253)), %offs_n2: tensor<64xi32> loc("offs_n2"(#loc253)), %offs_k: tensor<128xi32> loc("offs_k"(#loc253)), %offs_v: tensor<128xi32> loc("offs_v"(#loc253)), %stride_kn: i32 loc("stride_kn"(#loc253)), %stride_kd: i32 loc("stride_kd"(#loc253)), %stride_vn: i32 loc("stride_vn"(#loc253)), %stride_vd: i32 loc("stride_vd"(#loc253)), %kv_indices: !tt.ptr loc("kv_indices"(#loc253)), %sparse_kv_num_blocks: i32 loc("sparse_kv_num_blocks"(#loc253))) -> tensor<128x128xf32> attributes {noinline = false} { + %kT = tt.call @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.load_checked_2d__Pbf16S128_64S_i32S128S_i32S64S_i32__(3,)cconstexpr_None__(4,)cconstexpr_None__(5,)cconstexpr_True__(6,)cconstexpr_True__(7,)cconstexpr_128_"(%kT_ptrs, %offs_k, %offs_n2, %KV_LEN) : (tensor<128x64x!tt.ptr>, tensor<128xi32>, tensor<64xi32>, i32) -> tensor<128x64xbf16> loc(#loc762) + %qk = arith.constant 0.000000e+00 : f32 loc(#loc763) + %qk_0 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc763) + %qk_1 = tt.dot %q, %kT, %qk_0, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc763) + %qk_2 = arith.constant 0.0883883461 : f32 loc(#loc764) + %qk_3 = arith.constant 0.0883883461 : f32 loc(#loc764) + %qk_4 = arith.constant dense<0.0883883461> : tensor<128x64xf32> loc(#loc764) + %qk_5 = arith.mulf %qk_1, %qk_4 : tensor<128x64xf32> loc(#loc764) + %n = tt.expand_dims %offs_n2 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc765) + %n_6 = tt.call @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.get_bounded_indices__i32S1_64S__(1,)cconstexpr_None_"(%n) : (tensor<1x64xi32>) -> tensor<1x64xi32> loc(#loc766) + %m = tt.expand_dims %offs_m2 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc767) + %m_7 = tt.call @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.get_bounded_indices__i32S128_1S__(1,)cconstexpr_None_"(%m) : (tensor<128x1xi32>) -> tensor<128x1xi32> loc(#loc768) + %tmp1 = arith.constant false loc(#loc769) + %tmp1_8 = arith.constant dense : tensor<1xi1> loc(#loc769) + %tmp4 = tt.broadcast %m_7 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc770) + %tmp4_9 = tt.broadcast %n_6 : tensor<1x64xi32> -> tensor<128x64xi32> loc(#loc770) + %tmp4_10 = arith.cmpi sge, %tmp4, %tmp4_9 : tensor<128x64xi32> loc(#loc770) + %tmp5 = arith.extsi %n_6 : tensor<1x64xi32> to tensor<1x64xi64> loc(#loc771) + %tmp7 = tt.addptr %in_ptr16, %off_z : !tt.ptr, i32 loc(#loc772) + %tmp7_11 = tt.load %tmp7 : !tt.ptr loc(#loc773) + %tmp8 = tt.splat %tmp7_11 : i64 -> tensor<1x64xi64> loc(#loc774) + %tmp8_12 = arith.cmpi slt, %tmp5, %tmp8 : tensor<1x64xi64> loc(#loc774) + %tmp9 = arith.extsi %m_7 : tensor<128x1xi32> to tensor<128x1xi64> loc(#loc775) + %tmp10 = tt.splat %tmp7_11 : i64 -> tensor<128x1xi64> loc(#loc776) + %tmp10_13 = arith.cmpi slt, %tmp9, %tmp10 : tensor<128x1xi64> loc(#loc776) + %tmp11 = tt.broadcast %tmp8_12 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc777) + %tmp11_14 = tt.broadcast %tmp10_13 : tensor<128x1xi1> -> tensor<128x64xi1> loc(#loc777) + %tmp11_15 = arith.andi %tmp11, %tmp11_14 : tensor<128x64xi1> loc(#loc777) + %tmp12 = arith.andi %tmp4_10, %tmp11_15 : tensor<128x64xi1> loc(#loc778) + %tmp13 = tt.expand_dims %tmp1_8 {axis = 0 : i32} : tensor<1xi1> -> tensor<1x1xi1> loc(#loc779) + %tmp13_16 = tt.broadcast %tmp13 : tensor<1x1xi1> -> tensor<128x64xi1> loc(#loc779) + %tmp13_17 = arith.ori %tmp13_16, %tmp12 : tensor<128x64xi1> loc(#loc779) + %tmp14 = arith.constant 2048 : i32 loc(#loc780) + %tmp14_18 = arith.constant dense<2048> : tensor<1xi32> loc(#loc780) + %tmp15 = tt.expand_dims %tmp14_18 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc781) + %tmp15_19 = tt.broadcast %tmp15 : tensor<1x1xi32> -> tensor<1x64xi32> loc(#loc781) + %tmp15_20 = arith.cmpi sge, %n_6, %tmp15_19 : tensor<1x64xi32> loc(#loc781) + %tmp16 = tt.expand_dims %tmp14_18 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc782) + %tmp16_21 = tt.broadcast %tmp16 : tensor<1x1xi32> -> tensor<1x64xi32> loc(#loc782) + %tmp16_22 = arith.remsi %n_6, %tmp16_21 : tensor<1x64xi32> loc(#loc782) + %tmp17 = arith.constant 0 : i32 loc(#loc783) + %tmp17_23 = arith.constant dense<0> : tensor<1xi32> loc(#loc783) + %tmp18 = tt.expand_dims %tmp17_23 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc784) + %tmp18_24 = tt.broadcast %tmp18 : tensor<1x1xi32> -> tensor<1x64xi32> loc(#loc784) + %tmp18_25 = arith.cmpi ne, %tmp16_22, %tmp18_24 : tensor<1x64xi32> loc(#loc784) + %tmp19 = arith.constant 0 : i32 loc(#loc785) + %tmp19_26 = arith.constant dense<0> : tensor<1x64xi32> loc(#loc785) + %tmp19_27 = arith.cmpi slt, %tmp16_22, %tmp19_26 : tensor<1x64xi32> loc(#loc785) + %tmp20 = arith.constant 0 : i32 loc(#loc786) + %tmp20_28 = arith.constant dense<0> : tensor<1xi32> loc(#loc786) + %tmp20_29 = arith.cmpi slt, %tmp14_18, %tmp20_28 : tensor<1xi32> loc(#loc786) + %tmp21 = tt.expand_dims %tmp20_29 {axis = 0 : i32} : tensor<1xi1> -> tensor<1x1xi1> loc(#loc787) + %tmp21_30 = tt.broadcast %tmp21 : tensor<1x1xi1> -> tensor<1x64xi1> loc(#loc787) + %tmp21_31 = arith.cmpi ne, %tmp19_27, %tmp21_30 : tensor<1x64xi1> loc(#loc787) + %tmp22 = arith.andi %tmp18_25, %tmp21_31 : tensor<1x64xi1> loc(#loc788) + %tmp23 = tt.expand_dims %tmp14_18 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc789) + %tmp23_32 = tt.broadcast %tmp23 : tensor<1x1xi32> -> tensor<1x64xi32> loc(#loc789) + %tmp23_33 = arith.addi %tmp16_22, %tmp23_32 : tensor<1x64xi32> loc(#loc789) + %tmp24 = arith.select %tmp22, %tmp23_33, %tmp16_22 : tensor<1x64xi1>, tensor<1x64xi32> loc(#loc790) + %tmp25 = arith.extsi %tmp24 : tensor<1x64xi32> to tensor<1x64xi64> loc(#loc791) + %tmp26 = tt.splat %tmp7_11 : i64 -> tensor<1x64xi64> loc(#loc792) + %tmp26_34 = arith.cmpi slt, %tmp25, %tmp26 : tensor<1x64xi64> loc(#loc792) + %tmp27 = arith.andi %tmp15_20, %tmp26_34 : tensor<1x64xi1> loc(#loc793) + %tmp28 = tt.broadcast %n_6 : tensor<1x64xi32> -> tensor<128x64xi32> loc(#loc794) + %tmp28_35 = tt.broadcast %m_7 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc794) + %tmp28_36 = arith.subi %tmp28, %tmp28_35 : tensor<128x64xi32> loc(#loc794) + %tmp29 = tt.expand_dims %tmp14_18 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc795) + %tmp29_37 = tt.broadcast %tmp29 : tensor<1x1xi32> -> tensor<128x64xi32> loc(#loc795) + %tmp29_38 = arith.remsi %tmp28_36, %tmp29_37 : tensor<128x64xi32> loc(#loc795) + %tmp30 = tt.expand_dims %tmp17_23 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc796) + %tmp30_39 = tt.broadcast %tmp30 : tensor<1x1xi32> -> tensor<128x64xi32> loc(#loc796) + %tmp30_40 = arith.cmpi ne, %tmp29_38, %tmp30_39 : tensor<128x64xi32> loc(#loc796) + %tmp31 = arith.constant 0 : i32 loc(#loc797) + %tmp31_41 = arith.constant dense<0> : tensor<128x64xi32> loc(#loc797) + %tmp31_42 = arith.cmpi slt, %tmp29_38, %tmp31_41 : tensor<128x64xi32> loc(#loc797) + %tmp32 = tt.expand_dims %tmp20_29 {axis = 0 : i32} : tensor<1xi1> -> tensor<1x1xi1> loc(#loc798) + %tmp32_43 = tt.broadcast %tmp32 : tensor<1x1xi1> -> tensor<128x64xi1> loc(#loc798) + %tmp32_44 = arith.cmpi ne, %tmp31_42, %tmp32_43 : tensor<128x64xi1> loc(#loc798) + %tmp33 = arith.andi %tmp30_40, %tmp32_44 : tensor<128x64xi1> loc(#loc799) + %tmp34 = tt.expand_dims %tmp14_18 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc800) + %tmp34_45 = tt.broadcast %tmp34 : tensor<1x1xi32> -> tensor<128x64xi32> loc(#loc800) + %tmp34_46 = arith.addi %tmp29_38, %tmp34_45 : tensor<128x64xi32> loc(#loc800) + %tmp35 = arith.select %tmp33, %tmp34_46, %tmp29_38 : tensor<128x64xi1>, tensor<128x64xi32> loc(#loc801) + %tmp36 = tt.expand_dims %tmp17_23 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc802) + %tmp36_47 = tt.broadcast %tmp36 : tensor<1x1xi32> -> tensor<128x64xi32> loc(#loc802) + %tmp36_48 = arith.cmpi eq, %tmp35, %tmp36_47 : tensor<128x64xi32> loc(#loc802) + %tmp37 = tt.broadcast %tmp27 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc803) + %tmp37_49 = arith.andi %tmp37, %tmp36_48 : tensor<128x64xi1> loc(#loc803) + %tmp38 = arith.ori %tmp13_17, %tmp37_49 : tensor<128x64xi1> loc(#loc804) + %post_mod_scores = arith.constant 0xFF800000 : f32 loc(#loc805) + %post_mod_scores_50 = arith.constant 0xFF800000 : f32 loc(#loc805) + %post_mod_scores_51 = arith.constant dense<0xFF800000> : tensor<128x64xf32> loc(#loc805) + %post_mod_scores_52 = arith.select %tmp38, %qk_5, %post_mod_scores_51 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc805) + %post_mod_scores_53 = arith.constant 1.44269502 : f32 loc(#loc806) + %post_mod_scores_54 = arith.constant 1.44269502 : f32 loc(#loc806) + %post_mod_scores_55 = arith.constant dense<1.44269502> : tensor<128x64xf32> loc(#loc806) + %post_mod_scores_56 = arith.mulf %post_mod_scores_52, %post_mod_scores_55 : tensor<128x64xf32> loc(#loc806) + %p = tt.broadcast %lse : tensor<128x1xf32> -> tensor<128x64xf32> loc(#loc807) + %p_57 = arith.subf %post_mod_scores_56, %p : tensor<128x64xf32> loc(#loc807) + %p_58 = math.exp2 %p_57 : tensor<128x64xf32> loc(#loc808) + %vT = tt.call @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.load_checked_2d__Pbf16S128_64S_i32S128S_i32S64S_i32__(3,)cconstexpr_None__(4,)cconstexpr_None__(5,)cconstexpr_True__(6,)cconstexpr_True__(7,)cconstexpr_128_"(%vT_ptrs, %offs_v, %offs_n2, %KV_LEN) : (tensor<128x64x!tt.ptr>, tensor<128xi32>, tensor<64xi32>, i32) -> tensor<128x64xbf16> loc(#loc809) + %dp = arith.constant 0.000000e+00 : f32 loc(#loc810) + %dp_59 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc810) + %dp_60 = tt.dot %do, %vT, %dp_59, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc810) + %ds = tt.expand_dims %Di {axis = 1 : i32} : tensor<128xf32> -> tensor<128x1xf32> loc(#loc811) + %ds_61 = tt.broadcast %ds : tensor<128x1xf32> -> tensor<128x64xf32> loc(#loc812) + %ds_62 = arith.subf %dp_60, %ds_61 : tensor<128x64xf32> loc(#loc812) + %ds_63 = arith.mulf %p_58, %ds_62 : tensor<128x64xf32> loc(#loc813) + %scatter_mask = tt.expand_dims %offs_m2 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc814) + %scatter_mask_64 = tt.splat %Q_LEN : i32 -> tensor<128x1xi32> loc(#loc815) + %scatter_mask_65 = arith.cmpi slt, %scatter_mask, %scatter_mask_64 : tensor<128x1xi32> loc(#loc815) + %scatter_mask_66 = tt.expand_dims %offs_n2 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc816) + %scatter_mask_67 = tt.splat %KV_LEN : i32 -> tensor<1x64xi32> loc(#loc817) + %scatter_mask_68 = arith.cmpi slt, %scatter_mask_66, %scatter_mask_67 : tensor<1x64xi32> loc(#loc817) + %scatter_mask_69 = tt.broadcast %scatter_mask_65 : tensor<128x1xi1> -> tensor<128x64xi1> loc(#loc818) + %scatter_mask_70 = tt.broadcast %scatter_mask_68 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc818) + %scatter_mask_71 = arith.andi %scatter_mask_69, %scatter_mask_70 : tensor<128x64xi1> loc(#loc818) + %ds_72 = arith.constant 0.000000e+00 : f32 loc(#loc819) + %ds_73 = arith.constant 0.000000e+00 : f32 loc(#loc819) + %ds_74 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc819) + %ds_75 = arith.select %tmp38, %ds_63, %ds_74 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc819) + %ds_76 = arith.truncf %ds_75 : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc820) + %dq_77 = tt.trans %kT {order = array} : tensor<128x64xbf16> -> tensor<64x128xbf16> loc(#loc821) + %dq_78 = arith.constant 0.000000e+00 : f32 loc(#loc822) + %dq_79 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> loc(#loc822) + %dq_80 = tt.dot %ds_76, %dq_77, %dq_79, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc822) + %dq_81 = arith.addf %dq, %dq_80 : tensor<128x128xf32> loc(#loc823) + tt.return %dq_81 : tensor<128x128xf32> loc(#loc316) + ^bb1: // no predecessors + %0 = ub.poison : tensor<128x128xf32> loc(#loc317) + tt.return %0 : tensor<128x128xf32> loc(#loc317) + } loc(#loc253) + tt.func private @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.load_checked_2d__Pbf16S128_64S_i32S128S_i32S64S_i32__(3,)cconstexpr_None__(4,)cconstexpr_None__(5,)cconstexpr_True__(6,)cconstexpr_True__(7,)cconstexpr_128_"(%ptr: tensor<128x64x!tt.ptr> loc("ptr"(#loc211)), %offs_m: tensor<128xi32> loc("offs_m"(#loc211)), %offs_n: tensor<64xi32> loc("offs_n"(#loc211)), %N_LEN: i32 loc("N_LEN"(#loc211))) -> tensor<128x64xbf16> attributes {noinline = false} { + %0 = tt.load %ptr : tensor<128x64x!tt.ptr> loc(#loc218) + tt.return %0 : tensor<128x64xbf16> loc(#loc219) + ^bb1: // no predecessors + %1 = ub.poison : tensor<128x64xbf16> loc(#loc220) + tt.return %1 : tensor<128x64xbf16> loc(#loc220) + } loc(#loc211) + tt.func private @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.get_bounded_indices__i32S1_64S__(1,)cconstexpr_None_"(%indices: tensor<1x64xi32> loc("indices"(#loc318))) -> tensor<1x64xi32> attributes {noinline = false} { + tt.return %indices : tensor<1x64xi32> loc(#loc319) + ^bb1: // no predecessors + %0 = ub.poison : tensor<1x64xi32> loc(#loc320) + tt.return %0 : tensor<1x64xi32> loc(#loc320) + } loc(#loc318) + tt.func private @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.get_bounded_indices__i32S128_1S__(1,)cconstexpr_None_"(%indices: tensor<128x1xi32> loc("indices"(#loc318))) -> tensor<128x1xi32> attributes {noinline = false} { + tt.return %indices : tensor<128x1xi32> loc(#loc319) + ^bb1: // no predecessors + %0 = ub.poison : tensor<128x1xi32> loc(#loc320) + tt.return %0 : tensor<128x1xi32> loc(#loc320) + } loc(#loc318) + tt.func private @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.get_offset_for_next_block__i32_Pi32_i32__(3,)cconstexpr_128__(4,)cconstexpr_2__(5,)cconstexpr_64__(6,)cconstexpr_False_"(%loop_iter: i32 loc("loop_iter"(#loc321)), %col_indices: !tt.ptr loc("col_indices"(#loc321)), %total_blocks: i32 loc("total_blocks"(#loc321))) -> i32 attributes {noinline = false} { + %cur_block_idx = arith.constant 2 : i32 loc(#loc829) + %cur_block_idx_0 = arith.constant 2 : i32 loc(#loc829) + %cur_block_idx_1 = arith.divsi %loop_iter, %cur_block_idx_0 : i32 loc(#loc829) + %cur_block = tt.addptr %col_indices, %cur_block_idx_1 : !tt.ptr, i32 loc(#loc830) + %cur_block_2 = tt.load %cur_block evictionPolicy = evict_last : !tt.ptr loc(#loc831) + %next_block = arith.constant 1 : i32 loc(#loc832) + %next_block_3 = arith.constant 1 : i32 loc(#loc832) + %next_block_4 = arith.addi %cur_block_idx_1, %next_block_3 : i32 loc(#loc832) + %next_block_5 = arith.cmpi slt, %next_block_4, %total_blocks : i32 loc(#loc833) + %next_block_6 = tt.addptr %col_indices, %cur_block_idx_1 : !tt.ptr, i32 loc(#loc834) + %next_block_7 = arith.constant 1 : i32 loc(#loc835) + %next_block_8 = tt.addptr %next_block_6, %next_block_7 : !tt.ptr, i32 loc(#loc835) + %next_block_9 = tt.load %next_block_8, %next_block_5 evictionPolicy = evict_last : !tt.ptr loc(#loc836) + %needs_jump = arith.constant 1 : i32 loc(#loc837) + %needs_jump_10 = arith.constant 1 : i32 loc(#loc837) + %needs_jump_11 = arith.addi %loop_iter, %needs_jump_10 : i32 loc(#loc837) + %needs_jump_12 = arith.constant 2 : i32 loc(#loc838) + %needs_jump_13 = arith.constant 2 : i32 loc(#loc838) + %needs_jump_14 = arith.remsi %needs_jump_11, %needs_jump_13 : i32 loc(#loc838) + %needs_jump_15 = arith.constant 0 : i32 loc(#loc839) + %needs_jump_16 = arith.cmpi eq, %needs_jump_14, %needs_jump_15 : i32 loc(#loc839) + %jump_to_block = arith.subi %next_block_9, %cur_block_2 : i32 loc(#loc840) + %jump_to_block_17 = arith.constant 128 : i32 loc(#loc841) + %jump_to_block_18 = arith.constant 128 : i32 loc(#loc841) + %jump_to_block_19 = arith.muli %jump_to_block, %jump_to_block_18 : i32 loc(#loc841) + %jump_to_block_20 = arith.constant 64 : i32 loc(#loc842) + %jump_to_block_21 = arith.constant 64 : i32 loc(#loc842) + %jump_to_block_22 = arith.subi %jump_to_block_19, %jump_to_block_21 : i32 loc(#loc842) + %offset = arith.extui %needs_jump_16 : i1 to i32 loc(#loc843) + %offset_23 = arith.muli %jump_to_block_22, %offset : i32 loc(#loc843) + %offset_24 = arith.constant 1 : i32 loc(#loc844) + %offset_25 = arith.constant 1 : i32 loc(#loc844) + %offset_26 = arith.extui %needs_jump_16 : i1 to i32 loc(#loc844) + %offset_27 = arith.subi %offset_25, %offset_26 : i32 loc(#loc844) + %offset_28 = arith.constant 64 : i32 loc(#loc845) + %offset_29 = arith.constant 64 : i32 loc(#loc845) + %offset_30 = arith.muli %offset_27, %offset_29 : i32 loc(#loc845) + %offset_31 = arith.addi %offset_23, %offset_30 : i32 loc(#loc846) + tt.return %offset_31 : i32 loc(#loc340) + ^bb1: // no predecessors + %0 = ub.poison : i32 loc(#loc341) + tt.return %0 : i32 loc(#loc341) + } loc(#loc321) + tt.func private @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.bwd_dq_inner__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi64_Pbf16_Pbf16_Pbf16_fp32S128_128S_bf16S128_128S_bf16S128_128S_fp32S128S_fp32S128_1S_i32_i32_i32S128S_i32S64S_i32_i32_i32_i32_Pi32_i32__(35,)cconstexpr_bf16__(36,)cconstexpr_True_"(%arg_Q: !tt.ptr loc("arg_Q"(#loc221)), %arg_K: !tt.ptr loc("arg_K"(#loc221)), %arg_V: !tt.ptr loc("arg_V"(#loc221)), %arg_LSE: !tt.ptr loc("arg_LSE"(#loc221)), %arg_DELTA: !tt.ptr loc("arg_DELTA"(#loc221)), %arg_DO: !tt.ptr loc("arg_DO"(#loc221)), %arg_DQ: !tt.ptr loc("arg_DQ"(#loc221)), %arg_DV: !tt.ptr loc("arg_DV"(#loc221)), %arg_KV_NUM_BLKS: !tt.ptr loc("arg_KV_NUM_BLKS"(#loc221)), %arg_KV_IDX: !tt.ptr loc("arg_KV_IDX"(#loc221)), %arg_Q_NUM_BLKS: !tt.ptr loc("arg_Q_NUM_BLKS"(#loc221)), %arg_Q_IDX: !tt.ptr loc("arg_Q_IDX"(#loc221)), %arg_FULL_KV_NUM_BLKS: !tt.ptr loc("arg_FULL_KV_NUM_BLKS"(#loc221)), %arg_FULL_KV_IDX: !tt.ptr loc("arg_FULL_KV_IDX"(#loc221)), %arg_FULL_Q_NUM_BLKS: !tt.ptr loc("arg_FULL_Q_NUM_BLKS"(#loc221)), %arg_FULL_Q_IDX: !tt.ptr loc("arg_FULL_Q_IDX"(#loc221)), %in_ptr16: !tt.ptr loc("in_ptr16"(#loc221)), %out_ptr0: !tt.ptr loc("out_ptr0"(#loc221)), %K: !tt.ptr loc("K"(#loc221)), %V: !tt.ptr loc("V"(#loc221)), %dq: tensor<128x128xf32> loc("dq"(#loc221)), %q: tensor<128x128xbf16> loc("q"(#loc221)), %do: tensor<128x128xbf16> loc("do"(#loc221)), %Di: tensor<128xf32> loc("Di"(#loc221)), %lse: tensor<128x1xf32> loc("lse"(#loc221)), %off_z: i32 loc("off_z"(#loc221)), %off_hq: i32 loc("off_hq"(#loc221)), %offs_m2: tensor<128xi32> loc("offs_m2"(#loc221)), %offs_n2: tensor<64xi32> loc("offs_n2"(#loc221)), %stride_kn: i32 loc("stride_kn"(#loc221)), %stride_kd: i32 loc("stride_kd"(#loc221)), %stride_vn: i32 loc("stride_vn"(#loc221)), %stride_vd: i32 loc("stride_vd"(#loc221)), %kv_indices: !tt.ptr loc("kv_indices"(#loc221)), %sparse_kv_num_blocks: i32 loc("sparse_kv_num_blocks"(#loc221))) -> tensor<128x128xf32> attributes {noinline = false} { + %Q_LEN = arith.constant 2048 : i32 loc(#loc695) + %KV_LEN = arith.constant 2048 : i32 loc(#loc696) + %offs_k = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc697) + %offs_v = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc698) + %kT_ptrs = tt.expand_dims %offs_n2 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc699) + %kT_ptrs_0 = tt.splat %stride_kn : i32 -> tensor<1x64xi32> loc(#loc700) + %kT_ptrs_1 = arith.muli %kT_ptrs, %kT_ptrs_0 : tensor<1x64xi32> loc(#loc700) + %kT_ptrs_2 = tt.splat %K : !tt.ptr -> tensor<1x64x!tt.ptr> loc(#loc701) + %kT_ptrs_3 = tt.addptr %kT_ptrs_2, %kT_ptrs_1 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> loc(#loc701) + %kT_ptrs_4 = tt.expand_dims %offs_k {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc702) + %kT_ptrs_5 = tt.splat %stride_kd : i32 -> tensor<128x1xi32> loc(#loc703) + %kT_ptrs_6 = arith.muli %kT_ptrs_4, %kT_ptrs_5 : tensor<128x1xi32> loc(#loc703) + %kT_ptrs_7 = tt.broadcast %kT_ptrs_3 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> loc(#loc704) + %kT_ptrs_8 = tt.broadcast %kT_ptrs_6 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc704) + %kT_ptrs_9 = tt.addptr %kT_ptrs_7, %kT_ptrs_8 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc704) + %vT_ptrs = tt.expand_dims %offs_n2 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc705) + %vT_ptrs_10 = tt.splat %stride_vn : i32 -> tensor<1x64xi32> loc(#loc706) + %vT_ptrs_11 = arith.muli %vT_ptrs, %vT_ptrs_10 : tensor<1x64xi32> loc(#loc706) + %vT_ptrs_12 = tt.splat %V : !tt.ptr -> tensor<1x64x!tt.ptr> loc(#loc707) + %vT_ptrs_13 = tt.addptr %vT_ptrs_12, %vT_ptrs_11 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> loc(#loc707) + %vT_ptrs_14 = tt.expand_dims %offs_v {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc708) + %vT_ptrs_15 = tt.splat %stride_vd : i32 -> tensor<128x1xi32> loc(#loc709) + %vT_ptrs_16 = arith.muli %vT_ptrs_14, %vT_ptrs_15 : tensor<128x1xi32> loc(#loc709) + %vT_ptrs_17 = tt.broadcast %vT_ptrs_13 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> loc(#loc710) + %vT_ptrs_18 = tt.broadcast %vT_ptrs_16 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc710) + %vT_ptrs_19 = tt.addptr %vT_ptrs_17, %vT_ptrs_18 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc710) + %hi = arith.constant 2 : i32 loc(#loc711) + %hi_20 = arith.constant 2 : i32 loc(#loc711) + %hi_21 = arith.muli %sparse_kv_num_blocks, %hi_20 : i32 loc(#loc711) + %hi_22 = tt.call @"triton.language.standard.cdiv__i32__(1,)cconstexpr_64_"(%KV_LEN) : (i32) -> i32 loc(#loc712) + %hi_23 = arith.constant 1 : i32 loc(#loc713) + %hi_24 = arith.maxsi %hi_22, %hi_23 : i32 loc(#loc713) + %hi_25 = arith.minsi %hi_21, %hi_24 : i32 loc(#loc714) + %c0_i32 = arith.constant 0 : i32 loc(#loc242) + %c1_i32 = arith.constant 1 : i32 loc(#loc242) + %0 = arith.bitcast %c0_i32 : i32 to i32 loc(#loc242) + %1 = arith.bitcast %hi_25 : i32 to i32 loc(#loc242) + %2 = arith.bitcast %c1_i32 : i32 to i32 loc(#loc242) + %3 = ub.poison : i32 loc(#loc242) + %vT_ptrs_26:4 = scf.for %start_n = %0 to %1 step %2 iter_args(%dq_27 = %dq, %offs_n2_28 = %offs_n2, %kT_ptrs_29 = %kT_ptrs_9, %vT_ptrs_30 = %vT_ptrs_19) -> (tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<128x64x!tt.ptr>) : i32 { + %dq_31 = tt.call @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.bwd_dq_block_mn__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi64_Pbf16_fp32S128_128S_bf16S128_128S_Pbf16S128_64S_Pbf16S128_64S_bf16S128_128S_fp32S128S_fp32S128_1S_i32_i32_i32_i32_i32S128S_i32S64S_i32S128S_i32S128S_i32_i32_i32_i32_Pi32_i32__(39,)cconstexpr_bf16__(40,)cconstexpr_1_d_44269504__(41,)cconstexpr_True_"(%arg_Q, %arg_K, %arg_V, %arg_LSE, %arg_DELTA, %arg_DO, %arg_DQ, %arg_DV, %arg_KV_NUM_BLKS, %arg_KV_IDX, %arg_Q_NUM_BLKS, %arg_Q_IDX, %arg_FULL_KV_NUM_BLKS, %arg_FULL_KV_IDX, %arg_FULL_Q_NUM_BLKS, %arg_FULL_Q_IDX, %in_ptr16, %out_ptr0, %dq_27, %q, %kT_ptrs_29, %vT_ptrs_30, %do, %Di, %lse, %Q_LEN, %KV_LEN, %off_z, %off_hq, %offs_m2, %offs_n2_28, %offs_k, %offs_v, %stride_kn, %stride_kd, %stride_vn, %stride_vd, %kv_indices, %sparse_kv_num_blocks) : (!tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, tensor<128x128xf32>, tensor<128x128xbf16>, tensor<128x64x!tt.ptr>, tensor<128x64x!tt.ptr>, tensor<128x128xbf16>, tensor<128xf32>, tensor<128x1xf32>, i32, i32, i32, i32, tensor<128xi32>, tensor<64xi32>, tensor<128xi32>, tensor<128xi32>, i32, i32, i32, i32, !tt.ptr, i32) -> tensor<128x128xf32> loc(#loc716) + %offset = tt.call @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.get_offset_for_next_block__i32_Pi32_i32__(3,)cconstexpr_128__(4,)cconstexpr_2__(5,)cconstexpr_64__(6,)cconstexpr_False_"(%start_n, %kv_indices, %sparse_kv_num_blocks) : (i32, !tt.ptr, i32) -> i32 loc(#loc717) + %kT_ptrs_32 = arith.muli %offset, %stride_kn : i32 loc(#loc718) + %kT_ptrs_33 = tt.splat %kT_ptrs_32 : i32 -> tensor<128x64xi32> loc(#loc719) + %kT_ptrs_34 = tt.addptr %kT_ptrs_29, %kT_ptrs_33 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc719) + %vT_ptrs_35 = arith.muli %offset, %stride_vn : i32 loc(#loc720) + %vT_ptrs_36 = tt.splat %vT_ptrs_35 : i32 -> tensor<128x64xi32> loc(#loc721) + %vT_ptrs_37 = tt.addptr %vT_ptrs_30, %vT_ptrs_36 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc721) + %offs_n2_38 = tt.splat %offset : i32 -> tensor<64xi32> loc(#loc722) + %offs_n2_39 = arith.addi %offs_n2_28, %offs_n2_38 : tensor<64xi32> loc(#loc722) + scf.yield %dq_31, %offs_n2_39, %kT_ptrs_34, %vT_ptrs_37 : tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<128x64x!tt.ptr> loc(#loc250) + } loc(#loc1027) + tt.return %vT_ptrs_26#0 : tensor<128x128xf32> loc(#loc251) + ^bb1: // no predecessors + %4 = ub.poison : tensor<128x128xf32> loc(#loc252) + tt.return %4 : tensor<128x128xf32> loc(#loc252) + } loc(#loc221) + tt.func private @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.bwd_dq_block_mn__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi64_Pbf16_fp32S128_128S_bf16S128_128S_Pbf16S128_64S_Pbf16S128_64S_bf16S128_128S_fp32S128S_fp32S128_1S_i32_i32_i32_i32_i32S128S_i32S64S_i32S128S_i32S128S_i32_i32_i32_i32_Pi32_i32__(39,)cconstexpr_bf16__(40,)cconstexpr_1_d_44269504__(41,)cconstexpr_True_"(%arg_Q: !tt.ptr loc("arg_Q"(#loc253)), %arg_K: !tt.ptr loc("arg_K"(#loc253)), %arg_V: !tt.ptr loc("arg_V"(#loc253)), %arg_LSE: !tt.ptr loc("arg_LSE"(#loc253)), %arg_DELTA: !tt.ptr loc("arg_DELTA"(#loc253)), %arg_DO: !tt.ptr loc("arg_DO"(#loc253)), %arg_DQ: !tt.ptr loc("arg_DQ"(#loc253)), %arg_DV: !tt.ptr loc("arg_DV"(#loc253)), %arg_KV_NUM_BLKS: !tt.ptr loc("arg_KV_NUM_BLKS"(#loc253)), %arg_KV_IDX: !tt.ptr loc("arg_KV_IDX"(#loc253)), %arg_Q_NUM_BLKS: !tt.ptr loc("arg_Q_NUM_BLKS"(#loc253)), %arg_Q_IDX: !tt.ptr loc("arg_Q_IDX"(#loc253)), %arg_FULL_KV_NUM_BLKS: !tt.ptr loc("arg_FULL_KV_NUM_BLKS"(#loc253)), %arg_FULL_KV_IDX: !tt.ptr loc("arg_FULL_KV_IDX"(#loc253)), %arg_FULL_Q_NUM_BLKS: !tt.ptr loc("arg_FULL_Q_NUM_BLKS"(#loc253)), %arg_FULL_Q_IDX: !tt.ptr loc("arg_FULL_Q_IDX"(#loc253)), %in_ptr16: !tt.ptr loc("in_ptr16"(#loc253)), %out_ptr0: !tt.ptr loc("out_ptr0"(#loc253)), %dq: tensor<128x128xf32> loc("dq"(#loc253)), %q: tensor<128x128xbf16> loc("q"(#loc253)), %kT_ptrs: tensor<128x64x!tt.ptr> loc("kT_ptrs"(#loc253)), %vT_ptrs: tensor<128x64x!tt.ptr> loc("vT_ptrs"(#loc253)), %do: tensor<128x128xbf16> loc("do"(#loc253)), %Di: tensor<128xf32> loc("Di"(#loc253)), %lse: tensor<128x1xf32> loc("lse"(#loc253)), %Q_LEN: i32 loc("Q_LEN"(#loc253)), %KV_LEN: i32 loc("KV_LEN"(#loc253)), %off_z: i32 loc("off_z"(#loc253)), %off_hq: i32 loc("off_hq"(#loc253)), %offs_m2: tensor<128xi32> loc("offs_m2"(#loc253)), %offs_n2: tensor<64xi32> loc("offs_n2"(#loc253)), %offs_k: tensor<128xi32> loc("offs_k"(#loc253)), %offs_v: tensor<128xi32> loc("offs_v"(#loc253)), %stride_kn: i32 loc("stride_kn"(#loc253)), %stride_kd: i32 loc("stride_kd"(#loc253)), %stride_vn: i32 loc("stride_vn"(#loc253)), %stride_vd: i32 loc("stride_vd"(#loc253)), %kv_indices: !tt.ptr loc("kv_indices"(#loc253)), %sparse_kv_num_blocks: i32 loc("sparse_kv_num_blocks"(#loc253))) -> tensor<128x128xf32> attributes {noinline = false} { + %kT = tt.call @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.load_checked_2d__Pbf16S128_64S_i32S128S_i32S64S_i32__(3,)cconstexpr_None__(4,)cconstexpr_None__(5,)cconstexpr_True__(6,)cconstexpr_True__(7,)cconstexpr_128_"(%kT_ptrs, %offs_k, %offs_n2, %KV_LEN) : (tensor<128x64x!tt.ptr>, tensor<128xi32>, tensor<64xi32>, i32) -> tensor<128x64xbf16> loc(#loc762) + %qk = arith.constant 0.000000e+00 : f32 loc(#loc763) + %qk_0 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc763) + %qk_1 = tt.dot %q, %kT, %qk_0, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc763) + %qk_2 = arith.constant 0.0883883461 : f32 loc(#loc764) + %qk_3 = arith.constant 0.0883883461 : f32 loc(#loc764) + %qk_4 = arith.constant dense<0.0883883461> : tensor<128x64xf32> loc(#loc764) + %qk_5 = arith.mulf %qk_1, %qk_4 : tensor<128x64xf32> loc(#loc764) + %n = tt.expand_dims %offs_n2 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc765) + %n_6 = tt.call @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.get_bounded_indices__i32S1_64S__(1,)cconstexpr_None_"(%n) : (tensor<1x64xi32>) -> tensor<1x64xi32> loc(#loc766) + %m = tt.expand_dims %offs_m2 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc767) + %m_7 = tt.call @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.get_bounded_indices__i32S128_1S__(1,)cconstexpr_None_"(%m) : (tensor<128x1xi32>) -> tensor<128x1xi32> loc(#loc768) + %post_mod_scores = arith.constant 1.44269502 : f32 loc(#loc806) + %post_mod_scores_8 = arith.constant 1.44269502 : f32 loc(#loc806) + %post_mod_scores_9 = arith.constant dense<1.44269502> : tensor<128x64xf32> loc(#loc806) + %post_mod_scores_10 = arith.mulf %qk_5, %post_mod_scores_9 : tensor<128x64xf32> loc(#loc806) + %p = tt.broadcast %lse : tensor<128x1xf32> -> tensor<128x64xf32> loc(#loc807) + %p_11 = arith.subf %post_mod_scores_10, %p : tensor<128x64xf32> loc(#loc807) + %p_12 = math.exp2 %p_11 : tensor<128x64xf32> loc(#loc808) + %vT = tt.call @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.load_checked_2d__Pbf16S128_64S_i32S128S_i32S64S_i32__(3,)cconstexpr_None__(4,)cconstexpr_None__(5,)cconstexpr_True__(6,)cconstexpr_True__(7,)cconstexpr_128_"(%vT_ptrs, %offs_v, %offs_n2, %KV_LEN) : (tensor<128x64x!tt.ptr>, tensor<128xi32>, tensor<64xi32>, i32) -> tensor<128x64xbf16> loc(#loc809) + %dp = arith.constant 0.000000e+00 : f32 loc(#loc810) + %dp_13 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc810) + %dp_14 = tt.dot %do, %vT, %dp_13, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc810) + %ds = tt.expand_dims %Di {axis = 1 : i32} : tensor<128xf32> -> tensor<128x1xf32> loc(#loc811) + %ds_15 = tt.broadcast %ds : tensor<128x1xf32> -> tensor<128x64xf32> loc(#loc812) + %ds_16 = arith.subf %dp_14, %ds_15 : tensor<128x64xf32> loc(#loc812) + %ds_17 = arith.mulf %p_12, %ds_16 : tensor<128x64xf32> loc(#loc813) + %scatter_mask = tt.expand_dims %offs_m2 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc814) + %scatter_mask_18 = tt.splat %Q_LEN : i32 -> tensor<128x1xi32> loc(#loc815) + %scatter_mask_19 = arith.cmpi slt, %scatter_mask, %scatter_mask_18 : tensor<128x1xi32> loc(#loc815) + %scatter_mask_20 = tt.expand_dims %offs_n2 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc816) + %scatter_mask_21 = tt.splat %KV_LEN : i32 -> tensor<1x64xi32> loc(#loc817) + %scatter_mask_22 = arith.cmpi slt, %scatter_mask_20, %scatter_mask_21 : tensor<1x64xi32> loc(#loc817) + %scatter_mask_23 = tt.broadcast %scatter_mask_19 : tensor<128x1xi1> -> tensor<128x64xi1> loc(#loc818) + %scatter_mask_24 = tt.broadcast %scatter_mask_22 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc818) + %scatter_mask_25 = arith.andi %scatter_mask_23, %scatter_mask_24 : tensor<128x64xi1> loc(#loc818) + %ds_26 = arith.truncf %ds_17 : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc820) + %dq_27 = tt.trans %kT {order = array} : tensor<128x64xbf16> -> tensor<64x128xbf16> loc(#loc821) + %dq_28 = arith.constant 0.000000e+00 : f32 loc(#loc822) + %dq_29 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> loc(#loc822) + %dq_30 = tt.dot %ds_26, %dq_27, %dq_29, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc822) + %dq_31 = arith.addf %dq, %dq_30 : tensor<128x128xf32> loc(#loc823) + tt.return %dq_31 : tensor<128x128xf32> loc(#loc316) + ^bb1: // no predecessors + %0 = ub.poison : tensor<128x128xf32> loc(#loc317) + tt.return %0 : tensor<128x128xf32> loc(#loc317) + } loc(#loc253) + tt.func private @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.bwd_dkdv_inner__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi64_Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_fp32S128_128S_fp32S128_128S_bf16S128_128S_bf16S128_128S_i32_i32_i32S128S_i32S64S_i32_i32_i32_i32_Pi32_i32__(36,)cconstexpr_bf16__(37,)cconstexpr_False_"(%arg_Q: !tt.ptr loc("arg_Q"(#loc342)), %arg_K: !tt.ptr loc("arg_K"(#loc342)), %arg_V: !tt.ptr loc("arg_V"(#loc342)), %arg_LSE: !tt.ptr loc("arg_LSE"(#loc342)), %arg_DELTA: !tt.ptr loc("arg_DELTA"(#loc342)), %arg_DO: !tt.ptr loc("arg_DO"(#loc342)), %arg_DQ: !tt.ptr loc("arg_DQ"(#loc342)), %arg_DV: !tt.ptr loc("arg_DV"(#loc342)), %arg_KV_NUM_BLKS: !tt.ptr loc("arg_KV_NUM_BLKS"(#loc342)), %arg_KV_IDX: !tt.ptr loc("arg_KV_IDX"(#loc342)), %arg_Q_NUM_BLKS: !tt.ptr loc("arg_Q_NUM_BLKS"(#loc342)), %arg_Q_IDX: !tt.ptr loc("arg_Q_IDX"(#loc342)), %arg_FULL_KV_NUM_BLKS: !tt.ptr loc("arg_FULL_KV_NUM_BLKS"(#loc342)), %arg_FULL_KV_IDX: !tt.ptr loc("arg_FULL_KV_IDX"(#loc342)), %arg_FULL_Q_NUM_BLKS: !tt.ptr loc("arg_FULL_Q_NUM_BLKS"(#loc342)), %arg_FULL_Q_IDX: !tt.ptr loc("arg_FULL_Q_IDX"(#loc342)), %in_ptr16: !tt.ptr loc("in_ptr16"(#loc342)), %out_ptr0: !tt.ptr loc("out_ptr0"(#loc342)), %Q: !tt.ptr loc("Q"(#loc342)), %DO: !tt.ptr loc("DO"(#loc342)), %DELTA: !tt.ptr loc("DELTA"(#loc342)), %LSE: !tt.ptr loc("LSE"(#loc342)), %dk: tensor<128x128xf32> loc("dk"(#loc342)), %dv: tensor<128x128xf32> loc("dv"(#loc342)), %k: tensor<128x128xbf16> loc("k"(#loc342)), %v: tensor<128x128xbf16> loc("v"(#loc342)), %off_z: i32 loc("off_z"(#loc342)), %off_hq: i32 loc("off_hq"(#loc342)), %offs_n1: tensor<128xi32> loc("offs_n1"(#loc342)), %offs_m1: tensor<64xi32> loc("offs_m1"(#loc342)), %stride_qm: i32 loc("stride_qm"(#loc342)), %stride_qd: i32 loc("stride_qd"(#loc342)), %stride_dom: i32 loc("stride_dom"(#loc342)), %stride_dod: i32 loc("stride_dod"(#loc342)), %q_indices: !tt.ptr loc("q_indices"(#loc342)), %sparse_q_num_blocks: i32 loc("sparse_q_num_blocks"(#loc342))) -> (tensor<128x128xf32>, tensor<128x128xf32>) attributes {noinline = false} { + %Q_LEN = arith.constant 2048 : i32 loc(#loc883) + %KV_LEN = arith.constant 2048 : i32 loc(#loc884) + %offs_k = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc885) + %offs_v = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc886) + %qT_ptrs = tt.expand_dims %offs_m1 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc887) + %qT_ptrs_0 = tt.splat %stride_qm : i32 -> tensor<1x64xi32> loc(#loc888) + %qT_ptrs_1 = arith.muli %qT_ptrs, %qT_ptrs_0 : tensor<1x64xi32> loc(#loc888) + %qT_ptrs_2 = tt.splat %Q : !tt.ptr -> tensor<1x64x!tt.ptr> loc(#loc889) + %qT_ptrs_3 = tt.addptr %qT_ptrs_2, %qT_ptrs_1 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> loc(#loc889) + %qT_ptrs_4 = tt.expand_dims %offs_k {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc890) + %qT_ptrs_5 = tt.splat %stride_qd : i32 -> tensor<128x1xi32> loc(#loc891) + %qT_ptrs_6 = arith.muli %qT_ptrs_4, %qT_ptrs_5 : tensor<128x1xi32> loc(#loc891) + %qT_ptrs_7 = tt.broadcast %qT_ptrs_3 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> loc(#loc892) + %qT_ptrs_8 = tt.broadcast %qT_ptrs_6 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc892) + %qT_ptrs_9 = tt.addptr %qT_ptrs_7, %qT_ptrs_8 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc892) + %do_ptrs = tt.expand_dims %offs_m1 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc893) + %do_ptrs_10 = tt.splat %stride_dom : i32 -> tensor<64x1xi32> loc(#loc894) + %do_ptrs_11 = arith.muli %do_ptrs, %do_ptrs_10 : tensor<64x1xi32> loc(#loc894) + %do_ptrs_12 = tt.splat %DO : !tt.ptr -> tensor<64x1x!tt.ptr> loc(#loc895) + %do_ptrs_13 = tt.addptr %do_ptrs_12, %do_ptrs_11 : tensor<64x1x!tt.ptr>, tensor<64x1xi32> loc(#loc895) + %do_ptrs_14 = tt.expand_dims %offs_v {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc896) + %do_ptrs_15 = tt.splat %stride_dod : i32 -> tensor<1x128xi32> loc(#loc897) + %do_ptrs_16 = arith.muli %do_ptrs_14, %do_ptrs_15 : tensor<1x128xi32> loc(#loc897) + %do_ptrs_17 = tt.broadcast %do_ptrs_13 : tensor<64x1x!tt.ptr> -> tensor<64x128x!tt.ptr> loc(#loc898) + %do_ptrs_18 = tt.broadcast %do_ptrs_16 : tensor<1x128xi32> -> tensor<64x128xi32> loc(#loc898) + %do_ptrs_19 = tt.addptr %do_ptrs_17, %do_ptrs_18 : tensor<64x128x!tt.ptr>, tensor<64x128xi32> loc(#loc898) + %hi = arith.constant 2 : i32 loc(#loc899) + %hi_20 = arith.constant 2 : i32 loc(#loc899) + %hi_21 = arith.muli %sparse_q_num_blocks, %hi_20 : i32 loc(#loc899) + %hi_22 = tt.call @"triton.language.standard.cdiv__i32__(1,)cconstexpr_64_"(%Q_LEN) : (i32) -> i32 loc(#loc900) + %hi_23 = arith.constant 1 : i32 loc(#loc901) + %hi_24 = arith.maxsi %hi_22, %hi_23 : i32 loc(#loc901) + %hi_25 = arith.minsi %hi_21, %hi_24 : i32 loc(#loc902) + %c0_i32 = arith.constant 0 : i32 loc(#loc363) + %c1_i32 = arith.constant 1 : i32 loc(#loc363) + %0 = arith.bitcast %c0_i32 : i32 to i32 loc(#loc363) + %1 = arith.bitcast %hi_25 : i32 to i32 loc(#loc363) + %2 = arith.bitcast %c1_i32 : i32 to i32 loc(#loc363) + %3 = ub.poison : i32 loc(#loc363) + %do_ptrs_26:5 = scf.for %start_m = %0 to %1 step %2 iter_args(%dk_27 = %dk, %dv_28 = %dv, %offs_m1_29 = %offs_m1, %qT_ptrs_30 = %qT_ptrs_9, %do_ptrs_31 = %do_ptrs_19) -> (tensor<128x128xf32>, tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<64x128x!tt.ptr>) : i32 { + %6:2 = tt.call @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.bwd_dkdv_block_mn__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi64_Pbf16_fp32S128_128S_fp32S128_128S_Pbf16S128_64S_bf16S128_128S_bf16S128_128S_Pbf16S64_128S_Pfp32_Pfp32_i32_i32_i32_i32_i32S128S_i32S64S_i32S128S_i32S128S_i32_i32_i32_i32_Pi32_i32__(40,)cconstexpr_bf16__(41,)cconstexpr_1_d_44269504__(42,)cconstexpr_False_"(%arg_Q, %arg_K, %arg_V, %arg_LSE, %arg_DELTA, %arg_DO, %arg_DQ, %arg_DV, %arg_KV_NUM_BLKS, %arg_KV_IDX, %arg_Q_NUM_BLKS, %arg_Q_IDX, %arg_FULL_KV_NUM_BLKS, %arg_FULL_KV_IDX, %arg_FULL_Q_NUM_BLKS, %arg_FULL_Q_IDX, %in_ptr16, %out_ptr0, %dk_27, %dv_28, %qT_ptrs_30, %k, %v, %do_ptrs_31, %DELTA, %LSE, %Q_LEN, %KV_LEN, %off_z, %off_hq, %offs_n1, %offs_m1_29, %offs_k, %offs_v, %stride_qm, %stride_qd, %stride_dom, %stride_dod, %q_indices, %sparse_q_num_blocks) : (!tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x64x!tt.ptr>, tensor<128x128xbf16>, tensor<128x128xbf16>, tensor<64x128x!tt.ptr>, !tt.ptr, !tt.ptr, i32, i32, i32, i32, tensor<128xi32>, tensor<64xi32>, tensor<128xi32>, tensor<128xi32>, i32, i32, i32, i32, !tt.ptr, i32) -> (tensor<128x128xf32>, tensor<128x128xf32>) loc(#loc364) + %offset = tt.call @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.get_offset_for_next_block__i32_Pi32_i32__(3,)cconstexpr_128__(4,)cconstexpr_2__(5,)cconstexpr_64__(6,)cconstexpr_False_"(%start_m, %q_indices, %sparse_q_num_blocks) : (i32, !tt.ptr, i32) -> i32 loc(#loc904) + %qT_ptrs_32 = arith.muli %offset, %stride_qm : i32 loc(#loc905) + %qT_ptrs_33 = tt.splat %qT_ptrs_32 : i32 -> tensor<128x64xi32> loc(#loc906) + %qT_ptrs_34 = tt.addptr %qT_ptrs_30, %qT_ptrs_33 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc906) + %do_ptrs_35 = arith.muli %offset, %stride_dom : i32 loc(#loc907) + %do_ptrs_36 = tt.splat %do_ptrs_35 : i32 -> tensor<64x128xi32> loc(#loc908) + %do_ptrs_37 = tt.addptr %do_ptrs_31, %do_ptrs_36 : tensor<64x128x!tt.ptr>, tensor<64x128xi32> loc(#loc908) + %offs_m1_38 = tt.splat %offset : i32 -> tensor<64xi32> loc(#loc909) + %offs_m1_39 = arith.addi %offs_m1_29, %offs_m1_38 : tensor<64xi32> loc(#loc909) + scf.yield %6#0, %6#1, %offs_m1_39, %qT_ptrs_34, %do_ptrs_37 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<64x128x!tt.ptr> loc(#loc371) + } loc(#loc1029) + tt.return %do_ptrs_26#0, %do_ptrs_26#1 : tensor<128x128xf32>, tensor<128x128xf32> loc(#loc372) + ^bb1: // no predecessors + %4 = ub.poison : tensor<128x128xf32> loc(#loc373) + %5 = ub.poison : tensor<128x128xf32> loc(#loc373) + tt.return %4, %5 : tensor<128x128xf32>, tensor<128x128xf32> loc(#loc373) + } loc(#loc342) + tt.func private @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.bwd_dkdv_block_mn__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi64_Pbf16_fp32S128_128S_fp32S128_128S_Pbf16S128_64S_bf16S128_128S_bf16S128_128S_Pbf16S64_128S_Pfp32_Pfp32_i32_i32_i32_i32_i32S128S_i32S64S_i32S128S_i32S128S_i32_i32_i32_i32_Pi32_i32__(40,)cconstexpr_bf16__(41,)cconstexpr_1_d_44269504__(42,)cconstexpr_False_"(%arg_Q: !tt.ptr loc("arg_Q"(#loc374)), %arg_K: !tt.ptr loc("arg_K"(#loc374)), %arg_V: !tt.ptr loc("arg_V"(#loc374)), %arg_LSE: !tt.ptr loc("arg_LSE"(#loc374)), %arg_DELTA: !tt.ptr loc("arg_DELTA"(#loc374)), %arg_DO: !tt.ptr loc("arg_DO"(#loc374)), %arg_DQ: !tt.ptr loc("arg_DQ"(#loc374)), %arg_DV: !tt.ptr loc("arg_DV"(#loc374)), %arg_KV_NUM_BLKS: !tt.ptr loc("arg_KV_NUM_BLKS"(#loc374)), %arg_KV_IDX: !tt.ptr loc("arg_KV_IDX"(#loc374)), %arg_Q_NUM_BLKS: !tt.ptr loc("arg_Q_NUM_BLKS"(#loc374)), %arg_Q_IDX: !tt.ptr loc("arg_Q_IDX"(#loc374)), %arg_FULL_KV_NUM_BLKS: !tt.ptr loc("arg_FULL_KV_NUM_BLKS"(#loc374)), %arg_FULL_KV_IDX: !tt.ptr loc("arg_FULL_KV_IDX"(#loc374)), %arg_FULL_Q_NUM_BLKS: !tt.ptr loc("arg_FULL_Q_NUM_BLKS"(#loc374)), %arg_FULL_Q_IDX: !tt.ptr loc("arg_FULL_Q_IDX"(#loc374)), %in_ptr16: !tt.ptr loc("in_ptr16"(#loc374)), %out_ptr0: !tt.ptr loc("out_ptr0"(#loc374)), %dk: tensor<128x128xf32> loc("dk"(#loc374)), %dv: tensor<128x128xf32> loc("dv"(#loc374)), %qT_ptrs: tensor<128x64x!tt.ptr> loc("qT_ptrs"(#loc374)), %k: tensor<128x128xbf16> loc("k"(#loc374)), %v: tensor<128x128xbf16> loc("v"(#loc374)), %do_ptrs: tensor<64x128x!tt.ptr> loc("do_ptrs"(#loc374)), %DELTA: !tt.ptr loc("DELTA"(#loc374)), %LSE: !tt.ptr loc("LSE"(#loc374)), %Q_LEN: i32 loc("Q_LEN"(#loc374)), %KV_LEN: i32 loc("KV_LEN"(#loc374)), %off_z: i32 loc("off_z"(#loc374)), %off_hq: i32 loc("off_hq"(#loc374)), %offs_n1: tensor<128xi32> loc("offs_n1"(#loc374)), %offs_m1: tensor<64xi32> loc("offs_m1"(#loc374)), %offs_k: tensor<128xi32> loc("offs_k"(#loc374)), %offs_v: tensor<128xi32> loc("offs_v"(#loc374)), %stride_qm: i32 loc("stride_qm"(#loc374)), %stride_qd: i32 loc("stride_qd"(#loc374)), %stride_dom: i32 loc("stride_dom"(#loc374)), %stride_dod: i32 loc("stride_dod"(#loc374)), %q_indices: !tt.ptr loc("q_indices"(#loc374)), %sparse_q_num_blocks: i32 loc("sparse_q_num_blocks"(#loc374))) -> (tensor<128x128xf32>, tensor<128x128xf32>) attributes {noinline = false} { + %qT = tt.call @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.load_checked_2d__Pbf16S128_64S_i32S128S_i32S64S_i32__(3,)cconstexpr_None__(4,)cconstexpr_None__(5,)cconstexpr_True__(6,)cconstexpr_True__(7,)cconstexpr_128_"(%qT_ptrs, %offs_k, %offs_m1, %Q_LEN) : (tensor<128x64x!tt.ptr>, tensor<128xi32>, tensor<64xi32>, i32) -> tensor<128x64xbf16> loc(#loc950) + %lse = tt.splat %LSE : !tt.ptr -> tensor<64x!tt.ptr> loc(#loc951) + %lse_0 = tt.addptr %lse, %offs_m1 : tensor<64x!tt.ptr>, tensor<64xi32> loc(#loc951) + %lse_1 = tt.load %lse_0 : tensor<64x!tt.ptr> loc(#loc952) + %lse_2 = arith.constant 0xFF800000 : f32 loc(#loc953) + %lse_3 = arith.constant dense<0xFF800000> : tensor<64xf32> loc(#loc953) + %lse_4 = arith.cmpf oeq, %lse_1, %lse_3 : tensor<64xf32> loc(#loc953) + %lse_5 = arith.constant 0.000000e+00 : f32 loc(#loc954) + %lse_6 = arith.constant 0.000000e+00 : f32 loc(#loc954) + %lse_7 = arith.constant dense<0.000000e+00> : tensor<64xf32> loc(#loc954) + %lse_8 = arith.select %lse_4, %lse_7, %lse_1 : tensor<64xi1>, tensor<64xf32> loc(#loc954) + %qkT = arith.constant 0.000000e+00 : f32 loc(#loc955) + %qkT_9 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc955) + %qkT_10 = tt.dot %k, %qT, %qkT_9, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc955) + %qkT_11 = arith.constant 0.0883883461 : f32 loc(#loc956) + %qkT_12 = arith.constant 0.0883883461 : f32 loc(#loc956) + %qkT_13 = arith.constant dense<0.0883883461> : tensor<128x64xf32> loc(#loc956) + %qkT_14 = arith.mulf %qkT_10, %qkT_13 : tensor<128x64xf32> loc(#loc956) + %m = tt.expand_dims %offs_m1 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc957) + %m_15 = tt.call @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.get_bounded_indices__i32S1_64S__(1,)cconstexpr_None_"(%m) : (tensor<1x64xi32>) -> tensor<1x64xi32> loc(#loc958) + %n = tt.expand_dims %offs_n1 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc959) + %n_16 = tt.call @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.get_bounded_indices__i32S128_1S__(1,)cconstexpr_None_"(%n) : (tensor<128x1xi32>) -> tensor<128x1xi32> loc(#loc960) + %tmp41 = arith.constant false loc(#loc961) + %tmp41_17 = arith.constant dense : tensor<1xi1> loc(#loc961) + %tmp44 = tt.broadcast %m_15 : tensor<1x64xi32> -> tensor<128x64xi32> loc(#loc962) + %tmp44_18 = tt.broadcast %n_16 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc962) + %tmp44_19 = arith.cmpi sge, %tmp44, %tmp44_18 : tensor<128x64xi32> loc(#loc962) + %tmp45 = arith.extsi %n_16 : tensor<128x1xi32> to tensor<128x1xi64> loc(#loc963) + %tmp47 = tt.addptr %in_ptr16, %off_z : !tt.ptr, i32 loc(#loc964) + %tmp47_20 = tt.load %tmp47 : !tt.ptr loc(#loc965) + %tmp48 = tt.splat %tmp47_20 : i64 -> tensor<128x1xi64> loc(#loc966) + %tmp48_21 = arith.cmpi slt, %tmp45, %tmp48 : tensor<128x1xi64> loc(#loc966) + %tmp49 = arith.extsi %m_15 : tensor<1x64xi32> to tensor<1x64xi64> loc(#loc967) + %tmp50 = tt.splat %tmp47_20 : i64 -> tensor<1x64xi64> loc(#loc968) + %tmp50_22 = arith.cmpi slt, %tmp49, %tmp50 : tensor<1x64xi64> loc(#loc968) + %tmp51 = tt.broadcast %tmp48_21 : tensor<128x1xi1> -> tensor<128x64xi1> loc(#loc969) + %tmp51_23 = tt.broadcast %tmp50_22 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc969) + %tmp51_24 = arith.andi %tmp51, %tmp51_23 : tensor<128x64xi1> loc(#loc969) + %tmp52 = arith.andi %tmp44_19, %tmp51_24 : tensor<128x64xi1> loc(#loc970) + %tmp53 = tt.expand_dims %tmp41_17 {axis = 0 : i32} : tensor<1xi1> -> tensor<1x1xi1> loc(#loc971) + %tmp53_25 = tt.broadcast %tmp53 : tensor<1x1xi1> -> tensor<128x64xi1> loc(#loc971) + %tmp53_26 = arith.ori %tmp53_25, %tmp52 : tensor<128x64xi1> loc(#loc971) + %tmp54 = arith.constant 2048 : i32 loc(#loc972) + %tmp54_27 = arith.constant dense<2048> : tensor<1xi32> loc(#loc972) + %tmp55 = tt.expand_dims %tmp54_27 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc973) + %tmp55_28 = tt.broadcast %tmp55 : tensor<1x1xi32> -> tensor<128x1xi32> loc(#loc973) + %tmp55_29 = arith.cmpi sge, %n_16, %tmp55_28 : tensor<128x1xi32> loc(#loc973) + %tmp56 = tt.expand_dims %tmp54_27 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc974) + %tmp56_30 = tt.broadcast %tmp56 : tensor<1x1xi32> -> tensor<128x1xi32> loc(#loc974) + %tmp56_31 = arith.remsi %n_16, %tmp56_30 : tensor<128x1xi32> loc(#loc974) + %tmp57 = arith.constant 0 : i32 loc(#loc975) + %tmp57_32 = arith.constant dense<0> : tensor<1xi32> loc(#loc975) + %tmp58 = tt.expand_dims %tmp57_32 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc976) + %tmp58_33 = tt.broadcast %tmp58 : tensor<1x1xi32> -> tensor<128x1xi32> loc(#loc976) + %tmp58_34 = arith.cmpi ne, %tmp56_31, %tmp58_33 : tensor<128x1xi32> loc(#loc976) + %tmp59 = arith.constant 0 : i32 loc(#loc977) + %tmp59_35 = arith.constant dense<0> : tensor<128x1xi32> loc(#loc977) + %tmp59_36 = arith.cmpi slt, %tmp56_31, %tmp59_35 : tensor<128x1xi32> loc(#loc977) + %tmp60 = arith.constant 0 : i32 loc(#loc978) + %tmp60_37 = arith.constant dense<0> : tensor<1xi32> loc(#loc978) + %tmp60_38 = arith.cmpi slt, %tmp54_27, %tmp60_37 : tensor<1xi32> loc(#loc978) + %tmp61 = tt.expand_dims %tmp60_38 {axis = 0 : i32} : tensor<1xi1> -> tensor<1x1xi1> loc(#loc979) + %tmp61_39 = tt.broadcast %tmp61 : tensor<1x1xi1> -> tensor<128x1xi1> loc(#loc979) + %tmp61_40 = arith.cmpi ne, %tmp59_36, %tmp61_39 : tensor<128x1xi1> loc(#loc979) + %tmp62 = arith.andi %tmp58_34, %tmp61_40 : tensor<128x1xi1> loc(#loc980) + %tmp63 = tt.expand_dims %tmp54_27 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc981) + %tmp63_41 = tt.broadcast %tmp63 : tensor<1x1xi32> -> tensor<128x1xi32> loc(#loc981) + %tmp63_42 = arith.addi %tmp56_31, %tmp63_41 : tensor<128x1xi32> loc(#loc981) + %tmp64 = arith.select %tmp62, %tmp63_42, %tmp56_31 : tensor<128x1xi1>, tensor<128x1xi32> loc(#loc982) + %tmp65 = arith.extsi %tmp64 : tensor<128x1xi32> to tensor<128x1xi64> loc(#loc983) + %tmp66 = tt.splat %tmp47_20 : i64 -> tensor<128x1xi64> loc(#loc984) + %tmp66_43 = arith.cmpi slt, %tmp65, %tmp66 : tensor<128x1xi64> loc(#loc984) + %tmp67 = arith.andi %tmp55_29, %tmp66_43 : tensor<128x1xi1> loc(#loc985) + %tmp68 = tt.broadcast %n_16 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc986) + %tmp68_44 = tt.broadcast %m_15 : tensor<1x64xi32> -> tensor<128x64xi32> loc(#loc986) + %tmp68_45 = arith.subi %tmp68, %tmp68_44 : tensor<128x64xi32> loc(#loc986) + %tmp69 = tt.expand_dims %tmp54_27 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc987) + %tmp69_46 = tt.broadcast %tmp69 : tensor<1x1xi32> -> tensor<128x64xi32> loc(#loc987) + %tmp69_47 = arith.remsi %tmp68_45, %tmp69_46 : tensor<128x64xi32> loc(#loc987) + %tmp70 = tt.expand_dims %tmp57_32 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc988) + %tmp70_48 = tt.broadcast %tmp70 : tensor<1x1xi32> -> tensor<128x64xi32> loc(#loc988) + %tmp70_49 = arith.cmpi ne, %tmp69_47, %tmp70_48 : tensor<128x64xi32> loc(#loc988) + %tmp71 = arith.constant 0 : i32 loc(#loc989) + %tmp71_50 = arith.constant dense<0> : tensor<128x64xi32> loc(#loc989) + %tmp71_51 = arith.cmpi slt, %tmp69_47, %tmp71_50 : tensor<128x64xi32> loc(#loc989) + %tmp72 = tt.expand_dims %tmp60_38 {axis = 0 : i32} : tensor<1xi1> -> tensor<1x1xi1> loc(#loc990) + %tmp72_52 = tt.broadcast %tmp72 : tensor<1x1xi1> -> tensor<128x64xi1> loc(#loc990) + %tmp72_53 = arith.cmpi ne, %tmp71_51, %tmp72_52 : tensor<128x64xi1> loc(#loc990) + %tmp73 = arith.andi %tmp70_49, %tmp72_53 : tensor<128x64xi1> loc(#loc991) + %tmp74 = tt.expand_dims %tmp54_27 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc992) + %tmp74_54 = tt.broadcast %tmp74 : tensor<1x1xi32> -> tensor<128x64xi32> loc(#loc992) + %tmp74_55 = arith.addi %tmp69_47, %tmp74_54 : tensor<128x64xi32> loc(#loc992) + %tmp75 = arith.select %tmp73, %tmp74_55, %tmp69_47 : tensor<128x64xi1>, tensor<128x64xi32> loc(#loc993) + %tmp76 = tt.expand_dims %tmp57_32 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc994) + %tmp76_56 = tt.broadcast %tmp76 : tensor<1x1xi32> -> tensor<128x64xi32> loc(#loc994) + %tmp76_57 = arith.cmpi eq, %tmp75, %tmp76_56 : tensor<128x64xi32> loc(#loc994) + %tmp77 = tt.broadcast %tmp67 : tensor<128x1xi1> -> tensor<128x64xi1> loc(#loc995) + %tmp77_58 = arith.andi %tmp77, %tmp76_57 : tensor<128x64xi1> loc(#loc995) + %tmp78 = arith.ori %tmp53_26, %tmp77_58 : tensor<128x64xi1> loc(#loc996) + %post_mod_scores = arith.constant 0xFF800000 : f32 loc(#loc997) + %post_mod_scores_59 = arith.constant 0xFF800000 : f32 loc(#loc997) + %post_mod_scores_60 = arith.constant dense<0xFF800000> : tensor<128x64xf32> loc(#loc997) + %post_mod_scores_61 = arith.select %tmp78, %qkT_14, %post_mod_scores_60 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc997) + %post_mod_scores_62 = arith.constant 1.44269502 : f32 loc(#loc998) + %post_mod_scores_63 = arith.constant 1.44269502 : f32 loc(#loc998) + %post_mod_scores_64 = arith.constant dense<1.44269502> : tensor<128x64xf32> loc(#loc998) + %post_mod_scores_65 = arith.mulf %post_mod_scores_61, %post_mod_scores_64 : tensor<128x64xf32> loc(#loc998) + %pT = tt.expand_dims %lse_8 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> loc(#loc999) + %pT_66 = tt.broadcast %pT : tensor<1x64xf32> -> tensor<128x64xf32> loc(#loc1000) + %pT_67 = arith.subf %post_mod_scores_65, %pT_66 : tensor<128x64xf32> loc(#loc1000) + %pT_68 = math.exp2 %pT_67 : tensor<128x64xf32> loc(#loc1001) + %do = tt.call @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.load_checked_2d__Pbf16S64_128S_i32S64S_i32S128S_i32__(3,)cconstexpr_None__(4,)cconstexpr_None__(5,)cconstexpr_True__(6,)cconstexpr_True__(8,)cconstexpr_128_"(%do_ptrs, %offs_m1, %offs_v, %Q_LEN) : (tensor<64x128x!tt.ptr>, tensor<64xi32>, tensor<128xi32>, i32) -> tensor<64x128xbf16> loc(#loc1002) + %dv_69 = arith.truncf %pT_68 : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc1003) + %dv_70 = arith.constant 0.000000e+00 : f32 loc(#loc1004) + %dv_71 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> loc(#loc1004) + %dv_72 = tt.dot %dv_69, %do, %dv_71, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc1004) + %dv_73 = arith.addf %dv, %dv_72 : tensor<128x128xf32> loc(#loc1005) + %Di = tt.splat %DELTA : !tt.ptr -> tensor<64x!tt.ptr> loc(#loc1006) + %Di_74 = tt.addptr %Di, %offs_m1 : tensor<64x!tt.ptr>, tensor<64xi32> loc(#loc1006) + %Di_75 = tt.load %Di_74 : tensor<64x!tt.ptr> loc(#loc1007) + %dpT = tt.trans %do {order = array} : tensor<64x128xbf16> -> tensor<128x64xbf16> loc(#loc1008) + %dpT_76 = arith.constant 0.000000e+00 : f32 loc(#loc1009) + %dpT_77 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc1009) + %dpT_78 = tt.dot %v, %dpT, %dpT_77, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc1009) + %dsT = tt.expand_dims %Di_75 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> loc(#loc1010) + %dsT_79 = tt.broadcast %dsT : tensor<1x64xf32> -> tensor<128x64xf32> loc(#loc1011) + %dsT_80 = arith.subf %dpT_78, %dsT_79 : tensor<128x64xf32> loc(#loc1011) + %dsT_81 = arith.mulf %pT_68, %dsT_80 : tensor<128x64xf32> loc(#loc1012) + %dsT_82 = arith.constant 0.000000e+00 : f32 loc(#loc1013) + %dsT_83 = arith.constant 0.000000e+00 : f32 loc(#loc1013) + %dsT_84 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc1013) + %dsT_85 = arith.select %tmp78, %dsT_81, %dsT_84 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc1013) + %dk_86 = arith.truncf %dsT_85 : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc1014) + %dk_87 = tt.trans %qT {order = array} : tensor<128x64xbf16> -> tensor<64x128xbf16> loc(#loc1015) + %dk_88 = arith.constant 0.000000e+00 : f32 loc(#loc1016) + %dk_89 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> loc(#loc1016) + %dk_90 = tt.dot %dk_86, %dk_87, %dk_89, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc1016) + %dk_91 = arith.addf %dk, %dk_90 : tensor<128x128xf32> loc(#loc1017) + tt.return %dk_91, %dv_73 : tensor<128x128xf32>, tensor<128x128xf32> loc(#loc443) + ^bb1: // no predecessors + %0 = ub.poison : tensor<128x128xf32> loc(#loc444) + %1 = ub.poison : tensor<128x128xf32> loc(#loc444) + tt.return %0, %1 : tensor<128x128xf32>, tensor<128x128xf32> loc(#loc444) + } loc(#loc374) + tt.func private @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.load_checked_2d__Pbf16S64_128S_i32S64S_i32S128S_i32__(3,)cconstexpr_None__(4,)cconstexpr_None__(5,)cconstexpr_True__(6,)cconstexpr_True__(8,)cconstexpr_128_"(%ptr: tensor<64x128x!tt.ptr> loc("ptr"(#loc211)), %offs_m: tensor<64xi32> loc("offs_m"(#loc211)), %offs_n: tensor<128xi32> loc("offs_n"(#loc211)), %M_LEN: i32 loc("M_LEN"(#loc211))) -> tensor<64x128xbf16> attributes {noinline = false} { + %0 = tt.load %ptr : tensor<64x128x!tt.ptr> loc(#loc218) + tt.return %0 : tensor<64x128xbf16> loc(#loc219) + ^bb1: // no predecessors + %1 = ub.poison : tensor<64x128xbf16> loc(#loc220) + tt.return %1 : tensor<64x128xbf16> loc(#loc220) + } loc(#loc211) + tt.func private @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.bwd_dkdv_inner__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi64_Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_fp32S128_128S_fp32S128_128S_bf16S128_128S_bf16S128_128S_i32_i32_i32S128S_i32S64S_i32_i32_i32_i32_Pi32_i32__(36,)cconstexpr_bf16__(37,)cconstexpr_True_"(%arg_Q: !tt.ptr loc("arg_Q"(#loc342)), %arg_K: !tt.ptr loc("arg_K"(#loc342)), %arg_V: !tt.ptr loc("arg_V"(#loc342)), %arg_LSE: !tt.ptr loc("arg_LSE"(#loc342)), %arg_DELTA: !tt.ptr loc("arg_DELTA"(#loc342)), %arg_DO: !tt.ptr loc("arg_DO"(#loc342)), %arg_DQ: !tt.ptr loc("arg_DQ"(#loc342)), %arg_DV: !tt.ptr loc("arg_DV"(#loc342)), %arg_KV_NUM_BLKS: !tt.ptr loc("arg_KV_NUM_BLKS"(#loc342)), %arg_KV_IDX: !tt.ptr loc("arg_KV_IDX"(#loc342)), %arg_Q_NUM_BLKS: !tt.ptr loc("arg_Q_NUM_BLKS"(#loc342)), %arg_Q_IDX: !tt.ptr loc("arg_Q_IDX"(#loc342)), %arg_FULL_KV_NUM_BLKS: !tt.ptr loc("arg_FULL_KV_NUM_BLKS"(#loc342)), %arg_FULL_KV_IDX: !tt.ptr loc("arg_FULL_KV_IDX"(#loc342)), %arg_FULL_Q_NUM_BLKS: !tt.ptr loc("arg_FULL_Q_NUM_BLKS"(#loc342)), %arg_FULL_Q_IDX: !tt.ptr loc("arg_FULL_Q_IDX"(#loc342)), %in_ptr16: !tt.ptr loc("in_ptr16"(#loc342)), %out_ptr0: !tt.ptr loc("out_ptr0"(#loc342)), %Q: !tt.ptr loc("Q"(#loc342)), %DO: !tt.ptr loc("DO"(#loc342)), %DELTA: !tt.ptr loc("DELTA"(#loc342)), %LSE: !tt.ptr loc("LSE"(#loc342)), %dk: tensor<128x128xf32> loc("dk"(#loc342)), %dv: tensor<128x128xf32> loc("dv"(#loc342)), %k: tensor<128x128xbf16> loc("k"(#loc342)), %v: tensor<128x128xbf16> loc("v"(#loc342)), %off_z: i32 loc("off_z"(#loc342)), %off_hq: i32 loc("off_hq"(#loc342)), %offs_n1: tensor<128xi32> loc("offs_n1"(#loc342)), %offs_m1: tensor<64xi32> loc("offs_m1"(#loc342)), %stride_qm: i32 loc("stride_qm"(#loc342)), %stride_qd: i32 loc("stride_qd"(#loc342)), %stride_dom: i32 loc("stride_dom"(#loc342)), %stride_dod: i32 loc("stride_dod"(#loc342)), %q_indices: !tt.ptr loc("q_indices"(#loc342)), %sparse_q_num_blocks: i32 loc("sparse_q_num_blocks"(#loc342))) -> (tensor<128x128xf32>, tensor<128x128xf32>) attributes {noinline = false} { + %Q_LEN = arith.constant 2048 : i32 loc(#loc883) + %KV_LEN = arith.constant 2048 : i32 loc(#loc884) + %offs_k = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc885) + %offs_v = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc886) + %qT_ptrs = tt.expand_dims %offs_m1 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc887) + %qT_ptrs_0 = tt.splat %stride_qm : i32 -> tensor<1x64xi32> loc(#loc888) + %qT_ptrs_1 = arith.muli %qT_ptrs, %qT_ptrs_0 : tensor<1x64xi32> loc(#loc888) + %qT_ptrs_2 = tt.splat %Q : !tt.ptr -> tensor<1x64x!tt.ptr> loc(#loc889) + %qT_ptrs_3 = tt.addptr %qT_ptrs_2, %qT_ptrs_1 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> loc(#loc889) + %qT_ptrs_4 = tt.expand_dims %offs_k {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc890) + %qT_ptrs_5 = tt.splat %stride_qd : i32 -> tensor<128x1xi32> loc(#loc891) + %qT_ptrs_6 = arith.muli %qT_ptrs_4, %qT_ptrs_5 : tensor<128x1xi32> loc(#loc891) + %qT_ptrs_7 = tt.broadcast %qT_ptrs_3 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> loc(#loc892) + %qT_ptrs_8 = tt.broadcast %qT_ptrs_6 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc892) + %qT_ptrs_9 = tt.addptr %qT_ptrs_7, %qT_ptrs_8 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc892) + %do_ptrs = tt.expand_dims %offs_m1 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc893) + %do_ptrs_10 = tt.splat %stride_dom : i32 -> tensor<64x1xi32> loc(#loc894) + %do_ptrs_11 = arith.muli %do_ptrs, %do_ptrs_10 : tensor<64x1xi32> loc(#loc894) + %do_ptrs_12 = tt.splat %DO : !tt.ptr -> tensor<64x1x!tt.ptr> loc(#loc895) + %do_ptrs_13 = tt.addptr %do_ptrs_12, %do_ptrs_11 : tensor<64x1x!tt.ptr>, tensor<64x1xi32> loc(#loc895) + %do_ptrs_14 = tt.expand_dims %offs_v {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc896) + %do_ptrs_15 = tt.splat %stride_dod : i32 -> tensor<1x128xi32> loc(#loc897) + %do_ptrs_16 = arith.muli %do_ptrs_14, %do_ptrs_15 : tensor<1x128xi32> loc(#loc897) + %do_ptrs_17 = tt.broadcast %do_ptrs_13 : tensor<64x1x!tt.ptr> -> tensor<64x128x!tt.ptr> loc(#loc898) + %do_ptrs_18 = tt.broadcast %do_ptrs_16 : tensor<1x128xi32> -> tensor<64x128xi32> loc(#loc898) + %do_ptrs_19 = tt.addptr %do_ptrs_17, %do_ptrs_18 : tensor<64x128x!tt.ptr>, tensor<64x128xi32> loc(#loc898) + %hi = arith.constant 2 : i32 loc(#loc899) + %hi_20 = arith.constant 2 : i32 loc(#loc899) + %hi_21 = arith.muli %sparse_q_num_blocks, %hi_20 : i32 loc(#loc899) + %hi_22 = tt.call @"triton.language.standard.cdiv__i32__(1,)cconstexpr_64_"(%Q_LEN) : (i32) -> i32 loc(#loc900) + %hi_23 = arith.constant 1 : i32 loc(#loc901) + %hi_24 = arith.maxsi %hi_22, %hi_23 : i32 loc(#loc901) + %hi_25 = arith.minsi %hi_21, %hi_24 : i32 loc(#loc902) + %c0_i32 = arith.constant 0 : i32 loc(#loc363) + %c1_i32 = arith.constant 1 : i32 loc(#loc363) + %0 = arith.bitcast %c0_i32 : i32 to i32 loc(#loc363) + %1 = arith.bitcast %hi_25 : i32 to i32 loc(#loc363) + %2 = arith.bitcast %c1_i32 : i32 to i32 loc(#loc363) + %3 = ub.poison : i32 loc(#loc363) + %do_ptrs_26:5 = scf.for %start_m = %0 to %1 step %2 iter_args(%dk_27 = %dk, %dv_28 = %dv, %offs_m1_29 = %offs_m1, %qT_ptrs_30 = %qT_ptrs_9, %do_ptrs_31 = %do_ptrs_19) -> (tensor<128x128xf32>, tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<64x128x!tt.ptr>) : i32 { + %6:2 = tt.call @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.bwd_dkdv_block_mn__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi64_Pbf16_fp32S128_128S_fp32S128_128S_Pbf16S128_64S_bf16S128_128S_bf16S128_128S_Pbf16S64_128S_Pfp32_Pfp32_i32_i32_i32_i32_i32S128S_i32S64S_i32S128S_i32S128S_i32_i32_i32_i32_Pi32_i32__(40,)cconstexpr_bf16__(41,)cconstexpr_1_d_44269504__(42,)cconstexpr_True_"(%arg_Q, %arg_K, %arg_V, %arg_LSE, %arg_DELTA, %arg_DO, %arg_DQ, %arg_DV, %arg_KV_NUM_BLKS, %arg_KV_IDX, %arg_Q_NUM_BLKS, %arg_Q_IDX, %arg_FULL_KV_NUM_BLKS, %arg_FULL_KV_IDX, %arg_FULL_Q_NUM_BLKS, %arg_FULL_Q_IDX, %in_ptr16, %out_ptr0, %dk_27, %dv_28, %qT_ptrs_30, %k, %v, %do_ptrs_31, %DELTA, %LSE, %Q_LEN, %KV_LEN, %off_z, %off_hq, %offs_n1, %offs_m1_29, %offs_k, %offs_v, %stride_qm, %stride_qd, %stride_dom, %stride_dod, %q_indices, %sparse_q_num_blocks) : (!tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x64x!tt.ptr>, tensor<128x128xbf16>, tensor<128x128xbf16>, tensor<64x128x!tt.ptr>, !tt.ptr, !tt.ptr, i32, i32, i32, i32, tensor<128xi32>, tensor<64xi32>, tensor<128xi32>, tensor<128xi32>, i32, i32, i32, i32, !tt.ptr, i32) -> (tensor<128x128xf32>, tensor<128x128xf32>) loc(#loc364) + %offset = tt.call @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.get_offset_for_next_block__i32_Pi32_i32__(3,)cconstexpr_128__(4,)cconstexpr_2__(5,)cconstexpr_64__(6,)cconstexpr_False_"(%start_m, %q_indices, %sparse_q_num_blocks) : (i32, !tt.ptr, i32) -> i32 loc(#loc904) + %qT_ptrs_32 = arith.muli %offset, %stride_qm : i32 loc(#loc905) + %qT_ptrs_33 = tt.splat %qT_ptrs_32 : i32 -> tensor<128x64xi32> loc(#loc906) + %qT_ptrs_34 = tt.addptr %qT_ptrs_30, %qT_ptrs_33 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc906) + %do_ptrs_35 = arith.muli %offset, %stride_dom : i32 loc(#loc907) + %do_ptrs_36 = tt.splat %do_ptrs_35 : i32 -> tensor<64x128xi32> loc(#loc908) + %do_ptrs_37 = tt.addptr %do_ptrs_31, %do_ptrs_36 : tensor<64x128x!tt.ptr>, tensor<64x128xi32> loc(#loc908) + %offs_m1_38 = tt.splat %offset : i32 -> tensor<64xi32> loc(#loc909) + %offs_m1_39 = arith.addi %offs_m1_29, %offs_m1_38 : tensor<64xi32> loc(#loc909) + scf.yield %6#0, %6#1, %offs_m1_39, %qT_ptrs_34, %do_ptrs_37 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<64x128x!tt.ptr> loc(#loc371) + } loc(#loc1029) + tt.return %do_ptrs_26#0, %do_ptrs_26#1 : tensor<128x128xf32>, tensor<128x128xf32> loc(#loc372) + ^bb1: // no predecessors + %4 = ub.poison : tensor<128x128xf32> loc(#loc373) + %5 = ub.poison : tensor<128x128xf32> loc(#loc373) + tt.return %4, %5 : tensor<128x128xf32>, tensor<128x128xf32> loc(#loc373) + } loc(#loc342) + tt.func private @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.bwd_dkdv_block_mn__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi64_Pbf16_fp32S128_128S_fp32S128_128S_Pbf16S128_64S_bf16S128_128S_bf16S128_128S_Pbf16S64_128S_Pfp32_Pfp32_i32_i32_i32_i32_i32S128S_i32S64S_i32S128S_i32S128S_i32_i32_i32_i32_Pi32_i32__(40,)cconstexpr_bf16__(41,)cconstexpr_1_d_44269504__(42,)cconstexpr_True_"(%arg_Q: !tt.ptr loc("arg_Q"(#loc374)), %arg_K: !tt.ptr loc("arg_K"(#loc374)), %arg_V: !tt.ptr loc("arg_V"(#loc374)), %arg_LSE: !tt.ptr loc("arg_LSE"(#loc374)), %arg_DELTA: !tt.ptr loc("arg_DELTA"(#loc374)), %arg_DO: !tt.ptr loc("arg_DO"(#loc374)), %arg_DQ: !tt.ptr loc("arg_DQ"(#loc374)), %arg_DV: !tt.ptr loc("arg_DV"(#loc374)), %arg_KV_NUM_BLKS: !tt.ptr loc("arg_KV_NUM_BLKS"(#loc374)), %arg_KV_IDX: !tt.ptr loc("arg_KV_IDX"(#loc374)), %arg_Q_NUM_BLKS: !tt.ptr loc("arg_Q_NUM_BLKS"(#loc374)), %arg_Q_IDX: !tt.ptr loc("arg_Q_IDX"(#loc374)), %arg_FULL_KV_NUM_BLKS: !tt.ptr loc("arg_FULL_KV_NUM_BLKS"(#loc374)), %arg_FULL_KV_IDX: !tt.ptr loc("arg_FULL_KV_IDX"(#loc374)), %arg_FULL_Q_NUM_BLKS: !tt.ptr loc("arg_FULL_Q_NUM_BLKS"(#loc374)), %arg_FULL_Q_IDX: !tt.ptr loc("arg_FULL_Q_IDX"(#loc374)), %in_ptr16: !tt.ptr loc("in_ptr16"(#loc374)), %out_ptr0: !tt.ptr loc("out_ptr0"(#loc374)), %dk: tensor<128x128xf32> loc("dk"(#loc374)), %dv: tensor<128x128xf32> loc("dv"(#loc374)), %qT_ptrs: tensor<128x64x!tt.ptr> loc("qT_ptrs"(#loc374)), %k: tensor<128x128xbf16> loc("k"(#loc374)), %v: tensor<128x128xbf16> loc("v"(#loc374)), %do_ptrs: tensor<64x128x!tt.ptr> loc("do_ptrs"(#loc374)), %DELTA: !tt.ptr loc("DELTA"(#loc374)), %LSE: !tt.ptr loc("LSE"(#loc374)), %Q_LEN: i32 loc("Q_LEN"(#loc374)), %KV_LEN: i32 loc("KV_LEN"(#loc374)), %off_z: i32 loc("off_z"(#loc374)), %off_hq: i32 loc("off_hq"(#loc374)), %offs_n1: tensor<128xi32> loc("offs_n1"(#loc374)), %offs_m1: tensor<64xi32> loc("offs_m1"(#loc374)), %offs_k: tensor<128xi32> loc("offs_k"(#loc374)), %offs_v: tensor<128xi32> loc("offs_v"(#loc374)), %stride_qm: i32 loc("stride_qm"(#loc374)), %stride_qd: i32 loc("stride_qd"(#loc374)), %stride_dom: i32 loc("stride_dom"(#loc374)), %stride_dod: i32 loc("stride_dod"(#loc374)), %q_indices: !tt.ptr loc("q_indices"(#loc374)), %sparse_q_num_blocks: i32 loc("sparse_q_num_blocks"(#loc374))) -> (tensor<128x128xf32>, tensor<128x128xf32>) attributes {noinline = false} { + %qT = tt.call @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.load_checked_2d__Pbf16S128_64S_i32S128S_i32S64S_i32__(3,)cconstexpr_None__(4,)cconstexpr_None__(5,)cconstexpr_True__(6,)cconstexpr_True__(7,)cconstexpr_128_"(%qT_ptrs, %offs_k, %offs_m1, %Q_LEN) : (tensor<128x64x!tt.ptr>, tensor<128xi32>, tensor<64xi32>, i32) -> tensor<128x64xbf16> loc(#loc950) + %lse = tt.splat %LSE : !tt.ptr -> tensor<64x!tt.ptr> loc(#loc951) + %lse_0 = tt.addptr %lse, %offs_m1 : tensor<64x!tt.ptr>, tensor<64xi32> loc(#loc951) + %lse_1 = tt.load %lse_0 : tensor<64x!tt.ptr> loc(#loc952) + %lse_2 = arith.constant 0xFF800000 : f32 loc(#loc953) + %lse_3 = arith.constant dense<0xFF800000> : tensor<64xf32> loc(#loc953) + %lse_4 = arith.cmpf oeq, %lse_1, %lse_3 : tensor<64xf32> loc(#loc953) + %lse_5 = arith.constant 0.000000e+00 : f32 loc(#loc954) + %lse_6 = arith.constant 0.000000e+00 : f32 loc(#loc954) + %lse_7 = arith.constant dense<0.000000e+00> : tensor<64xf32> loc(#loc954) + %lse_8 = arith.select %lse_4, %lse_7, %lse_1 : tensor<64xi1>, tensor<64xf32> loc(#loc954) + %qkT = arith.constant 0.000000e+00 : f32 loc(#loc955) + %qkT_9 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc955) + %qkT_10 = tt.dot %k, %qT, %qkT_9, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc955) + %qkT_11 = arith.constant 0.0883883461 : f32 loc(#loc956) + %qkT_12 = arith.constant 0.0883883461 : f32 loc(#loc956) + %qkT_13 = arith.constant dense<0.0883883461> : tensor<128x64xf32> loc(#loc956) + %qkT_14 = arith.mulf %qkT_10, %qkT_13 : tensor<128x64xf32> loc(#loc956) + %m = tt.expand_dims %offs_m1 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc957) + %m_15 = tt.call @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.get_bounded_indices__i32S1_64S__(1,)cconstexpr_None_"(%m) : (tensor<1x64xi32>) -> tensor<1x64xi32> loc(#loc958) + %n = tt.expand_dims %offs_n1 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc959) + %n_16 = tt.call @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.get_bounded_indices__i32S128_1S__(1,)cconstexpr_None_"(%n) : (tensor<128x1xi32>) -> tensor<128x1xi32> loc(#loc960) + %post_mod_scores = arith.constant 1.44269502 : f32 loc(#loc998) + %post_mod_scores_17 = arith.constant 1.44269502 : f32 loc(#loc998) + %post_mod_scores_18 = arith.constant dense<1.44269502> : tensor<128x64xf32> loc(#loc998) + %post_mod_scores_19 = arith.mulf %qkT_14, %post_mod_scores_18 : tensor<128x64xf32> loc(#loc998) + %pT = tt.expand_dims %lse_8 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> loc(#loc999) + %pT_20 = tt.broadcast %pT : tensor<1x64xf32> -> tensor<128x64xf32> loc(#loc1000) + %pT_21 = arith.subf %post_mod_scores_19, %pT_20 : tensor<128x64xf32> loc(#loc1000) + %pT_22 = math.exp2 %pT_21 : tensor<128x64xf32> loc(#loc1001) + %do = tt.call @"torch._inductor.runtime.compile_tasks.cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.load_checked_2d__Pbf16S64_128S_i32S64S_i32S128S_i32__(3,)cconstexpr_None__(4,)cconstexpr_None__(5,)cconstexpr_True__(6,)cconstexpr_True__(8,)cconstexpr_128_"(%do_ptrs, %offs_m1, %offs_v, %Q_LEN) : (tensor<64x128x!tt.ptr>, tensor<64xi32>, tensor<128xi32>, i32) -> tensor<64x128xbf16> loc(#loc1002) + %dv_23 = arith.truncf %pT_22 : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc1003) + %dv_24 = arith.constant 0.000000e+00 : f32 loc(#loc1004) + %dv_25 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> loc(#loc1004) + %dv_26 = tt.dot %dv_23, %do, %dv_25, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc1004) + %dv_27 = arith.addf %dv, %dv_26 : tensor<128x128xf32> loc(#loc1005) + %Di = tt.splat %DELTA : !tt.ptr -> tensor<64x!tt.ptr> loc(#loc1006) + %Di_28 = tt.addptr %Di, %offs_m1 : tensor<64x!tt.ptr>, tensor<64xi32> loc(#loc1006) + %Di_29 = tt.load %Di_28 : tensor<64x!tt.ptr> loc(#loc1007) + %dpT = tt.trans %do {order = array} : tensor<64x128xbf16> -> tensor<128x64xbf16> loc(#loc1008) + %dpT_30 = arith.constant 0.000000e+00 : f32 loc(#loc1009) + %dpT_31 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc1009) + %dpT_32 = tt.dot %v, %dpT, %dpT_31, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc1009) + %dsT = tt.expand_dims %Di_29 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> loc(#loc1010) + %dsT_33 = tt.broadcast %dsT : tensor<1x64xf32> -> tensor<128x64xf32> loc(#loc1011) + %dsT_34 = arith.subf %dpT_32, %dsT_33 : tensor<128x64xf32> loc(#loc1011) + %dsT_35 = arith.mulf %pT_22, %dsT_34 : tensor<128x64xf32> loc(#loc1012) + %dk_36 = arith.truncf %dsT_35 : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc1014) + %dk_37 = tt.trans %qT {order = array} : tensor<128x64xbf16> -> tensor<64x128xbf16> loc(#loc1015) + %dk_38 = arith.constant 0.000000e+00 : f32 loc(#loc1016) + %dk_39 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> loc(#loc1016) + %dk_40 = tt.dot %dk_36, %dk_37, %dk_39, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc1016) + %dk_41 = arith.addf %dk, %dk_40 : tensor<128x128xf32> loc(#loc1017) + tt.return %dk_41, %dv_27 : tensor<128x128xf32>, tensor<128x128xf32> loc(#loc443) + ^bb1: // no predecessors + %0 = ub.poison : tensor<128x128xf32> loc(#loc444) + %1 = ub.poison : tensor<128x128xf32> loc(#loc444) + tt.return %0, %1 : tensor<128x128xf32>, tensor<128x128xf32> loc(#loc444) + } loc(#loc374) +} loc(#loc) +#loc1 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":94:49) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":95:49) +#loc3 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":96:49) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":97:53) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":99:53) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":100:53) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":102:9) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":103:9) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":104:10) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":105:12) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":106:10) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":107:13) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":111:24) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":112:36) +#loc15 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":113:34) +#loc16 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":115:27) +#loc17 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":116:28) +#loc18 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":117:23) +#loc19 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":119:15) +#loc20 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":120:16) +#loc21 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":122:28) +#loc22 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":124:25) +#loc23 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":124:47) +#loc24 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":124:35) +#loc25 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":124:59) +#loc26 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":125:25) +#loc27 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":125:47) +#loc28 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":125:35) +#loc29 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":125:59) +#loc30 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":128:27) +#loc31 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":128:50) +#loc32 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":128:37) +#loc33 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":128:61) +#loc34 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":131:9) +#loc35 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":132:9) +#loc36 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":133:10) +#loc37 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":135:14) +#loc38 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":136:26) +#loc39 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":137:26) +#loc40 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":139:14) +#loc41 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":139:7) +#loc42 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":140:24) +#loc43 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":142:29) +#loc44 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":143:30) +#loc45 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":144:29) +#loc46 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":144:54) +#loc47 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":144:44) +#loc48 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":145:35) +#loc49 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":146:41) +#loc50 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":147:31) +#loc51 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":148:26) +#loc52 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":149:26) +#loc53 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":151:35) +#loc54 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":152:42) +#loc55 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":152:54) +#loc56 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":154:55) +#loc57 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":154:78) +#loc58 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":155:50) +#loc59 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":155:83) +#loc60 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":155:68) +#loc61 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":158:30) +#loc62 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":158:52) +#loc63 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":158:40) +#loc64 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":158:63) +#loc65 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":159:32) +#loc66 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":159:55) +#loc67 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":159:42) +#loc68 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":159:66) +#loc69 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":160:32) +#loc70 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":160:55) +#loc71 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":160:42) +#loc72 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":160:66) +#loc73 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":161:30) +#loc74 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":161:35) +#loc75 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":161:46) +#loc76 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":161:56) +#loc77 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":163:17) +#loc78 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":164:19) +#loc79 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":167:19) +#loc80 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":168:21) +#loc81 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":169:25) +#loc82 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":172:22) +#loc83 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":174:36) +#loc84 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":175:42) +#loc85 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":175:29) +#loc86 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":178:107) +#loc87 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":179:111) +#loc88 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":185:34) +#loc89 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":185:25) +#loc90 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":186:33) +#loc91 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":186:26) +#loc92 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":190:30) +#loc93 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":190:50) +#loc94 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":191:18) +#loc95 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":195:30) +#loc96 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":196:27) +#loc97 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":196:41) +#loc98 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":197:53) +#loc99 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":197:39) +#loc100 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":199:42) +#loc101 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":199:29) +#loc102 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":207:12) +#loc103 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":214:39) +#loc104 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":215:31) +#loc105 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":215:45) +#loc106 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":216:62) +#loc107 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":216:43) +#loc108 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":218:46) +#loc109 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":218:33) +#loc110 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":226:16) +#loc111 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":231:32) +#loc112 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":231:43) +#loc113 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":231:24) +#loc114 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":231:63) +#loc115 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":231:74) +#loc116 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":231:56) +#loc117 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":232:14) +#loc118 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":234:30) +#loc119 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":239:29) +#loc120 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":240:30) +#loc121 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":242:26) +#loc122 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":244:30) +#loc123 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":245:25) +#loc124 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":246:25) +#loc125 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":249:22) +#loc126 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":250:22) +#loc127 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":252:25) +#loc128 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":253:42) +#loc129 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":253:29) +#loc130 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":256:107) +#loc131 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":257:107) +#loc132 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":262:30) +#loc133 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":263:32) +#loc134 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":263:51) +#loc135 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":266:34) +#loc136 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":266:56) +#loc137 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":266:44) +#loc138 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":266:67) +#loc139 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":267:36) +#loc140 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":267:59) +#loc141 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":267:46) +#loc142 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":267:70) +#loc143 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":268:36) +#loc144 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":268:59) +#loc145 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":268:46) +#loc146 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":268:70) +#loc147 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":269:34) +#loc148 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":269:39) +#loc149 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":269:50) +#loc150 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":269:60) +#loc151 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":271:21) +#loc152 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":272:23) +#loc153 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":275:25) +#loc154 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":276:29) +#loc155 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":278:39) +#loc156 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":279:46) +#loc157 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":279:58) +#loc158 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":281:58) +#loc159 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":281:80) +#loc160 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":282:53) +#loc161 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":282:81) +#loc162 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":282:70) +#loc163 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":286:32) +#loc164 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":287:30) +#loc165 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":287:43) +#loc166 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":288:55) +#loc167 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":288:42) +#loc168 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":290:45) +#loc169 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":290:32) +#loc170 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":298:16) +#loc171 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":306:41) +#loc172 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":307:34) +#loc173 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":307:47) +#loc174 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":308:64) +#loc175 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":308:46) +#loc176 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":310:49) +#loc177 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":310:36) +#loc178 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":318:20) +#loc179 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":303:12) +#loc180 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":323:31) +#loc181 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":323:42) +#loc182 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":323:23) +#loc183 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":323:62) +#loc184 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":323:73) +#loc185 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":323:55) +#loc186 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":325:26) +#loc187 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":326:25) +#loc188 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":327:25) +#loc189 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":330:30) +#loc190 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":334:14) +#loc191 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":337:29) +#loc192 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":344:31) +#loc193 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":344:27) +#loc194 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":344:48) +#loc195 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":344:41) +#loc196 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":344:66) +#loc197 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":344:58) +#loc198 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":345:29) +#loc199 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":345:69) +#loc200 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":139:4) +#loc202 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":41:16) +#loc203 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":41:22) +#loc204 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":41:28) +#loc205 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":41:11) +#loc206 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":41:4) +#loc207 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":118:0) +#loc208 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":127:31) +#loc209 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":127:11) +#loc210 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":127:4) +#loc212 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":825:27) +#loc213 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":825:38) +#loc214 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":825:20) +#loc215 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":825:56) +#loc216 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":825:67) +#loc217 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":825:49) +#loc218 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":835:23) +#loc219 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":835:15) +#loc220 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":828:4) +#loc222 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":384:12) +#loc223 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":385:13) +#loc224 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":387:26) +#loc225 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":388:26) +#loc226 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":390:26) +#loc227 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":390:37) +#loc228 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":390:18) +#loc229 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":390:56) +#loc230 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":390:67) +#loc231 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":390:49) +#loc232 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":391:26) +#loc233 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":391:37) +#loc234 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":391:18) +#loc235 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":391:56) +#loc236 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":391:67) +#loc237 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":391:49) +#loc238 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":395:43) +#loc239 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":395:90) +#loc240 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":395:101) +#loc241 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":395:63) +#loc242 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":397:28) +#loc243 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":405:12) +#loc244 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":411:64) +#loc245 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":414:28) +#loc246 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":414:19) +#loc247 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":415:28) +#loc248 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":415:19) +#loc249 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":417:19) +#loc250 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":417:8) +#loc251 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":419:11) +#loc252 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":419:4) +#loc254 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":458:105) +#loc255 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":459:19) +#loc256 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":461:14) +#loc257 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":464:36) +#loc258 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":464:46) +#loc259 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":467:36) +#loc260 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":467:46) +#loc261 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":479:35) +#loc262 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":482:23) +#loc263 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":483:23) +#loc264 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":485:34) +#loc265 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":485:23) +#loc266 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":486:22) +#loc267 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":487:23) +#loc268 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":488:23) +#loc269 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":489:23) +#loc270 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":490:23) +#loc271 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":491:23) +#loc272 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":492:35) +#loc273 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":493:24) +#loc274 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":494:24) +#loc275 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":495:32) +#loc276 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":496:25) +#loc277 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":497:92) +#loc278 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":498:92) +#loc279 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":499:25) +#loc280 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":500:24) +#loc281 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":501:24) +#loc282 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":502:39) +#loc283 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":503:25) +#loc284 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":504:24) +#loc285 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":505:24) +#loc286 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":506:23) +#loc287 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":507:25) +#loc288 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":508:25) +#loc289 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":509:92) +#loc290 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":510:25) +#loc291 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":511:24) +#loc292 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":512:24) +#loc293 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":513:39) +#loc294 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":514:25) +#loc295 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":515:24) +#loc296 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":516:24) +#loc297 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":521:69) +#loc298 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":524:27) +#loc299 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":525:39) +#loc300 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":525:21) +#loc301 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":528:104) +#loc302 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":530:20) +#loc303 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":531:22) +#loc304 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":531:19) +#loc305 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":531:14) +#loc306 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":542:32) +#loc307 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":542:43) +#loc308 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":542:62) +#loc309 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":542:73) +#loc310 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":542:54) +#loc311 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":549:43) +#loc312 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":551:15) +#loc313 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":553:30) +#loc314 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":553:21) +#loc315 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":553:10) +#loc316 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":555:11) +#loc317 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":555:4) +#loc319 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":798:11) +#loc320 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":798:4) +#loc322 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":788:33) +#loc323 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":789:38) +#loc324 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":789:24) +#loc325 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":790:109) +#loc326 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":790:113) +#loc327 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":790:39) +#loc328 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":790:55) +#loc329 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":790:25) +#loc330 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":791:30) +#loc331 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":791:35) +#loc332 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":791:60) +#loc333 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":792:34) +#loc334 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":792:48) +#loc335 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":792:63) +#loc336 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":793:29) +#loc337 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":793:47) +#loc338 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":793:61) +#loc339 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":793:42) +#loc340 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":794:11) +#loc341 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":794:4) +#loc343 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":595:12) +#loc344 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":596:13) +#loc345 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":598:26) +#loc346 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":599:26) +#loc347 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":601:26) +#loc348 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":601:37) +#loc349 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":601:18) +#loc350 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":601:56) +#loc351 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":601:67) +#loc352 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":601:49) +#loc353 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":602:27) +#loc354 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":602:38) +#loc355 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":602:19) +#loc356 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":602:58) +#loc357 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":602:69) +#loc358 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":602:51) +#loc359 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":608:42) +#loc360 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":608:87) +#loc361 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":608:98) +#loc362 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":608:61) +#loc363 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":610:28) +#loc364 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":618:12) +#loc365 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":623:62) +#loc366 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":626:28) +#loc367 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":626:19) +#loc368 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":627:28) +#loc369 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":627:19) +#loc370 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":628:19) +#loc371 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":628:8) +#loc372 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":630:11) +#loc373 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":630:4) +#loc375 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":669:105) +#loc376 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":672:28) +#loc377 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":672:22) +#loc378 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":675:26) +#loc379 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":675:46) +#loc380 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":676:20) +#loc381 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":678:15) +#loc382 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":680:36) +#loc383 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":680:46) +#loc384 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":683:36) +#loc385 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":683:46) +#loc386 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":695:36) +#loc387 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":698:25) +#loc388 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":699:25) +#loc389 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":701:35) +#loc390 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":701:24) +#loc391 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":702:24) +#loc392 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":703:25) +#loc393 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":704:24) +#loc394 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":705:24) +#loc395 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":706:24) +#loc396 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":707:24) +#loc397 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":708:35) +#loc398 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":709:25) +#loc399 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":710:25) +#loc400 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":711:32) +#loc401 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":712:25) +#loc402 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":713:92) +#loc403 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":714:92) +#loc404 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":715:25) +#loc405 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":716:24) +#loc406 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":717:24) +#loc407 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":718:39) +#loc408 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":719:25) +#loc409 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":720:24) +#loc410 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":721:24) +#loc411 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":722:24) +#loc412 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":723:25) +#loc413 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":724:25) +#loc414 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":725:92) +#loc415 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":726:25) +#loc416 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":727:24) +#loc417 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":728:24) +#loc418 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":729:39) +#loc419 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":730:25) +#loc420 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":731:24) +#loc421 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":732:24) +#loc422 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":736:69) +#loc423 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":739:27) +#loc424 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":740:44) +#loc425 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":740:40) +#loc426 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":740:22) +#loc427 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":741:99) +#loc428 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":744:24) +#loc429 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":744:43) +#loc430 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":744:10) +#loc431 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":746:29) +#loc432 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":746:21) +#loc433 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":750:29) +#loc434 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":750:20) +#loc435 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":751:25) +#loc436 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":751:22) +#loc437 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":751:16) +#loc438 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":773:45) +#loc439 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":775:24) +#loc440 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":775:52) +#loc441 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":775:43) +#loc442 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":775:10) +#loc443 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":777:11) +#loc444 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":777:4) +#loc463 = loc("ZQ"(#loc7)) +#loc464 = loc("HQ"(#loc8)) +#loc465 = loc("HKV"(#loc9)) +#loc466 = loc("Q_LEN"(#loc10)) +#loc467 = loc("ZKV"(#loc11)) +#loc468 = loc("KV_LEN"(#loc12)) +#loc469 = loc("pid"(#loc13)) +#loc470 = loc("NUM_KV_BLOCKS"(#loc14)) +#loc471 = loc("NUM_Q_BLOCKS"(#loc15)) +#loc472 = loc("off_zq"(#loc16)) +#loc473 = loc("off_hkv"(#loc17)) +#loc474 = loc("off_zkv"(#loc18)) +#loc475 = loc("SPARSE_Z"(#loc19)) +#loc476 = loc("SPARSE_HQ"(#loc20)) +#loc477 = loc("sparse_idx_z"(#loc21)) +#loc478 = loc("k_adj"(#loc22)) +#loc479 = loc("k_adj"(#loc23)) +#loc480 = loc("k_adj"(#loc24)) +#loc481 = loc("k_adj"(#loc25)) +#loc482 = loc("v_adj"(#loc26)) +#loc483 = loc("v_adj"(#loc27)) +#loc484 = loc("v_adj"(#loc28)) +#loc485 = loc("v_adj"(#loc29)) +#loc486 = loc("dv_adj"(#loc30)) +#loc487 = loc("dv_adj"(#loc31)) +#loc488 = loc("dv_adj"(#loc32)) +#loc489 = loc("dv_adj"(#loc33)) +#loc490 = loc("K"(#loc34)) +#loc491 = loc("V"(#loc35)) +#loc492 = loc("DV"(#loc36)) +#loc493 = loc("RCP_LN2"(#loc37)) +#loc494 = loc("offs_k"(#loc38)) +#loc495 = loc("offs_v"(#loc39)) +#loc496 = loc("off_pid"(#loc42)) +#loc497 = loc("SPARSE_Q_MULTIPLE"(#loc43)) +#loc498 = loc("SPARSE_KV_MULTIPLE"(#loc44)) +#loc499 = loc("off_hq2"(#loc45)) +#loc500 = loc("off_hq2"(#loc46)) +#loc501 = loc("off_hq2"(#loc47)) +#loc502 = loc("start_m2_block"(#loc48)) +#loc503 = loc("off_pid_mask"(#loc49)) +#loc504 = loc("stride_kv_num_blks_h"(#loc50)) +#loc505 = loc("stride_kv_idx_h"(#loc51)) +#loc506 = loc("stride_kv_idx_m"(#loc52)) +#loc507 = loc("sparse_idx_hq2"(#loc53)) +#loc508 = loc("sparse_hz_offset"(#loc54)) +#loc509 = loc("sparse_hz_offset"(#loc55)) +#loc510 = loc("sparse_kv_num_blks_offset"(#loc56)) +#loc511 = loc("sparse_kv_num_blks_offset"(#loc57)) +#loc512 = loc("sparse_kv_idx_offset"(#loc58)) +#loc513 = loc("sparse_kv_idx_offset"(#loc59)) +#loc514 = loc("sparse_kv_idx_offset"(#loc60)) +#loc515 = loc("q_adj2"(#loc61)) +#loc516 = loc("q_adj2"(#loc62)) +#loc517 = loc("q_adj2"(#loc63)) +#loc518 = loc("q_adj2"(#loc64)) +#loc519 = loc("do_adj2"(#loc65)) +#loc520 = loc("do_adj2"(#loc66)) +#loc521 = loc("do_adj2"(#loc67)) +#loc522 = loc("do_adj2"(#loc68)) +#loc523 = loc("dq_adj2"(#loc69)) +#loc524 = loc("dq_adj2"(#loc70)) +#loc525 = loc("dq_adj2"(#loc71)) +#loc526 = loc("dq_adj2"(#loc72)) +#loc527 = loc("off_chz2"(#loc73)) +#loc528 = loc("off_chz2"(#loc74)) +#loc529 = loc("off_chz2"(#loc75)) +#loc530 = loc("off_chz2"(#loc76)) +#loc531 = loc("Q2"(#loc77)) +#loc532 = loc("DO2"(#loc78)) +#loc533 = loc("DQ2"(#loc79)) +#loc534 = loc("LSE2"(#loc80)) +#loc535 = loc("DELTA2"(#loc81)) +#loc536 = loc("dq"(#loc82)) +#loc537 = loc("start_m2"(#loc83)) +#loc538 = loc("offs_m2"(#loc84)) +#loc539 = loc("offs_m2"(#loc85)) +#loc540 = loc("q"(#loc86)) +#loc541 = loc("do"(#loc87)) +#loc542 = loc("Di"(#loc88)) +#loc543 = loc("Di"(#loc89)) +#loc544 = loc("lse"(#loc90)) +#loc545 = loc("lse"(#loc91)) +#loc546 = loc("lse"(#loc92)) +#loc547 = loc("lse"(#loc93)) +#loc548 = loc("lse"(#loc94)) +#loc549 = loc("kv_indices"(#loc95)) +#loc550 = loc("kv_start"(#loc96)) +#loc551 = loc("kv_start"(#loc97)) +#loc552 = loc("sparse_kv_num_blocks"(#loc98)) +#loc553 = loc("sparse_kv_num_blocks"(#loc99)) +#loc554 = loc("offs_n2"(#loc100)) +#loc555 = loc("offs_n2"(#loc101)) +#loc556 = loc("dq"(#loc102)) +#loc557 = loc("kv_indices"(#loc103)) +#loc558 = loc("kv_start"(#loc104)) +#loc559 = loc("kv_start"(#loc105)) +#loc560 = loc("sparse_kv_num_blocks"(#loc106)) +#loc561 = loc("sparse_kv_num_blocks"(#loc107)) +#loc562 = loc("offs_n2"(#loc108)) +#loc563 = loc("offs_n2"(#loc109)) +#loc564 = loc("dq"(#loc110)) +#loc565 = loc("dq_ptrs"(#loc111)) +#loc566 = loc("dq_ptrs"(#loc112)) +#loc567 = loc("dq_ptrs"(#loc113)) +#loc568 = loc("dq_ptrs"(#loc114)) +#loc569 = loc("dq_ptrs"(#loc115)) +#loc570 = loc("dq_ptrs"(#loc116)) +#loc571 = loc("dq"(#loc117)) +#loc572 = loc("SPARSE_Q_MULTIPLE"(#loc119)) +#loc573 = loc("SPARSE_KV_MULTIPLE"(#loc120)) +#loc574 = loc("pid_mask"(#loc121)) +#loc575 = loc("stride_q_num_blks_h"(#loc122)) +#loc576 = loc("stride_q_idx_h"(#loc123)) +#loc577 = loc("stride_q_idx_n"(#loc124)) +#loc578 = loc("dv"(#loc125)) +#loc579 = loc("dk"(#loc126)) +#loc580 = loc("start_n1"(#loc127)) +#loc581 = loc("offs_n1"(#loc128)) +#loc582 = loc("offs_n1"(#loc129)) +#loc583 = loc("k"(#loc130)) +#loc584 = loc("v"(#loc131)) +#loc585 = loc("dv"(#loc132)) +#loc586 = loc("off_hq1"(#loc133)) +#loc587 = loc("off_hq1"(#loc134)) +#loc588 = loc("q_adj1"(#loc135)) +#loc589 = loc("q_adj1"(#loc136)) +#loc590 = loc("q_adj1"(#loc137)) +#loc591 = loc("q_adj1"(#loc138)) +#loc592 = loc("do_adj1"(#loc139)) +#loc593 = loc("do_adj1"(#loc140)) +#loc594 = loc("do_adj1"(#loc141)) +#loc595 = loc("do_adj1"(#loc142)) +#loc596 = loc("dq_adj1"(#loc143)) +#loc597 = loc("dq_adj1"(#loc144)) +#loc598 = loc("dq_adj1"(#loc145)) +#loc599 = loc("dq_adj1"(#loc146)) +#loc600 = loc("off_chz1"(#loc147)) +#loc601 = loc("off_chz1"(#loc148)) +#loc602 = loc("off_chz1"(#loc149)) +#loc603 = loc("off_chz1"(#loc150)) +#loc604 = loc("Q1"(#loc151)) +#loc605 = loc("DO1"(#loc152)) +#loc606 = loc("LSE1"(#loc153)) +#loc607 = loc("DELTA1"(#loc154)) +#loc608 = loc("sparse_idx_hq1"(#loc155)) +#loc609 = loc("sparse_hz_offset"(#loc156)) +#loc610 = loc("sparse_hz_offset"(#loc157)) +#loc611 = loc("sparse_q_num_blks_offset"(#loc158)) +#loc612 = loc("sparse_q_num_blks_offset"(#loc159)) +#loc613 = loc("sparse_q_idx_offset"(#loc160)) +#loc614 = loc("sparse_q_idx_offset"(#loc161)) +#loc615 = loc("sparse_q_idx_offset"(#loc162)) +#loc616 = loc("q_indices"(#loc163)) +#loc617 = loc("q_start"(#loc164)) +#loc618 = loc("q_start"(#loc165)) +#loc619 = loc("sparse_q_num_blocks"(#loc166)) +#loc620 = loc("sparse_q_num_blocks"(#loc167)) +#loc621 = loc("offs_m1"(#loc168)) +#loc622 = loc("offs_m1"(#loc169)) +#loc623 = loc("q_indices"(#loc171)) +#loc624 = loc("q_start"(#loc172)) +#loc625 = loc("q_start"(#loc173)) +#loc626 = loc("sparse_q_num_blocks"(#loc174)) +#loc627 = loc("sparse_q_num_blocks"(#loc175)) +#loc628 = loc("offs_m1"(#loc176)) +#loc629 = loc("offs_m1"(#loc177)) +#loc630 = loc("dv_ptrs"(#loc180)) +#loc631 = loc("dv_ptrs"(#loc181)) +#loc632 = loc("dv_ptrs"(#loc182)) +#loc633 = loc("dv_ptrs"(#loc183)) +#loc634 = loc("dv_ptrs"(#loc184)) +#loc635 = loc("dv_ptrs"(#loc185)) +#loc636 = loc("index_n"(#loc186)) +#loc637 = loc("index_k"(#loc187)) +#loc638 = loc("index_v"(#loc188)) +#loc639 = loc("dk"(#loc190)) +#loc640 = loc("mask"(#loc191)) +#loc641 = loc("xindex"(#loc192)) +#loc642 = loc("xindex"(#loc193)) +#loc643 = loc("xindex"(#loc194)) +#loc644 = loc("xindex"(#loc195)) +#loc645 = loc("xindex"(#loc196)) +#loc646 = loc("xindex"(#loc197)) +#loc654 = loc("ptr"(#loc212)) +#loc655 = loc("ptr"(#loc213)) +#loc656 = loc("ptr"(#loc214)) +#loc657 = loc("ptr"(#loc215)) +#loc658 = loc("ptr"(#loc216)) +#loc659 = loc("ptr"(#loc217)) +#loc695 = loc("Q_LEN"(#loc222)) +#loc696 = loc("KV_LEN"(#loc223)) +#loc697 = loc("offs_k"(#loc224)) +#loc698 = loc("offs_v"(#loc225)) +#loc699 = loc("kT_ptrs"(#loc226)) +#loc700 = loc("kT_ptrs"(#loc227)) +#loc701 = loc("kT_ptrs"(#loc228)) +#loc702 = loc("kT_ptrs"(#loc229)) +#loc703 = loc("kT_ptrs"(#loc230)) +#loc704 = loc("kT_ptrs"(#loc231)) +#loc705 = loc("vT_ptrs"(#loc232)) +#loc706 = loc("vT_ptrs"(#loc233)) +#loc707 = loc("vT_ptrs"(#loc234)) +#loc708 = loc("vT_ptrs"(#loc235)) +#loc709 = loc("vT_ptrs"(#loc236)) +#loc710 = loc("vT_ptrs"(#loc237)) +#loc711 = loc("hi"(#loc238)) +#loc712 = loc("hi"(#loc239)) +#loc713 = loc("hi"(#loc240)) +#loc714 = loc("hi"(#loc241)) +#loc715 = loc("dq"(#loc242)) +#loc716 = loc("dq"(#loc243)) +#loc717 = loc("offset"(#loc244)) +#loc718 = loc("kT_ptrs"(#loc245)) +#loc719 = loc("kT_ptrs"(#loc246)) +#loc720 = loc("vT_ptrs"(#loc247)) +#loc721 = loc("vT_ptrs"(#loc248)) +#loc722 = loc("offs_n2"(#loc249)) +#loc762 = loc("kT"(#loc254)) +#loc763 = loc("qk"(#loc255)) +#loc764 = loc("qk"(#loc256)) +#loc765 = loc("n"(#loc257)) +#loc766 = loc("n"(#loc258)) +#loc767 = loc("m"(#loc259)) +#loc768 = loc("m"(#loc260)) +#loc769 = loc("tmp1"(#loc261)) +#loc770 = loc("tmp4"(#loc262)) +#loc771 = loc("tmp5"(#loc263)) +#loc772 = loc("tmp7"(#loc264)) +#loc773 = loc("tmp7"(#loc265)) +#loc774 = loc("tmp8"(#loc266)) +#loc775 = loc("tmp9"(#loc267)) +#loc776 = loc("tmp10"(#loc268)) +#loc777 = loc("tmp11"(#loc269)) +#loc778 = loc("tmp12"(#loc270)) +#loc779 = loc("tmp13"(#loc271)) +#loc780 = loc("tmp14"(#loc272)) +#loc781 = loc("tmp15"(#loc273)) +#loc782 = loc("tmp16"(#loc274)) +#loc783 = loc("tmp17"(#loc275)) +#loc784 = loc("tmp18"(#loc276)) +#loc785 = loc("tmp19"(#loc277)) +#loc786 = loc("tmp20"(#loc278)) +#loc787 = loc("tmp21"(#loc279)) +#loc788 = loc("tmp22"(#loc280)) +#loc789 = loc("tmp23"(#loc281)) +#loc790 = loc("tmp24"(#loc282)) +#loc791 = loc("tmp25"(#loc283)) +#loc792 = loc("tmp26"(#loc284)) +#loc793 = loc("tmp27"(#loc285)) +#loc794 = loc("tmp28"(#loc286)) +#loc795 = loc("tmp29"(#loc287)) +#loc796 = loc("tmp30"(#loc288)) +#loc797 = loc("tmp31"(#loc289)) +#loc798 = loc("tmp32"(#loc290)) +#loc799 = loc("tmp33"(#loc291)) +#loc800 = loc("tmp34"(#loc292)) +#loc801 = loc("tmp35"(#loc293)) +#loc802 = loc("tmp36"(#loc294)) +#loc803 = loc("tmp37"(#loc295)) +#loc804 = loc("tmp38"(#loc296)) +#loc805 = loc("post_mod_scores"(#loc297)) +#loc806 = loc("post_mod_scores"(#loc298)) +#loc807 = loc("p"(#loc299)) +#loc808 = loc("p"(#loc300)) +#loc809 = loc("vT"(#loc301)) +#loc810 = loc("dp"(#loc302)) +#loc811 = loc("ds"(#loc303)) +#loc812 = loc("ds"(#loc304)) +#loc813 = loc("ds"(#loc305)) +#loc814 = loc("scatter_mask"(#loc306)) +#loc815 = loc("scatter_mask"(#loc307)) +#loc816 = loc("scatter_mask"(#loc308)) +#loc817 = loc("scatter_mask"(#loc309)) +#loc818 = loc("scatter_mask"(#loc310)) +#loc819 = loc("ds"(#loc311)) +#loc820 = loc("ds"(#loc312)) +#loc821 = loc("dq"(#loc313)) +#loc822 = loc("dq"(#loc314)) +#loc823 = loc("dq"(#loc315)) +#loc829 = loc("cur_block_idx"(#loc322)) +#loc830 = loc("cur_block"(#loc323)) +#loc831 = loc("cur_block"(#loc324)) +#loc832 = loc("next_block"(#loc325)) +#loc833 = loc("next_block"(#loc326)) +#loc834 = loc("next_block"(#loc327)) +#loc835 = loc("next_block"(#loc328)) +#loc836 = loc("next_block"(#loc329)) +#loc837 = loc("needs_jump"(#loc330)) +#loc838 = loc("needs_jump"(#loc331)) +#loc839 = loc("needs_jump"(#loc332)) +#loc840 = loc("jump_to_block"(#loc333)) +#loc841 = loc("jump_to_block"(#loc334)) +#loc842 = loc("jump_to_block"(#loc335)) +#loc843 = loc("offset"(#loc336)) +#loc844 = loc("offset"(#loc337)) +#loc845 = loc("offset"(#loc338)) +#loc846 = loc("offset"(#loc339)) +#loc883 = loc("Q_LEN"(#loc343)) +#loc884 = loc("KV_LEN"(#loc344)) +#loc885 = loc("offs_k"(#loc345)) +#loc886 = loc("offs_v"(#loc346)) +#loc887 = loc("qT_ptrs"(#loc347)) +#loc888 = loc("qT_ptrs"(#loc348)) +#loc889 = loc("qT_ptrs"(#loc349)) +#loc890 = loc("qT_ptrs"(#loc350)) +#loc891 = loc("qT_ptrs"(#loc351)) +#loc892 = loc("qT_ptrs"(#loc352)) +#loc893 = loc("do_ptrs"(#loc353)) +#loc894 = loc("do_ptrs"(#loc354)) +#loc895 = loc("do_ptrs"(#loc355)) +#loc896 = loc("do_ptrs"(#loc356)) +#loc897 = loc("do_ptrs"(#loc357)) +#loc898 = loc("do_ptrs"(#loc358)) +#loc899 = loc("hi"(#loc359)) +#loc900 = loc("hi"(#loc360)) +#loc901 = loc("hi"(#loc361)) +#loc902 = loc("hi"(#loc362)) +#loc903 = loc("dk"(#loc363)) +#loc904 = loc("offset"(#loc365)) +#loc905 = loc("qT_ptrs"(#loc366)) +#loc906 = loc("qT_ptrs"(#loc367)) +#loc907 = loc("do_ptrs"(#loc368)) +#loc908 = loc("do_ptrs"(#loc369)) +#loc909 = loc("offs_m1"(#loc370)) +#loc950 = loc("qT"(#loc375)) +#loc951 = loc("lse"(#loc376)) +#loc952 = loc("lse"(#loc377)) +#loc953 = loc("lse"(#loc378)) +#loc954 = loc("lse"(#loc379)) +#loc955 = loc("qkT"(#loc380)) +#loc956 = loc("qkT"(#loc381)) +#loc957 = loc("m"(#loc382)) +#loc958 = loc("m"(#loc383)) +#loc959 = loc("n"(#loc384)) +#loc960 = loc("n"(#loc385)) +#loc961 = loc("tmp41"(#loc386)) +#loc962 = loc("tmp44"(#loc387)) +#loc963 = loc("tmp45"(#loc388)) +#loc964 = loc("tmp47"(#loc389)) +#loc965 = loc("tmp47"(#loc390)) +#loc966 = loc("tmp48"(#loc391)) +#loc967 = loc("tmp49"(#loc392)) +#loc968 = loc("tmp50"(#loc393)) +#loc969 = loc("tmp51"(#loc394)) +#loc970 = loc("tmp52"(#loc395)) +#loc971 = loc("tmp53"(#loc396)) +#loc972 = loc("tmp54"(#loc397)) +#loc973 = loc("tmp55"(#loc398)) +#loc974 = loc("tmp56"(#loc399)) +#loc975 = loc("tmp57"(#loc400)) +#loc976 = loc("tmp58"(#loc401)) +#loc977 = loc("tmp59"(#loc402)) +#loc978 = loc("tmp60"(#loc403)) +#loc979 = loc("tmp61"(#loc404)) +#loc980 = loc("tmp62"(#loc405)) +#loc981 = loc("tmp63"(#loc406)) +#loc982 = loc("tmp64"(#loc407)) +#loc983 = loc("tmp65"(#loc408)) +#loc984 = loc("tmp66"(#loc409)) +#loc985 = loc("tmp67"(#loc410)) +#loc986 = loc("tmp68"(#loc411)) +#loc987 = loc("tmp69"(#loc412)) +#loc988 = loc("tmp70"(#loc413)) +#loc989 = loc("tmp71"(#loc414)) +#loc990 = loc("tmp72"(#loc415)) +#loc991 = loc("tmp73"(#loc416)) +#loc992 = loc("tmp74"(#loc417)) +#loc993 = loc("tmp75"(#loc418)) +#loc994 = loc("tmp76"(#loc419)) +#loc995 = loc("tmp77"(#loc420)) +#loc996 = loc("tmp78"(#loc421)) +#loc997 = loc("post_mod_scores"(#loc422)) +#loc998 = loc("post_mod_scores"(#loc423)) +#loc999 = loc("pT"(#loc424)) +#loc1000 = loc("pT"(#loc425)) +#loc1001 = loc("pT"(#loc426)) +#loc1002 = loc("do"(#loc427)) +#loc1003 = loc("dv"(#loc428)) +#loc1004 = loc("dv"(#loc429)) +#loc1005 = loc("dv"(#loc430)) +#loc1006 = loc("Di"(#loc431)) +#loc1007 = loc("Di"(#loc432)) +#loc1008 = loc("dpT"(#loc433)) +#loc1009 = loc("dpT"(#loc434)) +#loc1010 = loc("dsT"(#loc435)) +#loc1011 = loc("dsT"(#loc436)) +#loc1012 = loc("dsT"(#loc437)) +#loc1013 = loc("dsT"(#loc438)) +#loc1014 = loc("dk"(#loc439)) +#loc1015 = loc("dk"(#loc440)) +#loc1016 = loc("dk"(#loc441)) +#loc1017 = loc("dk"(#loc442)) +#loc1018 = loc("SPARSE_Q_MULTIPLE"(#loc497)) +#loc1019 = loc("SPARSE_KV_MULTIPLE"(#loc498)) +#loc1020 = loc("SPARSE_Q_MULTIPLE"(#loc572)) +#loc1021 = loc("SPARSE_KV_MULTIPLE"(#loc573)) +#loc1022 = loc("dk"(#loc585)) +#loc1023 = loc("offs_n2"(#loc715)) +#loc1024 = loc("dv"(#loc903)) +#loc1025 = loc("kT_ptrs"(#loc1023)) +#loc1026 = loc("offs_m1"(#loc1024)) +#loc1027 = loc("vT_ptrs"(#loc1025)) +#loc1028 = loc("qT_ptrs"(#loc1026)) +#loc1029 = loc("do_ptrs"(#loc1028)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/HULI6VL32CI7Q6UGP75BIUMHXEMSAWXMPR3LDJJLGOCFNMUFGPVQ/triton_tem_fused_zeros_1.ttgir b/SpecForge-ext/cache/compiled_kernels/triton/0/HULI6VL32CI7Q6UGP75BIUMHXEMSAWXMPR3LDJJLGOCFNMUFGPVQ/triton_tem_fused_zeros_1.ttgir new file mode 100644 index 0000000000000000000000000000000000000000..af9c4caa9ecf419cfb4a3f1bb183ccb7bb98749f --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/HULI6VL32CI7Q6UGP75BIUMHXEMSAWXMPR3LDJJLGOCFNMUFGPVQ/triton_tem_fused_zeros_1.ttgir @@ -0,0 +1,1749 @@ +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 8], order = [0, 1]}> +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":18:0) +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 64, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 16]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}> +#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#smem = #ttg.shared_memory +#loc289 = loc("arg_Q"(#loc)) +#loc290 = loc("arg_K"(#loc)) +#loc291 = loc("arg_V"(#loc)) +#loc292 = loc("arg_LSE"(#loc)) +#loc293 = loc("arg_DELTA"(#loc)) +#loc294 = loc("arg_DO"(#loc)) +#loc295 = loc("arg_DQ"(#loc)) +#loc296 = loc("arg_DV"(#loc)) +#loc297 = loc("arg_KV_NUM_BLKS"(#loc)) +#loc298 = loc("arg_KV_IDX"(#loc)) +#loc299 = loc("arg_Q_NUM_BLKS"(#loc)) +#loc300 = loc("arg_Q_IDX"(#loc)) +#loc301 = loc("arg_FULL_KV_NUM_BLKS"(#loc)) +#loc302 = loc("arg_FULL_KV_IDX"(#loc)) +#loc303 = loc("arg_FULL_Q_NUM_BLKS"(#loc)) +#loc304 = loc("arg_FULL_Q_IDX"(#loc)) +#loc305 = loc("in_ptr16"(#loc)) +#loc306 = loc("out_ptr0"(#loc)) +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @triton_tem_fused_zeros_1(%arg_Q: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_Q"(#loc)), %arg_K: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_K"(#loc)), %arg_V: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_V"(#loc)), %arg_LSE: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_LSE"(#loc)), %arg_DELTA: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_DELTA"(#loc)), %arg_DO: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_DO"(#loc)), %arg_DQ: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_DQ"(#loc)), %arg_DV: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_DV"(#loc)), %arg_KV_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_KV_NUM_BLKS"(#loc)), %arg_KV_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_KV_IDX"(#loc)), %arg_Q_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_Q_NUM_BLKS"(#loc)), %arg_Q_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_Q_IDX"(#loc)), %arg_FULL_KV_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_KV_NUM_BLKS"(#loc)), %arg_FULL_KV_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_KV_IDX"(#loc)), %arg_FULL_Q_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_Q_NUM_BLKS"(#loc)), %arg_FULL_Q_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_Q_IDX"(#loc)), %in_ptr16: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr16"(#loc)), %out_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr0"(#loc))) attributes {noinline = false} { + %cst = arith.constant dense<0> : tensor<128x1xi32, #mma> loc(#loc1) + %cst_0 = arith.constant dense<2048> : tensor<128x1xi32, #mma> loc(#loc1) + %cst_1 = arith.constant dense<128> : tensor<128x1xi32, #blocked> loc(#loc1) + %cst_2 = arith.constant dense<4096> : tensor<128x1xi32, #blocked> loc(#loc1) + %c256_i32 = arith.constant 256 : i32 loc(#loc1) + %c16_i32 = arith.constant 16 : i32 loc(#loc1) + %c4_i32 = arith.constant 4 : i32 loc(#loc1) + %c8388608_i32 = arith.constant 8388608 : i32 loc(#loc1) + %c128_i32 = arith.constant 128 : i32 loc(#loc1) + %c4096_i32 = arith.constant 4096 : i32 loc(#loc1) + %c1_i32 = arith.constant 1 : i32 loc(#loc1) + %c2097152_i32 = arith.constant 2097152 : i32 loc(#loc1) + %c262144_i32 = arith.constant 262144 : i32 loc(#loc1) + %c2_i32 = arith.constant 2 : i32 loc(#loc1) + %c32_i32 = arith.constant 32 : i32 loc(#loc1) + %c2048_i32 = arith.constant 2048 : i32 loc(#loc1) + %c0_i32 = arith.constant 0 : i32 loc(#loc1) + %c64_i32 = arith.constant 64 : i32 loc(#loc1) + %cst_3 = arith.constant dense<0.0883883461> : tensor<128x128xf32, #mma1> loc(#loc1) + %cst_4 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma1> loc(#loc1) + %cst_5 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> loc(#loc1) + %cst_6 = arith.constant dense<0.0883883461> : tensor<128x64xf32, #mma> loc(#loc1) + %cst_7 = arith.constant dense<0xFF800000> : tensor<128x64xf32, #mma> loc(#loc1) + %cst_8 = arith.constant dense<1.44269502> : tensor<128x64xf32, #mma> loc(#loc1) + %true = arith.constant true loc(#loc1) + %c-1_i32 = arith.constant -1 : i32 loc(#loc1) + %c3_i32 = arith.constant 3 : i32 loc(#loc1) + %cst_9 = arith.constant dense<8192> : tensor<128x64xi32, #blocked1> loc(#loc1) + %cst_10 = arith.constant dense<262144> : tensor<128x64xi32, #blocked1> loc(#loc1) + %cst_11 = arith.constant dense<8192> : tensor<64x128xi32, #blocked> loc(#loc1) + %cst_12 = arith.constant dense<64> : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc1) + %cst_13 = arith.constant dense<2048> : tensor<128x1xi32, #blocked> loc(#loc1) + %cst_14 = arith.constant dense<128> : tensor<64x1xi32, #blocked> loc(#loc1) + %cst_15 = arith.constant dense<4096> : tensor<1x64xi32, #blocked1> loc(#loc1) + %cst_16 = arith.constant dense<128> : tensor<1x64xi32, #blocked1> loc(#loc1) + %cst_17 = arith.constant dense<2048> : tensor<1x64xi32, #mma> loc(#loc1) + %cst_18 = arith.constant dense<2048> : tensor<128x64xi32, #mma> loc(#loc1) + %cst_19 = arith.constant dense<0xFF800000> : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc1) + %cst_20 = arith.constant dense<0.000000e+00> : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc1) + %cst_21 = arith.constant dense<0> : tensor<128x64xi32, #mma> loc(#loc1) + %cst_22 = arith.constant dense<0> : tensor<1x64xi32, #mma> loc(#loc1) + %cst_23 = arith.constant dense<0.000000e+00> : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> loc(#loc1) + %cst_24 = arith.constant dense<0xFF800000> : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> loc(#loc1) + %pid = tt.get_program_id x : i32 loc(#loc307) + %off_zq = tt.get_program_id y : i32 loc(#loc308) + %off_hkv = tt.get_program_id z : i32 loc(#loc309) + %off_zkv = arith.remsi %off_zq, %c2_i32 : i32 loc(#loc310) + %k_adj = arith.muli %off_hkv, %c262144_i32 : i32 loc(#loc311) + %k_adj_25 = arith.muli %off_zkv, %c2097152_i32 : i32 loc(#loc312) + %k_adj_26 = arith.addi %k_adj, %k_adj_25 : i32 loc(#loc313) + %k_adj_27 = arith.extsi %k_adj_26 : i32 to i64 loc(#loc314) + %dv_adj = arith.muli %off_zq, %c2097152_i32 : i32 loc(#loc315) + %dv_adj_28 = arith.addi %k_adj, %dv_adj : i32 loc(#loc316) + %dv_adj_29 = arith.extsi %dv_adj_28 : i32 to i64 loc(#loc317) + %K = tt.addptr %arg_K, %k_adj_27 : !tt.ptr, i64 loc(#loc318) + %V = tt.addptr %arg_V, %k_adj_27 : !tt.ptr, i64 loc(#loc319) + %DV = tt.addptr %arg_DV, %dv_adj_29 : !tt.ptr, i64 loc(#loc320) + %offs_k = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc321) + %offs_k_30 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> loc(#loc321) + %0 = arith.cmpi sge, %pid, %c16_i32 : i32 loc(#loc17) + scf.if %0 { + %off_pid = arith.subi %pid, %c16_i32 : i32 loc(#loc322) + %off_hq2 = arith.divsi %off_pid, %c16_i32 : i32 loc(#loc323) + %off_hq2_31 = arith.muli %off_hkv, %c4_i32 : i32 loc(#loc324) + %off_hq2_32 = arith.addi %off_hq2, %off_hq2_31 : i32 loc(#loc325) + %start_m2_block = arith.remsi %off_pid, %c16_i32 : i32 loc(#loc326) + %sparse_kv_num_blks_offset = arith.muli %off_zkv, %c16_i32 : i32 loc(#loc327) + %sparse_kv_num_blks_offset_33 = arith.addi %sparse_kv_num_blks_offset, %start_m2_block : i32 loc(#loc328) + %sparse_kv_idx_offset = arith.muli %off_zkv, %c256_i32 : i32 loc(#loc329) + %sparse_kv_idx_offset_34 = arith.muli %start_m2_block, %c16_i32 : i32 loc(#loc330) + %sparse_kv_idx_offset_35 = arith.addi %sparse_kv_idx_offset, %sparse_kv_idx_offset_34 : i32 loc(#loc331) + %q_adj2 = arith.muli %off_hq2_32, %c128_i32 : i32 loc(#loc332) + %q_adj2_36 = arith.muli %off_zq, %c8388608_i32 : i32 loc(#loc333) + %q_adj2_37 = arith.addi %q_adj2, %q_adj2_36 : i32 loc(#loc334) + %q_adj2_38 = arith.extsi %q_adj2_37 : i32 to i64 loc(#loc335) + %do_adj2 = arith.muli %off_hq2_32, %c262144_i32 : i32 loc(#loc336) + %do_adj2_39 = arith.addi %do_adj2, %q_adj2_36 : i32 loc(#loc337) + %do_adj2_40 = arith.extsi %do_adj2_39 : i32 to i64 loc(#loc338) + %off_chz2 = arith.muli %off_zq, %c32_i32 : i32 loc(#loc339) + %off_chz2_41 = arith.addi %off_chz2, %off_hq2_32 : i32 loc(#loc340) + %off_chz2_42 = arith.muli %off_chz2_41, %c2048_i32 : i32 loc(#loc341) + %off_chz2_43 = arith.extsi %off_chz2_42 : i32 to i64 loc(#loc342) + %Q2 = tt.addptr %arg_Q, %q_adj2_38 : !tt.ptr, i64 loc(#loc343) + %DO2 = tt.addptr %arg_DO, %do_adj2_40 : !tt.ptr, i64 loc(#loc344) + %DQ2 = tt.addptr %arg_DQ, %q_adj2_38 : !tt.ptr, i64 loc(#loc345) + %LSE2 = tt.addptr %arg_LSE, %off_chz2_43 : !tt.ptr, i64 loc(#loc346) + %DELTA2 = tt.addptr %arg_DELTA, %off_chz2_43 : !tt.ptr, i64 loc(#loc347) + %start_m2 = arith.muli %start_m2_block, %c128_i32 : i32 loc(#loc348) + %offs_m2 = tt.splat %start_m2 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc349) + %offs_m2_44 = tt.splat %start_m2 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> loc(#loc349) + %offs_m2_45 = arith.addi %offs_m2, %offs_k : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc349) + %offs_m2_46 = arith.addi %offs_m2_44, %offs_k_30 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> loc(#loc349) + %ptr = tt.expand_dims %offs_m2_45 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> loc(#loc584) + %ptr_47 = tt.expand_dims %offs_m2_46 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xi32, #mma> loc(#loc584) + %ptr_48 = arith.muli %ptr, %cst_2 : tensor<128x1xi32, #blocked> loc(#loc585) + %ptr_49 = tt.splat %Q2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> loc(#loc586) + %ptr_50 = tt.addptr %ptr_49, %ptr_48 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> loc(#loc586) + %ptr_51 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc587) + %ptr_52 = tt.expand_dims %ptr_51 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> loc(#loc587) + %ptr_53 = tt.broadcast %ptr_50 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x128x!tt.ptr, #blocked> loc(#loc588) + %ptr_54 = tt.broadcast %ptr_52 : tensor<1x128xi32, #blocked> -> tensor<128x128xi32, #blocked> loc(#loc588) + %ptr_55 = tt.addptr %ptr_53, %ptr_54 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> loc(#loc588) + %q = tt.load %ptr_55 : tensor<128x128x!tt.ptr, #blocked> loc(#loc589) + %q_56 = ttg.local_alloc %q : (tensor<128x128xbf16, #blocked>) -> !ttg.memdesc<128x128xbf16, #shared, #smem> loc(#loc589) + %ptr_57 = arith.muli %ptr, %cst_1 : tensor<128x1xi32, #blocked> loc(#loc590) + %ptr_58 = tt.splat %DO2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> loc(#loc591) + %ptr_59 = tt.addptr %ptr_58, %ptr_57 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> loc(#loc591) + %ptr_60 = tt.broadcast %ptr_59 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x128x!tt.ptr, #blocked> loc(#loc592) + %ptr_61 = tt.addptr %ptr_60, %ptr_54 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> loc(#loc592) + %do = tt.load %ptr_61 : tensor<128x128x!tt.ptr, #blocked> loc(#loc593) + %do_62 = ttg.local_alloc %do : (tensor<128x128xbf16, #blocked>) -> !ttg.memdesc<128x128xbf16, #shared, #smem> loc(#loc593) + %Di = tt.splat %DELTA2 : !tt.ptr -> tensor<128x!tt.ptr, #ttg.slice<{dim = 1, parent = #mma}>> loc(#loc357) + %Di_63 = tt.addptr %Di, %offs_m2_46 : tensor<128x!tt.ptr, #ttg.slice<{dim = 1, parent = #mma}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> loc(#loc357) + %Di_64 = tt.load %Di_63 : tensor<128x!tt.ptr, #ttg.slice<{dim = 1, parent = #mma}>> loc(#loc358) + %lse = tt.splat %LSE2 : !tt.ptr -> tensor<128x!tt.ptr, #ttg.slice<{dim = 1, parent = #mma}>> loc(#loc359) + %lse_65 = tt.addptr %lse, %offs_m2_46 : tensor<128x!tt.ptr, #ttg.slice<{dim = 1, parent = #mma}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> loc(#loc359) + %lse_66 = tt.load %lse_65 : tensor<128x!tt.ptr, #ttg.slice<{dim = 1, parent = #mma}>> loc(#loc360) + %lse_67 = arith.cmpf oeq, %lse_66, %cst_24 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> loc(#loc361) + %lse_68 = arith.select %lse_67, %cst_23, %lse_66 : tensor<128xi1, #ttg.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> loc(#loc362) + %lse_69 = tt.expand_dims %lse_68 {axis = 1 : i32} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma> loc(#loc363) + %kv_indices = tt.addptr %arg_KV_IDX, %sparse_kv_idx_offset_35 : !tt.ptr, i32 loc(#loc364) + %kv_start = tt.load %kv_indices : !tt.ptr loc(#loc365) + %kv_start_70 = arith.muli %kv_start, %c128_i32 : i32 loc(#loc366) + %sparse_kv_num_blocks = tt.addptr %arg_KV_NUM_BLKS, %sparse_kv_num_blks_offset_33 : !tt.ptr, i32 loc(#loc367) + %sparse_kv_num_blocks_71 = tt.load %sparse_kv_num_blocks : !tt.ptr loc(#loc368) + %offs_n2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc369) + %offs_n2_72 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc369) + %offs_n2_73 = tt.splat %kv_start_70 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc370) + %offs_n2_74 = tt.splat %kv_start_70 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc370) + %offs_n2_75 = arith.addi %offs_n2_73, %offs_n2 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc370) + %offs_n2_76 = arith.addi %offs_n2_74, %offs_n2_72 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc370) + %kT_ptrs = tt.expand_dims %offs_n2_75 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc594) + %kT_ptrs_77 = arith.muli %kT_ptrs, %cst_16 : tensor<1x64xi32, #blocked1> loc(#loc595) + %kT_ptrs_78 = tt.splat %K : !tt.ptr -> tensor<1x64x!tt.ptr, #blocked1> loc(#loc596) + %kT_ptrs_79 = tt.addptr %kT_ptrs_78, %kT_ptrs_77 : tensor<1x64x!tt.ptr, #blocked1>, tensor<1x64xi32, #blocked1> loc(#loc596) + %kT_ptrs_80 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc597) + %kT_ptrs_81 = tt.expand_dims %kT_ptrs_80 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc597) + %kT_ptrs_82 = tt.broadcast %kT_ptrs_79 : tensor<1x64x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> loc(#loc598) + %kT_ptrs_83 = tt.broadcast %kT_ptrs_81 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc598) + %kT_ptrs_84 = tt.addptr %kT_ptrs_82, %kT_ptrs_83 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc598) + %vT_ptrs = tt.splat %V : !tt.ptr -> tensor<1x64x!tt.ptr, #blocked1> loc(#loc599) + %vT_ptrs_85 = tt.addptr %vT_ptrs, %kT_ptrs_77 : tensor<1x64x!tt.ptr, #blocked1>, tensor<1x64xi32, #blocked1> loc(#loc599) + %vT_ptrs_86 = tt.broadcast %vT_ptrs_85 : tensor<1x64x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> loc(#loc600) + %vT_ptrs_87 = tt.addptr %vT_ptrs_86, %kT_ptrs_83 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc600) + %hi = arith.muli %sparse_kv_num_blocks_71, %c2_i32 : i32 loc(#loc601) + %hi_88 = arith.minsi %hi, %c32_i32 : i32 loc(#loc602) + %tmp4 = tt.broadcast %ptr_47 : tensor<128x1xi32, #mma> -> tensor<128x64xi32, #mma> loc(#loc747) + %tmp7 = tt.addptr %in_ptr16, %off_zq : !tt.ptr, i32 loc(#loc748) + %vT_ptrs_89 = arith.cmpi sgt, %hi_88, %c0_i32 : i32 loc(#loc886) + %tmp7_90 = tt.load %tmp7, %vT_ptrs_89 : !tt.ptr loc(#loc750) + %tmp8 = tt.splat %tmp7_90 : i64 -> tensor<1x64xi64, #mma> loc(#loc751) + %tmp9 = arith.extsi %ptr_47 : tensor<128x1xi32, #mma> to tensor<128x1xi64, #mma> loc(#loc752) + %tmp10 = tt.splat %tmp7_90 : i64 -> tensor<128x1xi64, #mma> loc(#loc753) + %tmp10_91 = arith.cmpi slt, %tmp9, %tmp10 : tensor<128x1xi64, #mma> loc(#loc753) + %tmp11 = tt.broadcast %tmp10_91 : tensor<128x1xi1, #mma> -> tensor<128x64xi1, #mma> loc(#loc754) + %p = tt.broadcast %lse_69 : tensor<128x1xf32, #mma> -> tensor<128x64xf32, #mma> loc(#loc755) + %ds = tt.expand_dims %Di_64 {axis = 1 : i32} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma> loc(#loc756) + %ds_92 = tt.broadcast %ds : tensor<128x1xf32, #mma> -> tensor<128x64xf32, #mma> loc(#loc757) + %kT = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> loc(#loc881) + %vT = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> loc(#loc882) + %kT_93 = ttg.memdesc_index %kT[%c0_i32] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc881) + %vT_ptrs_94 = tt.splat %vT_ptrs_89 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc886) + %kT_95 = ttg.async_copy_global_to_local %kT_ptrs_84, %kT_93 mask %vT_ptrs_94 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc881) + %kT_96 = ttg.async_commit_group tokens %kT_95 loc(#loc881) + %vT_97 = ttg.memdesc_index %vT[%c0_i32] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc882) + %vT_98 = ttg.async_copy_global_to_local %vT_ptrs_87, %vT_97 mask %vT_ptrs_94 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc882) + %vT_99 = ttg.async_commit_group tokens %vT_98 loc(#loc882) + %vT_ptrs_100 = arith.cmpi sgt, %hi_88, %c1_i32 : i32 loc(#loc886) + %kT_ptrs_101 = tt.addptr %kT_ptrs_84, %cst_9 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc605) + %vT_ptrs_102 = tt.addptr %vT_ptrs_87, %cst_9 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc606) + %kT_103 = ttg.memdesc_index %kT[%c1_i32] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc881) + %vT_ptrs_104 = tt.splat %vT_ptrs_100 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc886) + %kT_105 = ttg.async_copy_global_to_local %kT_ptrs_101, %kT_103 mask %vT_ptrs_104 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc881) + %kT_106 = ttg.async_commit_group tokens %kT_105 loc(#loc881) + %vT_107 = ttg.memdesc_index %vT[%c1_i32] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc882) + %vT_108 = ttg.async_copy_global_to_local %vT_ptrs_102, %vT_107 mask %vT_ptrs_104 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc882) + %vT_109 = ttg.async_commit_group tokens %vT_108 loc(#loc882) + ttng.fence_async_shared {bCluster = false} loc(#loc760) + %vT_ptrs_110:11 = scf.for %vT_ptrs_156 = %c0_i32 to %hi_88 step %c1_i32 iter_args(%arg19 = %cst_4, %kT_ptrs_157 = %kT_ptrs_101, %vT_ptrs_158 = %vT_ptrs_102, %offs_n2_159 = %offs_n2_76, %arg23 = %c1_i32, %arg24 = %c-1_i32, %kT_160 = %kT_96, %kT_161 = %kT_106, %vT_162 = %vT_99, %vT_163 = %vT_109, %arg29 = %c64_i32) -> (tensor<128x128xf32, #mma1>, tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>>, i32, i32, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, i32) : i32 { + %vT_ptrs_164 = arith.subi %hi_88, %c2_i32 : i32 loc(#loc886) + %vT_ptrs_165 = arith.cmpi slt, %vT_ptrs_156, %vT_ptrs_164 : i32 loc(#loc886) + %vT_ptrs_166 = arith.subi %hi_88, %c1_i32 : i32 loc(#loc886) + %vT_ptrs_167 = arith.cmpi slt, %vT_ptrs_156, %vT_ptrs_166 : i32 loc(#loc886) + %vT_ptrs_168 = arith.addi %arg24, %c1_i32 : i32 loc(#loc886) + %vT_ptrs_169 = arith.cmpi sge, %vT_ptrs_168, %c3_i32 : i32 loc(#loc886) + %vT_ptrs_170 = arith.select %vT_ptrs_169, %c0_i32, %vT_ptrs_168 : i32 loc(#loc886) + %kT_171 = ttg.async_wait %kT_160, %vT_162 {num = 2 : i32} loc(#loc881) + %kT_172 = ttg.memdesc_index %kT[%vT_ptrs_170] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc881) + %dq_173 = ttg.memdesc_trans %kT_172 {order = array} : !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> -> !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc761) + %qk = ttng.warp_group_dot %q_56, %kT_172, %cst_5 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x128xbf16, #shared, #smem> * !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> -> tensor<128x64xf32, #mma> loc(#loc760) + %qk_174:3 = ttng.warp_group_dot_wait %qk, %q_56, %kT_172 {pendings = 0 : i32} : tensor<128x64xf32, #mma>, !ttg.memdesc<128x128xbf16, #shared, #smem>, !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc760) + %qk_175 = arith.mulf %qk_174#0, %cst_6 : tensor<128x64xf32, #mma> loc(#loc762) + %n = tt.expand_dims %offs_n2_159 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xi32, #mma> loc(#loc763) + %tmp4_176 = tt.broadcast %n : tensor<1x64xi32, #mma> -> tensor<128x64xi32, #mma> loc(#loc747) + %tmp4_177 = arith.cmpi sge, %tmp4, %tmp4_176 : tensor<128x64xi32, #mma> loc(#loc747) + %tmp5 = arith.extsi %n : tensor<1x64xi32, #mma> to tensor<1x64xi64, #mma> loc(#loc764) + %tmp8_178 = arith.cmpi slt, %tmp5, %tmp8 : tensor<1x64xi64, #mma> loc(#loc751) + %tmp11_179 = tt.broadcast %tmp8_178 : tensor<1x64xi1, #mma> -> tensor<128x64xi1, #mma> loc(#loc754) + %tmp11_180 = arith.andi %tmp11_179, %tmp11 : tensor<128x64xi1, #mma> loc(#loc754) + %tmp12 = arith.andi %tmp4_177, %tmp11_180 : tensor<128x64xi1, #mma> loc(#loc765) + %tmp15 = arith.cmpi sge, %n, %cst_17 : tensor<1x64xi32, #mma> loc(#loc766) + %tmp16 = arith.remsi %n, %cst_17 : tensor<1x64xi32, #mma> loc(#loc767) + %tmp18 = arith.cmpi ne, %tmp16, %cst_22 : tensor<1x64xi32, #mma> loc(#loc768) + %tmp19 = arith.cmpi slt, %tmp16, %cst_22 : tensor<1x64xi32, #mma> loc(#loc769) + %tmp22 = arith.andi %tmp18, %tmp19 : tensor<1x64xi1, #mma> loc(#loc770) + %tmp23 = arith.addi %tmp16, %cst_17 : tensor<1x64xi32, #mma> loc(#loc771) + %tmp24 = arith.select %tmp22, %tmp23, %tmp16 : tensor<1x64xi1, #mma>, tensor<1x64xi32, #mma> loc(#loc772) + %tmp25 = arith.extsi %tmp24 : tensor<1x64xi32, #mma> to tensor<1x64xi64, #mma> loc(#loc773) + %tmp26 = arith.cmpi slt, %tmp25, %tmp8 : tensor<1x64xi64, #mma> loc(#loc774) + %tmp27 = arith.andi %tmp15, %tmp26 : tensor<1x64xi1, #mma> loc(#loc775) + %tmp28 = arith.subi %tmp4_176, %tmp4 : tensor<128x64xi32, #mma> loc(#loc776) + %tmp29 = arith.remsi %tmp28, %cst_18 : tensor<128x64xi32, #mma> loc(#loc777) + %tmp30 = arith.cmpi ne, %tmp29, %cst_21 : tensor<128x64xi32, #mma> loc(#loc778) + %tmp31 = arith.cmpi slt, %tmp29, %cst_21 : tensor<128x64xi32, #mma> loc(#loc779) + %tmp33 = arith.andi %tmp30, %tmp31 : tensor<128x64xi1, #mma> loc(#loc780) + %tmp34 = arith.addi %tmp29, %cst_18 : tensor<128x64xi32, #mma> loc(#loc781) + %tmp35 = arith.select %tmp33, %tmp34, %tmp29 : tensor<128x64xi1, #mma>, tensor<128x64xi32, #mma> loc(#loc782) + %tmp36 = arith.cmpi eq, %tmp35, %cst_21 : tensor<128x64xi32, #mma> loc(#loc783) + %tmp37 = tt.broadcast %tmp27 : tensor<1x64xi1, #mma> -> tensor<128x64xi1, #mma> loc(#loc784) + %tmp37_181 = arith.andi %tmp37, %tmp36 : tensor<128x64xi1, #mma> loc(#loc784) + %tmp38 = arith.ori %tmp12, %tmp37_181 : tensor<128x64xi1, #mma> loc(#loc785) + %post_mod_scores = arith.select %tmp38, %qk_175, %cst_7 : tensor<128x64xi1, #mma>, tensor<128x64xf32, #mma> loc(#loc786) + %post_mod_scores_182 = arith.mulf %post_mod_scores, %cst_8 : tensor<128x64xf32, #mma> loc(#loc787) + %p_183 = arith.subf %post_mod_scores_182, %p : tensor<128x64xf32, #mma> loc(#loc755) + %p_184 = math.exp2 %p_183 : tensor<128x64xf32, #mma> loc(#loc788) + %vT_185 = ttg.memdesc_index %vT[%vT_ptrs_170] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc882) + %dp = ttng.warp_group_dot %do_62, %vT_185, %cst_5 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x128xbf16, #shared, #smem> * !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> -> tensor<128x64xf32, #mma> loc(#loc789) + %dp_186:3 = ttng.warp_group_dot_wait %dp, %do_62, %vT_185 {pendings = 0 : i32} : tensor<128x64xf32, #mma>, !ttg.memdesc<128x128xbf16, #shared, #smem>, !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc789) + %ds_187 = arith.subf %dp_186#0, %ds_92 : tensor<128x64xf32, #mma> loc(#loc757) + %ds_188 = arith.mulf %p_184, %ds_187 : tensor<128x64xf32, #mma> loc(#loc790) + %ds_189 = arith.select %tmp38, %ds_188, %cst_5 : tensor<128x64xi1, #mma>, tensor<128x64xf32, #mma> loc(#loc791) + %ds_190 = arith.truncf %ds_189 : tensor<128x64xf32, #mma> to tensor<128x64xbf16, #mma> loc(#loc792) + %ds_191 = ttg.convert_layout %ds_190 : tensor<128x64xbf16, #mma> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> loc(#loc792) + %dq_192 = ttng.warp_group_dot %ds_191, %dq_173, %arg19 {inputPrecision = 0 : i32, isAsync = true} : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> -> tensor<128x128xf32, #mma1> loc(#loc793) + %offs_n2_193 = tt.splat %arg29 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc607) + %offs_n2_194 = arith.addi %offs_n2_159, %offs_n2_193 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc607) + %vT_ptrs_195 = arith.addi %vT_ptrs_156, %c1_i32 : i32 loc(#loc886) + %cur_block_idx = arith.divsi %vT_ptrs_195, %c2_i32 : i32 loc(#loc794) + %cur_block = tt.addptr %kv_indices, %cur_block_idx : !tt.ptr, i32 loc(#loc795) + %cur_block_196 = tt.load %cur_block, %vT_ptrs_167 evictionPolicy = evict_last : !tt.ptr loc(#loc796) + %next_block = arith.addi %cur_block_idx, %c1_i32 : i32 loc(#loc797) + %next_block_197 = arith.cmpi slt, %next_block, %sparse_kv_num_blocks_71 : i32 loc(#loc798) + %next_block_198 = tt.addptr %cur_block, %c1_i32 : !tt.ptr, i32 loc(#loc799) + %vT_ptrs_199 = arith.andi %vT_ptrs_167, %next_block_197 : i1 loc(#loc886) + %next_block_200 = tt.load %next_block_198, %vT_ptrs_199 evictionPolicy = evict_last : !tt.ptr loc(#loc800) + %needs_jump = arith.addi %vT_ptrs_156, %c2_i32 : i32 loc(#loc801) + %needs_jump_201 = arith.remsi %needs_jump, %c2_i32 : i32 loc(#loc802) + %needs_jump_202 = arith.cmpi eq, %needs_jump_201, %c0_i32 : i32 loc(#loc803) + %jump_to_block = arith.subi %next_block_200, %cur_block_196 : i32 loc(#loc804) + %jump_to_block_203 = arith.muli %jump_to_block, %c128_i32 : i32 loc(#loc805) + %jump_to_block_204 = arith.subi %jump_to_block_203, %c64_i32 : i32 loc(#loc806) + %offset = arith.extui %needs_jump_202 : i1 to i32 loc(#loc807) + %offset_205 = arith.muli %jump_to_block_204, %offset : i32 loc(#loc807) + %offset_206 = arith.subi %c1_i32, %offset : i32 loc(#loc808) + %offset_207 = arith.muli %offset_206, %c64_i32 : i32 loc(#loc809) + %offset_208 = arith.addi %offset_205, %offset_207 : i32 loc(#loc810) + %kT_ptrs_209 = arith.muli %offset_208, %c128_i32 : i32 loc(#loc609) + %kT_ptrs_210 = tt.splat %kT_ptrs_209 : i32 -> tensor<128x64xi32, #blocked1> loc(#loc605) + %kT_ptrs_211 = tt.addptr %kT_ptrs_157, %kT_ptrs_210 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc605) + %vT_ptrs_212 = tt.addptr %vT_ptrs_158, %kT_ptrs_210 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc606) + %vT_ptrs_213 = arith.addi %arg23, %c1_i32 : i32 loc(#loc886) + %vT_ptrs_214 = arith.cmpi sge, %vT_ptrs_213, %c3_i32 : i32 loc(#loc886) + %vT_ptrs_215 = arith.select %vT_ptrs_214, %c0_i32, %vT_ptrs_213 : i32 loc(#loc886) + %kT_216 = ttg.memdesc_index %kT[%vT_ptrs_215] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc881) + %vT_ptrs_217 = tt.splat %vT_ptrs_165 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc886) + %kT_218 = ttg.async_copy_global_to_local %kT_ptrs_211, %kT_216 mask %vT_ptrs_217 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc881) + %kT_219 = ttg.async_commit_group tokens %kT_218 loc(#loc881) + %vT_220 = ttg.memdesc_index %vT[%vT_ptrs_215] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc882) + %vT_221 = ttg.async_copy_global_to_local %vT_ptrs_212, %vT_220 mask %vT_ptrs_217 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc882) + %vT_222 = ttg.async_commit_group tokens %vT_221 loc(#loc882) + scf.yield %dq_192, %kT_ptrs_211, %vT_ptrs_212, %offs_n2_194, %vT_ptrs_215, %vT_ptrs_170, %kT_161, %kT_219, %vT_163, %vT_222, %offset_208 : tensor<128x128xf32, #mma1>, tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>>, i32, i32, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, i32 loc(#loc886) + } loc(#loc886) + %vT_ptrs_111 = ttng.warp_group_dot_wait %vT_ptrs_110#0 {pendings = 0 : i32} : tensor<128x128xf32, #mma1> loc(#loc886) + %vT_ptrs_112 = ttg.async_wait {num = 0 : i32} loc(#loc886) + ttg.local_dealloc %vT : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> loc(#loc886) + ttg.local_dealloc %kT : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> loc(#loc886) + %kv_indices_113 = tt.addptr %arg_FULL_KV_IDX, %sparse_kv_idx_offset_35 : !tt.ptr, i32 loc(#loc451) + %kv_start_114 = tt.load %kv_indices_113 : !tt.ptr loc(#loc452) + %kv_start_115 = arith.muli %kv_start_114, %c128_i32 : i32 loc(#loc453) + %sparse_kv_num_blocks_116 = tt.addptr %arg_FULL_KV_NUM_BLKS, %sparse_kv_num_blks_offset_33 : !tt.ptr, i32 loc(#loc454) + %sparse_kv_num_blocks_117 = tt.load %sparse_kv_num_blocks_116 : !tt.ptr loc(#loc455) + %offs_n2_118 = tt.splat %kv_start_115 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc456) + %offs_n2_119 = arith.addi %offs_n2_118, %offs_n2 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc456) + %kT_ptrs_120 = tt.expand_dims %offs_n2_119 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc610) + %kT_ptrs_121 = arith.muli %kT_ptrs_120, %cst_16 : tensor<1x64xi32, #blocked1> loc(#loc611) + %kT_ptrs_122 = tt.addptr %kT_ptrs_78, %kT_ptrs_121 : tensor<1x64x!tt.ptr, #blocked1>, tensor<1x64xi32, #blocked1> loc(#loc612) + %kT_ptrs_123 = tt.broadcast %kT_ptrs_122 : tensor<1x64x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> loc(#loc613) + %kT_ptrs_124 = tt.addptr %kT_ptrs_123, %kT_ptrs_83 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc613) + %vT_ptrs_125 = tt.addptr %vT_ptrs, %kT_ptrs_121 : tensor<1x64x!tt.ptr, #blocked1>, tensor<1x64xi32, #blocked1> loc(#loc614) + %vT_ptrs_126 = tt.broadcast %vT_ptrs_125 : tensor<1x64x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> loc(#loc615) + %vT_ptrs_127 = tt.addptr %vT_ptrs_126, %kT_ptrs_83 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc615) + %hi_128 = arith.muli %sparse_kv_num_blocks_117, %c2_i32 : i32 loc(#loc616) + %hi_129 = arith.minsi %hi_128, %c32_i32 : i32 loc(#loc617) + %kT_130 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> loc(#loc883) + %vT_131 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> loc(#loc884) + %vT_ptrs_132 = arith.cmpi sgt, %hi_129, %c0_i32 : i32 loc(#loc887) + %kT_133 = ttg.memdesc_index %kT_130[%c0_i32] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc883) + %vT_ptrs_134 = tt.splat %vT_ptrs_132 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc887) + %kT_135 = ttg.async_copy_global_to_local %kT_ptrs_124, %kT_133 mask %vT_ptrs_134 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc883) + %kT_136 = ttg.async_commit_group tokens %kT_135 loc(#loc883) + %vT_137 = ttg.memdesc_index %vT_131[%c0_i32] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc884) + %vT_138 = ttg.async_copy_global_to_local %vT_ptrs_127, %vT_137 mask %vT_ptrs_134 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc884) + %vT_139 = ttg.async_commit_group tokens %vT_138 loc(#loc884) + %vT_ptrs_140 = arith.cmpi sgt, %hi_129, %c1_i32 : i32 loc(#loc887) + %kT_ptrs_141 = tt.addptr %kT_ptrs_124, %cst_9 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc619) + %vT_ptrs_142 = tt.addptr %vT_ptrs_127, %cst_9 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc620) + %kT_143 = ttg.memdesc_index %kT_130[%c1_i32] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc883) + %vT_ptrs_144 = tt.splat %vT_ptrs_140 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc887) + %kT_145 = ttg.async_copy_global_to_local %kT_ptrs_141, %kT_143 mask %vT_ptrs_144 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc883) + %kT_146 = ttg.async_commit_group tokens %kT_145 loc(#loc883) + %vT_147 = ttg.memdesc_index %vT_131[%c1_i32] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc884) + %vT_148 = ttg.async_copy_global_to_local %vT_ptrs_142, %vT_147 mask %vT_ptrs_144 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc884) + %vT_149 = ttg.async_commit_group tokens %vT_148 loc(#loc884) + ttng.fence_async_shared {bCluster = false} loc(#loc813) + %vT_ptrs_150:9 = scf.for %vT_ptrs_156 = %c0_i32 to %hi_129 step %c1_i32 iter_args(%vT_ptrs_157 = %vT_ptrs_111, %kT_ptrs_158 = %kT_ptrs_141, %vT_ptrs_159 = %vT_ptrs_142, %arg22 = %c1_i32, %arg23 = %c-1_i32, %kT_160 = %kT_136, %kT_161 = %kT_146, %vT_162 = %vT_139, %vT_163 = %vT_149) -> (tensor<128x128xf32, #mma1>, tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64x!tt.ptr, #blocked1>, i32, i32, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token) : i32 { + %vT_ptrs_164 = arith.subi %hi_129, %c2_i32 : i32 loc(#loc887) + %vT_ptrs_165 = arith.cmpi slt, %vT_ptrs_156, %vT_ptrs_164 : i32 loc(#loc887) + %vT_ptrs_166 = arith.subi %hi_129, %c1_i32 : i32 loc(#loc887) + %vT_ptrs_167 = arith.cmpi slt, %vT_ptrs_156, %vT_ptrs_166 : i32 loc(#loc887) + %vT_ptrs_168 = arith.addi %arg23, %c1_i32 : i32 loc(#loc887) + %vT_ptrs_169 = arith.cmpi sge, %vT_ptrs_168, %c3_i32 : i32 loc(#loc887) + %vT_ptrs_170 = arith.select %vT_ptrs_169, %c0_i32, %vT_ptrs_168 : i32 loc(#loc887) + %kT_171 = ttg.async_wait %kT_160, %vT_162 {num = 2 : i32} loc(#loc883) + %kT_172 = ttg.memdesc_index %kT_130[%vT_ptrs_170] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc883) + %dq_173 = ttg.memdesc_trans %kT_172 {order = array} : !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> -> !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc814) + %qk = ttng.warp_group_dot %q_56, %kT_172, %cst_5 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x128xbf16, #shared, #smem> * !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> -> tensor<128x64xf32, #mma> loc(#loc813) + %qk_174:3 = ttng.warp_group_dot_wait %qk, %q_56, %kT_172 {pendings = 0 : i32} : tensor<128x64xf32, #mma>, !ttg.memdesc<128x128xbf16, #shared, #smem>, !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc813) + %qk_175 = arith.mulf %qk_174#0, %cst_6 : tensor<128x64xf32, #mma> loc(#loc815) + %post_mod_scores = arith.mulf %qk_175, %cst_8 : tensor<128x64xf32, #mma> loc(#loc816) + %p_176 = arith.subf %post_mod_scores, %p : tensor<128x64xf32, #mma> loc(#loc817) + %p_177 = math.exp2 %p_176 : tensor<128x64xf32, #mma> loc(#loc818) + %vT_178 = ttg.memdesc_index %vT_131[%vT_ptrs_170] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc884) + %dp = ttng.warp_group_dot %do_62, %vT_178, %cst_5 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x128xbf16, #shared, #smem> * !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> -> tensor<128x64xf32, #mma> loc(#loc819) + %dp_179:3 = ttng.warp_group_dot_wait %dp, %do_62, %vT_178 {pendings = 0 : i32} : tensor<128x64xf32, #mma>, !ttg.memdesc<128x128xbf16, #shared, #smem>, !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc819) + %ds_180 = arith.subf %dp_179#0, %ds_92 : tensor<128x64xf32, #mma> loc(#loc820) + %ds_181 = arith.mulf %p_177, %ds_180 : tensor<128x64xf32, #mma> loc(#loc821) + %ds_182 = arith.truncf %ds_181 : tensor<128x64xf32, #mma> to tensor<128x64xbf16, #mma> loc(#loc822) + %ds_183 = ttg.convert_layout %ds_182 : tensor<128x64xbf16, #mma> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> loc(#loc822) + %dq_184 = ttng.warp_group_dot %ds_183, %dq_173, %vT_ptrs_157 {inputPrecision = 0 : i32, isAsync = true} : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> -> tensor<128x128xf32, #mma1> loc(#loc823) + %vT_ptrs_185 = arith.addi %vT_ptrs_156, %c1_i32 : i32 loc(#loc887) + %cur_block_idx = arith.divsi %vT_ptrs_185, %c2_i32 : i32 loc(#loc824) + %cur_block = tt.addptr %kv_indices_113, %cur_block_idx : !tt.ptr, i32 loc(#loc825) + %cur_block_186 = tt.load %cur_block, %vT_ptrs_167 evictionPolicy = evict_last : !tt.ptr loc(#loc826) + %next_block = arith.addi %cur_block_idx, %c1_i32 : i32 loc(#loc827) + %next_block_187 = arith.cmpi slt, %next_block, %sparse_kv_num_blocks_117 : i32 loc(#loc828) + %next_block_188 = tt.addptr %cur_block, %c1_i32 : !tt.ptr, i32 loc(#loc829) + %vT_ptrs_189 = arith.andi %vT_ptrs_167, %next_block_187 : i1 loc(#loc887) + %next_block_190 = tt.load %next_block_188, %vT_ptrs_189 evictionPolicy = evict_last : !tt.ptr loc(#loc830) + %needs_jump = arith.addi %vT_ptrs_156, %c2_i32 : i32 loc(#loc831) + %needs_jump_191 = arith.remsi %needs_jump, %c2_i32 : i32 loc(#loc832) + %needs_jump_192 = arith.cmpi eq, %needs_jump_191, %c0_i32 : i32 loc(#loc833) + %jump_to_block = arith.subi %next_block_190, %cur_block_186 : i32 loc(#loc834) + %jump_to_block_193 = arith.muli %jump_to_block, %c128_i32 : i32 loc(#loc835) + %jump_to_block_194 = arith.subi %jump_to_block_193, %c64_i32 : i32 loc(#loc836) + %offset = arith.extui %needs_jump_192 : i1 to i32 loc(#loc837) + %offset_195 = arith.muli %jump_to_block_194, %offset : i32 loc(#loc837) + %offset_196 = arith.subi %c1_i32, %offset : i32 loc(#loc838) + %offset_197 = arith.muli %offset_196, %c64_i32 : i32 loc(#loc839) + %offset_198 = arith.addi %offset_195, %offset_197 : i32 loc(#loc840) + %kT_ptrs_199 = arith.muli %offset_198, %c128_i32 : i32 loc(#loc622) + %kT_ptrs_200 = tt.splat %kT_ptrs_199 : i32 -> tensor<128x64xi32, #blocked1> loc(#loc619) + %kT_ptrs_201 = tt.addptr %kT_ptrs_158, %kT_ptrs_200 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc619) + %vT_ptrs_202 = tt.addptr %vT_ptrs_159, %kT_ptrs_200 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc620) + %vT_ptrs_203 = arith.addi %arg22, %c1_i32 : i32 loc(#loc887) + %vT_ptrs_204 = arith.cmpi sge, %vT_ptrs_203, %c3_i32 : i32 loc(#loc887) + %vT_ptrs_205 = arith.select %vT_ptrs_204, %c0_i32, %vT_ptrs_203 : i32 loc(#loc887) + %kT_206 = ttg.memdesc_index %kT_130[%vT_ptrs_205] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc883) + %vT_ptrs_207 = tt.splat %vT_ptrs_165 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc887) + %kT_208 = ttg.async_copy_global_to_local %kT_ptrs_201, %kT_206 mask %vT_ptrs_207 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc883) + %kT_209 = ttg.async_commit_group tokens %kT_208 loc(#loc883) + %vT_210 = ttg.memdesc_index %vT_131[%vT_ptrs_205] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc884) + %vT_211 = ttg.async_copy_global_to_local %vT_ptrs_202, %vT_210 mask %vT_ptrs_207 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc884) + %vT_212 = ttg.async_commit_group tokens %vT_211 loc(#loc884) + scf.yield %dq_184, %kT_ptrs_201, %vT_ptrs_202, %vT_ptrs_205, %vT_ptrs_170, %kT_161, %kT_209, %vT_163, %vT_212 : tensor<128x128xf32, #mma1>, tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64x!tt.ptr, #blocked1>, i32, i32, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token loc(#loc887) + } loc(#loc887) + %vT_ptrs_151 = ttng.warp_group_dot_wait %vT_ptrs_150#0 {pendings = 0 : i32} : tensor<128x128xf32, #mma1> loc(#loc887) + %vT_ptrs_152 = ttg.async_wait {num = 0 : i32} loc(#loc887) + ttg.local_dealloc %vT_131 : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> loc(#loc887) + ttg.local_dealloc %kT_130 : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> loc(#loc887) + %dq_ptrs = tt.splat %DQ2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> loc(#loc458) + %dq_ptrs_153 = tt.addptr %dq_ptrs, %ptr_48 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> loc(#loc458) + %dq_ptrs_154 = tt.broadcast %dq_ptrs_153 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x128x!tt.ptr, #blocked> loc(#loc459) + %dq_ptrs_155 = tt.addptr %dq_ptrs_154, %ptr_54 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> loc(#loc459) + %dq = arith.mulf %vT_ptrs_151, %cst_3 : tensor<128x128xf32, #mma1> loc(#loc460) + %1 = arith.truncf %dq : tensor<128x128xf32, #mma1> to tensor<128x128xbf16, #mma1> loc(#loc159) + %2 = ttg.convert_layout %1 : tensor<128x128xbf16, #mma1> -> tensor<128x128xbf16, #blocked> loc(#loc159) + tt.store %dq_ptrs_155, %2 : tensor<128x128x!tt.ptr, #blocked> loc(#loc159) + } else { + %start_n1 = arith.muli %pid, %c128_i32 : i32 loc(#loc461) + %offs_n1 = tt.splat %start_n1 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc462) + %offs_n1_31 = tt.splat %start_n1 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> loc(#loc462) + %offs_n1_32 = arith.addi %offs_n1, %offs_k : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc462) + %offs_n1_33 = arith.addi %offs_n1_31, %offs_k_30 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> loc(#loc462) + %ptr = tt.expand_dims %offs_n1_32 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> loc(#loc623) + %ptr_34 = tt.expand_dims %offs_n1_33 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xi32, #mma> loc(#loc623) + %ptr_35 = arith.muli %ptr, %cst_1 : tensor<128x1xi32, #blocked> loc(#loc624) + %ptr_36 = tt.splat %K : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> loc(#loc625) + %ptr_37 = tt.addptr %ptr_36, %ptr_35 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> loc(#loc625) + %ptr_38 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc626) + %ptr_39 = tt.expand_dims %ptr_38 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> loc(#loc626) + %ptr_40 = tt.broadcast %ptr_37 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x128x!tt.ptr, #blocked> loc(#loc627) + %ptr_41 = tt.broadcast %ptr_39 : tensor<1x128xi32, #blocked> -> tensor<128x128xi32, #blocked> loc(#loc627) + %ptr_42 = tt.addptr %ptr_40, %ptr_41 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> loc(#loc627) + %k = tt.load %ptr_42 : tensor<128x128x!tt.ptr, #blocked> loc(#loc628) + %k_43 = ttg.local_alloc %k : (tensor<128x128xbf16, #blocked>) -> !ttg.memdesc<128x128xbf16, #shared, #smem> loc(#loc628) + %ptr_44 = tt.splat %V : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> loc(#loc629) + %ptr_45 = tt.addptr %ptr_44, %ptr_35 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> loc(#loc629) + %ptr_46 = tt.broadcast %ptr_45 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x128x!tt.ptr, #blocked> loc(#loc630) + %ptr_47 = tt.addptr %ptr_46, %ptr_41 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> loc(#loc630) + %v = tt.load %ptr_47 : tensor<128x128x!tt.ptr, #blocked> loc(#loc631) + %v_48 = ttg.local_alloc %v : (tensor<128x128xbf16, #blocked>) -> !ttg.memdesc<128x128xbf16, #shared, #smem> loc(#loc631) + %off_hq1 = arith.muli %off_hkv, %c4_i32 : i32 loc(#loc465) + %q_adj1 = arith.muli %off_zq, %c8388608_i32 : i32 loc(#loc466) + %off_chz1 = arith.muli %off_zq, %c32_i32 : i32 loc(#loc467) + %sparse_q_num_blks_offset = arith.muli %off_zkv, %c16_i32 : i32 loc(#loc468) + %sparse_q_num_blks_offset_49 = arith.addi %sparse_q_num_blks_offset, %pid : i32 loc(#loc469) + %sparse_q_idx_offset = arith.muli %off_zkv, %c256_i32 : i32 loc(#loc470) + %sparse_q_idx_offset_50 = arith.muli %pid, %c16_i32 : i32 loc(#loc471) + %sparse_q_idx_offset_51 = arith.addi %sparse_q_idx_offset, %sparse_q_idx_offset_50 : i32 loc(#loc472) + %q_indices = tt.addptr %arg_Q_IDX, %sparse_q_idx_offset_51 : !tt.ptr, i32 loc(#loc473) + %q_start = tt.load %q_indices, %true : !tt.ptr loc(#loc474) + %q_start_52 = arith.muli %q_start, %c128_i32 : i32 loc(#loc475) + %sparse_q_num_blocks = tt.addptr %arg_Q_NUM_BLKS, %sparse_q_num_blks_offset_49 : !tt.ptr, i32 loc(#loc476) + %sparse_q_num_blocks_53 = tt.load %sparse_q_num_blocks, %true : !tt.ptr loc(#loc477) + %offs_m1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc478) + %offs_m1_54 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc478) + %offs_m1_55 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc478) + %offs_m1_56 = tt.splat %q_start_52 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc479) + %offs_m1_57 = tt.splat %q_start_52 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc479) + %offs_m1_58 = tt.splat %q_start_52 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc479) + %offs_m1_59 = arith.addi %offs_m1_56, %offs_m1 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc479) + %offs_m1_60 = arith.addi %offs_m1_57, %offs_m1_54 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc479) + %offs_m1_61 = arith.addi %offs_m1_58, %offs_m1_55 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc479) + %qT_ptrs = tt.expand_dims %offs_m1_59 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc632) + %qT_ptrs_62 = arith.muli %qT_ptrs, %cst_15 : tensor<1x64xi32, #blocked1> loc(#loc633) + %qT_ptrs_63 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc634) + %qT_ptrs_64 = tt.expand_dims %qT_ptrs_63 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc634) + %qT_ptrs_65 = tt.broadcast %qT_ptrs_64 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc635) + %do_ptrs = tt.expand_dims %offs_m1_61 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc636) + %do_ptrs_66 = arith.muli %do_ptrs, %cst_14 : tensor<64x1xi32, #blocked> loc(#loc637) + %do_ptrs_67 = tt.broadcast %ptr_39 : tensor<1x128xi32, #blocked> -> tensor<64x128xi32, #blocked> loc(#loc638) + %hi = arith.muli %sparse_q_num_blocks_53, %c2_i32 : i32 loc(#loc639) + %hi_68 = arith.minsi %hi, %c32_i32 : i32 loc(#loc640) + %tmp44 = tt.broadcast %ptr_34 : tensor<128x1xi32, #mma> -> tensor<128x64xi32, #mma> loc(#loc641) + %tmp45 = arith.extsi %ptr_34 : tensor<128x1xi32, #mma> to tensor<128x1xi64, #mma> loc(#loc642) + %tmp47 = tt.addptr %in_ptr16, %off_zq : !tt.ptr, i32 loc(#loc643) + %do_ptrs_69 = arith.cmpi sgt, %hi_68, %c0_i32 : i32 loc(#loc889) + %tmp47_70 = tt.load %tmp47, %do_ptrs_69 : !tt.ptr loc(#loc645) + %tmp48 = tt.splat %tmp47_70 : i64 -> tensor<128x1xi64, #mma> loc(#loc646) + %tmp48_71 = arith.cmpi slt, %tmp45, %tmp48 : tensor<128x1xi64, #mma> loc(#loc646) + %tmp50 = tt.splat %tmp47_70 : i64 -> tensor<1x64xi64, #mma> loc(#loc647) + %tmp51 = tt.broadcast %tmp48_71 : tensor<128x1xi1, #mma> -> tensor<128x64xi1, #mma> loc(#loc648) + %tmp55 = arith.cmpi sge, %ptr_34, %cst_0 : tensor<128x1xi32, #mma> loc(#loc649) + %tmp56 = arith.remsi %ptr_34, %cst_0 : tensor<128x1xi32, #mma> loc(#loc650) + %tmp58 = arith.cmpi ne, %tmp56, %cst : tensor<128x1xi32, #mma> loc(#loc651) + %tmp59 = arith.cmpi slt, %tmp56, %cst : tensor<128x1xi32, #mma> loc(#loc652) + %tmp62 = arith.andi %tmp58, %tmp59 : tensor<128x1xi1, #mma> loc(#loc653) + %tmp63 = arith.addi %tmp56, %cst_0 : tensor<128x1xi32, #mma> loc(#loc654) + %tmp64 = arith.select %tmp62, %tmp63, %tmp56 : tensor<128x1xi1, #mma>, tensor<128x1xi32, #mma> loc(#loc655) + %tmp65 = arith.extsi %tmp64 : tensor<128x1xi32, #mma> to tensor<128x1xi64, #mma> loc(#loc656) + %tmp66 = arith.cmpi slt, %tmp65, %tmp48 : tensor<128x1xi64, #mma> loc(#loc657) + %tmp67 = arith.andi %tmp55, %tmp66 : tensor<128x1xi1, #mma> loc(#loc658) + %tmp77 = tt.broadcast %tmp67 : tensor<128x1xi1, #mma> -> tensor<128x64xi1, #mma> loc(#loc659) + %q_indices_72 = tt.addptr %arg_FULL_Q_IDX, %sparse_q_idx_offset_51 : !tt.ptr, i32 loc(#loc509) + %q_start_73 = tt.load %q_indices_72, %true : !tt.ptr loc(#loc510) + %q_start_74 = arith.muli %q_start_73, %c128_i32 : i32 loc(#loc511) + %sparse_q_num_blocks_75 = tt.addptr %arg_FULL_Q_NUM_BLKS, %sparse_q_num_blks_offset_49 : !tt.ptr, i32 loc(#loc512) + %sparse_q_num_blocks_76 = tt.load %sparse_q_num_blocks_75, %true : !tt.ptr loc(#loc513) + %offs_m1_77 = tt.splat %q_start_74 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc514) + %offs_m1_78 = tt.splat %q_start_74 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc514) + %offs_m1_79 = tt.splat %q_start_74 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc514) + %offs_m1_80 = arith.addi %offs_m1_77, %offs_m1 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc514) + %offs_m1_81 = arith.addi %offs_m1_78, %offs_m1_54 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc514) + %offs_m1_82 = arith.addi %offs_m1_79, %offs_m1_55 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc514) + %qT_ptrs_83 = tt.expand_dims %offs_m1_80 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc660) + %qT_ptrs_84 = arith.muli %qT_ptrs_83, %cst_15 : tensor<1x64xi32, #blocked1> loc(#loc661) + %do_ptrs_85 = tt.expand_dims %offs_m1_82 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc662) + %do_ptrs_86 = arith.muli %do_ptrs_85, %cst_14 : tensor<64x1xi32, #blocked> loc(#loc663) + %hi_87 = arith.muli %sparse_q_num_blocks_76, %c2_i32 : i32 loc(#loc664) + %hi_88 = arith.minsi %hi_87, %c32_i32 : i32 loc(#loc665) + ttng.fence_async_shared {bCluster = false} loc(#loc666) + %dk:2 = scf.for %dk_98 = %c0_i32 to %c4_i32 step %c1_i32 iter_args(%arg19 = %cst_4, %arg20 = %cst_4) -> (tensor<128x128xf32, #mma1>, tensor<128x128xf32, #mma1>) : i32 { + %off_hq1_99 = arith.addi %off_hq1, %dk_98 : i32 loc(#loc517) + %q_adj1_100 = arith.muli %off_hq1_99, %c128_i32 : i32 loc(#loc518) + %q_adj1_101 = arith.addi %q_adj1_100, %q_adj1 : i32 loc(#loc519) + %q_adj1_102 = arith.extsi %q_adj1_101 : i32 to i64 loc(#loc520) + %do_adj1 = arith.muli %off_hq1_99, %c262144_i32 : i32 loc(#loc521) + %do_adj1_103 = arith.addi %do_adj1, %q_adj1 : i32 loc(#loc522) + %do_adj1_104 = arith.extsi %do_adj1_103 : i32 to i64 loc(#loc523) + %off_chz1_105 = arith.addi %off_chz1, %off_hq1_99 : i32 loc(#loc524) + %off_chz1_106 = arith.muli %off_chz1_105, %c2048_i32 : i32 loc(#loc525) + %off_chz1_107 = arith.extsi %off_chz1_106 : i32 to i64 loc(#loc526) + %Q1 = tt.addptr %arg_Q, %q_adj1_102 : !tt.ptr, i64 loc(#loc527) + %DO1 = tt.addptr %arg_DO, %do_adj1_104 : !tt.ptr, i64 loc(#loc528) + %LSE1 = tt.addptr %arg_LSE, %off_chz1_107 : !tt.ptr, i64 loc(#loc529) + %DELTA1 = tt.addptr %arg_DELTA, %off_chz1_107 : !tt.ptr, i64 loc(#loc530) + %qT_ptrs_108 = tt.splat %Q1 : !tt.ptr -> tensor<1x64x!tt.ptr, #blocked1> loc(#loc668) + %qT_ptrs_109 = tt.addptr %qT_ptrs_108, %qT_ptrs_62 : tensor<1x64x!tt.ptr, #blocked1>, tensor<1x64xi32, #blocked1> loc(#loc668) + %qT_ptrs_110 = tt.broadcast %qT_ptrs_109 : tensor<1x64x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> loc(#loc635) + %qT_ptrs_111 = tt.addptr %qT_ptrs_110, %qT_ptrs_65 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc635) + %do_ptrs_112 = tt.splat %DO1 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked> loc(#loc669) + %do_ptrs_113 = tt.addptr %do_ptrs_112, %do_ptrs_66 : tensor<64x1x!tt.ptr, #blocked>, tensor<64x1xi32, #blocked> loc(#loc669) + %do_ptrs_114 = tt.broadcast %do_ptrs_113 : tensor<64x1x!tt.ptr, #blocked> -> tensor<64x128x!tt.ptr, #blocked> loc(#loc638) + %do_ptrs_115 = tt.addptr %do_ptrs_114, %do_ptrs_67 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> loc(#loc638) + %lse = tt.splat %LSE1 : !tt.ptr -> tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc670) + %Di = tt.splat %DELTA1 : !tt.ptr -> tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc671) + %qT = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> loc(#loc842) + %lse_116 = ttg.local_alloc : () -> !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> loc(#loc673) + %do = ttg.local_alloc : () -> !ttg.memdesc<3x64x128xbf16, #shared, #smem, mutable> loc(#loc843) + %Di_117 = ttg.local_alloc : () -> !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> loc(#loc675) + %qT_118 = ttg.memdesc_index %qT[%c0_i32] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc842) + %do_ptrs_119 = tt.splat %do_ptrs_69 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc889) + %qT_120 = ttg.async_copy_global_to_local %qT_ptrs_111, %qT_118 mask %do_ptrs_119 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc842) + %qT_121 = ttg.async_commit_group tokens %qT_120 loc(#loc842) + %lse_122 = tt.addptr %lse, %offs_m1_60 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc670) + %lse_123 = ttg.memdesc_index %lse_116[%c0_i32] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc673) + %do_ptrs_124 = tt.splat %do_ptrs_69 : i1 -> tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc889) + %lse_125 = ttg.async_copy_global_to_local %lse_122, %lse_123 mask %do_ptrs_124 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma}>> -> <64xf32, #shared2, #smem, mutable, 2x64> loc(#loc673) + %lse_126 = ttg.async_commit_group tokens %lse_125 loc(#loc673) + %do_127 = ttg.memdesc_index %do[%c0_i32] : !ttg.memdesc<3x64x128xbf16, #shared, #smem, mutable> -> !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc843) + %do_ptrs_128 = tt.splat %do_ptrs_69 : i1 -> tensor<64x128xi1, #blocked> loc(#loc889) + %do_129 = ttg.async_copy_global_to_local %do_ptrs_115, %do_127 mask %do_ptrs_128 : tensor<64x128x!tt.ptr, #blocked> -> <64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc843) + %do_130 = ttg.async_commit_group tokens %do_129 loc(#loc843) + %Di_131 = tt.addptr %Di, %offs_m1_60 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc671) + %Di_132 = ttg.memdesc_index %Di_117[%c0_i32] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc675) + %Di_133 = ttg.async_copy_global_to_local %Di_131, %Di_132 mask %do_ptrs_124 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma}>> -> <64xf32, #shared2, #smem, mutable, 2x64> loc(#loc675) + %Di_134 = ttg.async_commit_group tokens %Di_133 loc(#loc675) + %do_ptrs_135 = arith.cmpi sgt, %hi_68, %c1_i32 : i32 loc(#loc889) + %qT_ptrs_136 = tt.addptr %qT_ptrs_111, %cst_10 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc676) + %do_ptrs_137 = tt.addptr %do_ptrs_115, %cst_11 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> loc(#loc677) + %offs_m1_138 = arith.addi %offs_m1_60, %cst_12 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc678) + %qT_139 = ttg.memdesc_index %qT[%c1_i32] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc842) + %do_ptrs_140 = tt.splat %do_ptrs_135 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc889) + %qT_141 = ttg.async_copy_global_to_local %qT_ptrs_136, %qT_139 mask %do_ptrs_140 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc842) + %qT_142 = ttg.async_commit_group tokens %qT_141 loc(#loc842) + %lse_143 = tt.addptr %lse, %offs_m1_138 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc670) + %lse_144 = ttg.memdesc_index %lse_116[%c1_i32] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc673) + %do_ptrs_145 = tt.splat %do_ptrs_135 : i1 -> tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc889) + %lse_146 = ttg.async_copy_global_to_local %lse_143, %lse_144 mask %do_ptrs_145 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma}>> -> <64xf32, #shared2, #smem, mutable, 2x64> loc(#loc673) + %lse_147 = ttg.async_commit_group tokens %lse_146 loc(#loc673) + %do_148 = ttg.memdesc_index %do[%c1_i32] : !ttg.memdesc<3x64x128xbf16, #shared, #smem, mutable> -> !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc843) + %do_ptrs_149 = tt.splat %do_ptrs_135 : i1 -> tensor<64x128xi1, #blocked> loc(#loc889) + %do_150 = ttg.async_copy_global_to_local %do_ptrs_137, %do_148 mask %do_ptrs_149 : tensor<64x128x!tt.ptr, #blocked> -> <64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc843) + %do_151 = ttg.async_commit_group tokens %do_150 loc(#loc843) + %Di_152 = tt.addptr %Di, %offs_m1_138 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc671) + %Di_153 = ttg.memdesc_index %Di_117[%c1_i32] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc675) + %Di_154 = ttg.async_copy_global_to_local %Di_152, %Di_153 mask %do_ptrs_145 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma}>> -> <64xf32, #shared2, #smem, mutable, 2x64> loc(#loc675) + %Di_155 = ttg.async_commit_group tokens %Di_154 loc(#loc675) + %do_ptrs_156:19 = scf.for %do_ptrs_211 = %c0_i32 to %hi_68 step %c1_i32 iter_args(%arg22 = %arg20, %arg23 = %arg19, %qT_ptrs_212 = %qT_ptrs_136, %do_ptrs_213 = %do_ptrs_137, %offs_m1_214 = %offs_m1_138, %arg27 = %c1_i32, %arg28 = %c-1_i32, %arg29 = %c1_i32, %arg30 = %c-1_i32, %qT_215 = %qT_121, %qT_216 = %qT_142, %lse_217 = %lse_126, %lse_218 = %lse_147, %do_219 = %do_130, %do_220 = %do_151, %Di_221 = %Di_134, %Di_222 = %Di_155, %arg39 = %c64_i32, %offs_m1_223 = %offs_m1_60) -> (tensor<128x128xf32, #mma1>, tensor<128x128xf32, #mma1>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>>, i32, i32, i32, i32, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, i32, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>>) : i32 { + %do_ptrs_224 = arith.subi %hi_68, %c2_i32 : i32 loc(#loc889) + %do_ptrs_225 = arith.cmpi slt, %do_ptrs_211, %do_ptrs_224 : i32 loc(#loc889) + %do_ptrs_226 = arith.subi %hi_68, %c1_i32 : i32 loc(#loc889) + %do_ptrs_227 = arith.cmpi slt, %do_ptrs_211, %do_ptrs_226 : i32 loc(#loc889) + %do_ptrs_228 = arith.addi %arg30, %c1_i32 : i32 loc(#loc889) + %do_ptrs_229 = arith.cmpi sge, %do_ptrs_228, %c2_i32 : i32 loc(#loc889) + %do_ptrs_230 = arith.select %do_ptrs_229, %c0_i32, %do_ptrs_228 : i32 loc(#loc889) + %do_ptrs_231 = arith.addi %arg28, %c1_i32 : i32 loc(#loc889) + %do_ptrs_232 = arith.cmpi sge, %do_ptrs_231, %c3_i32 : i32 loc(#loc889) + %do_ptrs_233 = arith.select %do_ptrs_232, %c0_i32, %do_ptrs_231 : i32 loc(#loc889) + %qT_234 = ttg.async_wait %qT_215, %lse_217, %do_219, %Di_221 {num = 4 : i32} loc(#loc842) + %qT_235 = ttg.memdesc_index %qT[%do_ptrs_233] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc842) + %dk_236 = ttg.memdesc_trans %qT_235 {order = array} : !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> -> !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc679) + %lse_237 = ttg.memdesc_index %lse_116[%do_ptrs_230] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc673) + %lse_238 = ttg.local_load %lse_237 token %qT_234 : !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> -> tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc673) + %lse_239 = arith.cmpf oeq, %lse_238, %cst_19 : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc680) + %lse_240 = arith.select %lse_239, %cst_20, %lse_238 : tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma}>>, tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc681) + %qkT = ttng.warp_group_dot %k_43, %qT_235, %cst_5 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x128xbf16, #shared, #smem> * !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> -> tensor<128x64xf32, #mma> loc(#loc666) + %qkT_241:3 = ttng.warp_group_dot_wait %qkT, %k_43, %qT_235 {pendings = 0 : i32} : tensor<128x64xf32, #mma>, !ttg.memdesc<128x128xbf16, #shared, #smem>, !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc666) + %qkT_242 = arith.mulf %qkT_241#0, %cst_6 : tensor<128x64xf32, #mma> loc(#loc682) + %m = tt.expand_dims %offs_m1_223 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xi32, #mma> loc(#loc683) + %tmp44_243 = tt.broadcast %m : tensor<1x64xi32, #mma> -> tensor<128x64xi32, #mma> loc(#loc641) + %tmp44_244 = arith.cmpi sge, %tmp44_243, %tmp44 : tensor<128x64xi32, #mma> loc(#loc641) + %tmp49 = arith.extsi %m : tensor<1x64xi32, #mma> to tensor<1x64xi64, #mma> loc(#loc684) + %tmp50_245 = arith.cmpi slt, %tmp49, %tmp50 : tensor<1x64xi64, #mma> loc(#loc647) + %tmp51_246 = tt.broadcast %tmp50_245 : tensor<1x64xi1, #mma> -> tensor<128x64xi1, #mma> loc(#loc648) + %tmp51_247 = arith.andi %tmp51, %tmp51_246 : tensor<128x64xi1, #mma> loc(#loc648) + %tmp52 = arith.andi %tmp44_244, %tmp51_247 : tensor<128x64xi1, #mma> loc(#loc685) + %tmp68 = arith.subi %tmp44, %tmp44_243 : tensor<128x64xi32, #mma> loc(#loc686) + %tmp69 = arith.remsi %tmp68, %cst_18 : tensor<128x64xi32, #mma> loc(#loc687) + %tmp70 = arith.cmpi ne, %tmp69, %cst_21 : tensor<128x64xi32, #mma> loc(#loc688) + %tmp71 = arith.cmpi slt, %tmp69, %cst_21 : tensor<128x64xi32, #mma> loc(#loc689) + %tmp73 = arith.andi %tmp70, %tmp71 : tensor<128x64xi1, #mma> loc(#loc690) + %tmp74 = arith.addi %tmp69, %cst_18 : tensor<128x64xi32, #mma> loc(#loc691) + %tmp75 = arith.select %tmp73, %tmp74, %tmp69 : tensor<128x64xi1, #mma>, tensor<128x64xi32, #mma> loc(#loc692) + %tmp76 = arith.cmpi eq, %tmp75, %cst_21 : tensor<128x64xi32, #mma> loc(#loc693) + %tmp77_248 = arith.andi %tmp77, %tmp76 : tensor<128x64xi1, #mma> loc(#loc659) + %tmp78 = arith.ori %tmp52, %tmp77_248 : tensor<128x64xi1, #mma> loc(#loc694) + %post_mod_scores = arith.select %tmp78, %qkT_242, %cst_7 : tensor<128x64xi1, #mma>, tensor<128x64xf32, #mma> loc(#loc695) + %post_mod_scores_249 = arith.mulf %post_mod_scores, %cst_8 : tensor<128x64xf32, #mma> loc(#loc696) + %pT = tt.expand_dims %lse_240 {axis = 0 : i32} : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xf32, #mma> loc(#loc697) + %pT_250 = tt.broadcast %pT : tensor<1x64xf32, #mma> -> tensor<128x64xf32, #mma> loc(#loc698) + %pT_251 = arith.subf %post_mod_scores_249, %pT_250 : tensor<128x64xf32, #mma> loc(#loc698) + %pT_252 = math.exp2 %pT_251 : tensor<128x64xf32, #mma> loc(#loc699) + %do_253 = ttg.memdesc_index %do[%do_ptrs_233] : !ttg.memdesc<3x64x128xbf16, #shared, #smem, mutable> -> !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc843) + %dpT = ttg.memdesc_trans %do_253 {order = array} : !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc700) + %dv = arith.truncf %pT_252 : tensor<128x64xf32, #mma> to tensor<128x64xbf16, #mma> loc(#loc701) + %dv_254 = ttg.convert_layout %dv : tensor<128x64xbf16, #mma> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> loc(#loc701) + %dv_255 = ttng.warp_group_dot %dv_254, %do_253, %arg23 {inputPrecision = 0 : i32, isAsync = true} : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> -> tensor<128x128xf32, #mma1> loc(#loc702) + %Di_256 = ttg.memdesc_index %Di_117[%do_ptrs_230] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc675) + %Di_257 = ttg.local_load %Di_256 token %qT_234 : !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> -> tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc675) + %dpT_258 = ttng.warp_group_dot %v_48, %dpT, %cst_5 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x128xbf16, #shared, #smem> * !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> -> tensor<128x64xf32, #mma> loc(#loc703) + %dpT_259:3 = ttng.warp_group_dot_wait %dpT_258, %v_48, %dpT {pendings = 0 : i32} : tensor<128x64xf32, #mma>, !ttg.memdesc<128x128xbf16, #shared, #smem>, !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc703) + %dsT = tt.expand_dims %Di_257 {axis = 0 : i32} : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xf32, #mma> loc(#loc704) + %dsT_260 = tt.broadcast %dsT : tensor<1x64xf32, #mma> -> tensor<128x64xf32, #mma> loc(#loc705) + %dsT_261 = arith.subf %dpT_259#0, %dsT_260 : tensor<128x64xf32, #mma> loc(#loc705) + %dsT_262 = arith.mulf %pT_252, %dsT_261 : tensor<128x64xf32, #mma> loc(#loc706) + %dsT_263 = arith.select %tmp78, %dsT_262, %cst_5 : tensor<128x64xi1, #mma>, tensor<128x64xf32, #mma> loc(#loc707) + %dk_264 = arith.truncf %dsT_263 : tensor<128x64xf32, #mma> to tensor<128x64xbf16, #mma> loc(#loc708) + %dk_265 = ttg.convert_layout %dk_264 : tensor<128x64xbf16, #mma> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> loc(#loc708) + %dk_266 = ttng.warp_group_dot %dk_265, %dk_236, %arg22 {inputPrecision = 0 : i32, isAsync = true} : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> -> tensor<128x128xf32, #mma1> loc(#loc709) + %offs_m1_267 = tt.splat %arg39 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc678) + %offs_m1_268 = arith.addi %offs_m1_223, %offs_m1_267 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc678) + %do_ptrs_269 = arith.addi %do_ptrs_211, %c1_i32 : i32 loc(#loc889) + %cur_block_idx = arith.divsi %do_ptrs_269, %c2_i32 : i32 loc(#loc844) + %cur_block = tt.addptr %q_indices, %cur_block_idx : !tt.ptr, i32 loc(#loc845) + %cur_block_270 = tt.load %cur_block, %do_ptrs_227 evictionPolicy = evict_last : !tt.ptr loc(#loc846) + %next_block = arith.addi %cur_block_idx, %c1_i32 : i32 loc(#loc847) + %next_block_271 = arith.cmpi slt, %next_block, %sparse_q_num_blocks_53 : i32 loc(#loc848) + %next_block_272 = tt.addptr %cur_block, %c1_i32 : !tt.ptr, i32 loc(#loc849) + %do_ptrs_273 = arith.andi %do_ptrs_227, %next_block_271 : i1 loc(#loc889) + %next_block_274 = tt.load %next_block_272, %do_ptrs_273 evictionPolicy = evict_last : !tt.ptr loc(#loc850) + %needs_jump = arith.addi %do_ptrs_211, %c2_i32 : i32 loc(#loc851) + %needs_jump_275 = arith.remsi %needs_jump, %c2_i32 : i32 loc(#loc852) + %needs_jump_276 = arith.cmpi eq, %needs_jump_275, %c0_i32 : i32 loc(#loc853) + %jump_to_block = arith.subi %next_block_274, %cur_block_270 : i32 loc(#loc854) + %jump_to_block_277 = arith.muli %jump_to_block, %c128_i32 : i32 loc(#loc855) + %jump_to_block_278 = arith.subi %jump_to_block_277, %c64_i32 : i32 loc(#loc856) + %offset = arith.extui %needs_jump_276 : i1 to i32 loc(#loc857) + %offset_279 = arith.muli %jump_to_block_278, %offset : i32 loc(#loc857) + %offset_280 = arith.subi %c1_i32, %offset : i32 loc(#loc858) + %offset_281 = arith.muli %offset_280, %c64_i32 : i32 loc(#loc859) + %offset_282 = arith.addi %offset_279, %offset_281 : i32 loc(#loc860) + %qT_ptrs_283 = arith.muli %offset_282, %c4096_i32 : i32 loc(#loc711) + %qT_ptrs_284 = tt.splat %qT_ptrs_283 : i32 -> tensor<128x64xi32, #blocked1> loc(#loc676) + %qT_ptrs_285 = tt.addptr %qT_ptrs_212, %qT_ptrs_284 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc676) + %do_ptrs_286 = arith.muli %offset_282, %c128_i32 : i32 loc(#loc712) + %do_ptrs_287 = tt.splat %do_ptrs_286 : i32 -> tensor<64x128xi32, #blocked> loc(#loc677) + %do_ptrs_288 = tt.addptr %do_ptrs_213, %do_ptrs_287 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> loc(#loc677) + %offs_m1_289 = tt.splat %offset_282 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc678) + %offs_m1_290 = arith.addi %offs_m1_214, %offs_m1_289 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc678) + %do_ptrs_291 = arith.addi %arg29, %c1_i32 : i32 loc(#loc889) + %do_ptrs_292 = arith.cmpi sge, %do_ptrs_291, %c2_i32 : i32 loc(#loc889) + %do_ptrs_293 = arith.select %do_ptrs_292, %c0_i32, %do_ptrs_291 : i32 loc(#loc889) + %do_ptrs_294 = arith.addi %arg27, %c1_i32 : i32 loc(#loc889) + %do_ptrs_295 = arith.cmpi sge, %do_ptrs_294, %c3_i32 : i32 loc(#loc889) + %do_ptrs_296 = arith.select %do_ptrs_295, %c0_i32, %do_ptrs_294 : i32 loc(#loc889) + %qT_297 = ttg.memdesc_index %qT[%do_ptrs_296] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc842) + %do_ptrs_298 = tt.splat %do_ptrs_225 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc889) + %qT_299 = ttg.async_copy_global_to_local %qT_ptrs_285, %qT_297 mask %do_ptrs_298 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc842) + %qT_300 = ttg.async_commit_group tokens %qT_299 loc(#loc842) + %lse_301 = tt.addptr %lse, %offs_m1_290 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc670) + %lse_302 = ttg.memdesc_index %lse_116[%do_ptrs_293] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc673) + %do_ptrs_303 = tt.splat %do_ptrs_225 : i1 -> tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc889) + %lse_304 = ttg.async_copy_global_to_local %lse_301, %lse_302 mask %do_ptrs_303 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma}>> -> <64xf32, #shared2, #smem, mutable, 2x64> loc(#loc673) + %lse_305 = ttg.async_commit_group tokens %lse_304 loc(#loc673) + %do_306 = ttg.memdesc_index %do[%do_ptrs_296] : !ttg.memdesc<3x64x128xbf16, #shared, #smem, mutable> -> !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc843) + %do_ptrs_307 = tt.splat %do_ptrs_225 : i1 -> tensor<64x128xi1, #blocked> loc(#loc889) + %do_308 = ttg.async_copy_global_to_local %do_ptrs_288, %do_306 mask %do_ptrs_307 : tensor<64x128x!tt.ptr, #blocked> -> <64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc843) + %do_309 = ttg.async_commit_group tokens %do_308 loc(#loc843) + %Di_310 = tt.addptr %Di, %offs_m1_290 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc671) + %Di_311 = ttg.memdesc_index %Di_117[%do_ptrs_293] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc675) + %Di_312 = ttg.async_copy_global_to_local %Di_310, %Di_311 mask %do_ptrs_303 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma}>> -> <64xf32, #shared2, #smem, mutable, 2x64> loc(#loc675) + %Di_313 = ttg.async_commit_group tokens %Di_312 loc(#loc675) + scf.yield %dk_266, %dv_255, %qT_ptrs_285, %do_ptrs_288, %offs_m1_290, %do_ptrs_296, %do_ptrs_233, %do_ptrs_293, %do_ptrs_230, %qT_216, %qT_300, %lse_218, %lse_305, %do_220, %do_309, %Di_222, %Di_313, %offset_282, %offs_m1_268 : tensor<128x128xf32, #mma1>, tensor<128x128xf32, #mma1>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>>, i32, i32, i32, i32, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, i32, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc889) + } loc(#loc889) + %do_ptrs_157:2 = ttng.warp_group_dot_wait %do_ptrs_156#1, %do_ptrs_156#0 {pendings = 0 : i32} : tensor<128x128xf32, #mma1>, tensor<128x128xf32, #mma1> loc(#loc889) + %do_ptrs_158 = ttg.async_wait {num = 0 : i32} loc(#loc889) + ttg.local_dealloc %Di_117 : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> loc(#loc889) + ttg.local_dealloc %do : !ttg.memdesc<3x64x128xbf16, #shared, #smem, mutable> loc(#loc889) + ttg.local_dealloc %lse_116 : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> loc(#loc889) + ttg.local_dealloc %qT : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> loc(#loc889) + %qT_ptrs_159 = tt.addptr %qT_ptrs_108, %qT_ptrs_84 : tensor<1x64x!tt.ptr, #blocked1>, tensor<1x64xi32, #blocked1> loc(#loc713) + %qT_ptrs_160 = tt.broadcast %qT_ptrs_159 : tensor<1x64x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> loc(#loc714) + %qT_ptrs_161 = tt.addptr %qT_ptrs_160, %qT_ptrs_65 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc714) + %do_ptrs_162 = tt.addptr %do_ptrs_112, %do_ptrs_86 : tensor<64x1x!tt.ptr, #blocked>, tensor<64x1xi32, #blocked> loc(#loc715) + %do_ptrs_163 = tt.broadcast %do_ptrs_162 : tensor<64x1x!tt.ptr, #blocked> -> tensor<64x128x!tt.ptr, #blocked> loc(#loc716) + %do_ptrs_164 = tt.addptr %do_ptrs_163, %do_ptrs_67 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> loc(#loc716) + %qT_165 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> loc(#loc861) + %lse_166 = ttg.local_alloc : () -> !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> loc(#loc718) + %do_167 = ttg.local_alloc : () -> !ttg.memdesc<3x64x128xbf16, #shared, #smem, mutable> loc(#loc862) + %Di_168 = ttg.local_alloc : () -> !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> loc(#loc720) + %do_ptrs_169 = arith.cmpi sgt, %hi_88, %c0_i32 : i32 loc(#loc890) + %qT_170 = ttg.memdesc_index %qT_165[%c0_i32] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc861) + %do_ptrs_171 = tt.splat %do_ptrs_169 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc890) + %qT_172 = ttg.async_copy_global_to_local %qT_ptrs_161, %qT_170 mask %do_ptrs_171 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc861) + %qT_173 = ttg.async_commit_group tokens %qT_172 loc(#loc861) + %lse_174 = tt.addptr %lse, %offs_m1_81 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc721) + %lse_175 = ttg.memdesc_index %lse_166[%c0_i32] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc718) + %do_ptrs_176 = tt.splat %do_ptrs_169 : i1 -> tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc890) + %lse_177 = ttg.async_copy_global_to_local %lse_174, %lse_175 mask %do_ptrs_176 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma}>> -> <64xf32, #shared2, #smem, mutable, 2x64> loc(#loc718) + %lse_178 = ttg.async_commit_group tokens %lse_177 loc(#loc718) + %do_179 = ttg.memdesc_index %do_167[%c0_i32] : !ttg.memdesc<3x64x128xbf16, #shared, #smem, mutable> -> !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc862) + %do_ptrs_180 = tt.splat %do_ptrs_169 : i1 -> tensor<64x128xi1, #blocked> loc(#loc890) + %do_181 = ttg.async_copy_global_to_local %do_ptrs_164, %do_179 mask %do_ptrs_180 : tensor<64x128x!tt.ptr, #blocked> -> <64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc862) + %do_182 = ttg.async_commit_group tokens %do_181 loc(#loc862) + %Di_183 = tt.addptr %Di, %offs_m1_81 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc722) + %Di_184 = ttg.memdesc_index %Di_168[%c0_i32] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc720) + %Di_185 = ttg.async_copy_global_to_local %Di_183, %Di_184 mask %do_ptrs_176 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma}>> -> <64xf32, #shared2, #smem, mutable, 2x64> loc(#loc720) + %Di_186 = ttg.async_commit_group tokens %Di_185 loc(#loc720) + %do_ptrs_187 = arith.cmpi sgt, %hi_88, %c1_i32 : i32 loc(#loc890) + %qT_ptrs_188 = tt.addptr %qT_ptrs_161, %cst_10 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc723) + %do_ptrs_189 = tt.addptr %do_ptrs_164, %cst_11 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> loc(#loc724) + %offs_m1_190 = arith.addi %offs_m1_81, %cst_12 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc725) + %qT_191 = ttg.memdesc_index %qT_165[%c1_i32] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc861) + %do_ptrs_192 = tt.splat %do_ptrs_187 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc890) + %qT_193 = ttg.async_copy_global_to_local %qT_ptrs_188, %qT_191 mask %do_ptrs_192 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc861) + %qT_194 = ttg.async_commit_group tokens %qT_193 loc(#loc861) + %lse_195 = tt.addptr %lse, %offs_m1_190 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc721) + %lse_196 = ttg.memdesc_index %lse_166[%c1_i32] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc718) + %do_ptrs_197 = tt.splat %do_ptrs_187 : i1 -> tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc890) + %lse_198 = ttg.async_copy_global_to_local %lse_195, %lse_196 mask %do_ptrs_197 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma}>> -> <64xf32, #shared2, #smem, mutable, 2x64> loc(#loc718) + %lse_199 = ttg.async_commit_group tokens %lse_198 loc(#loc718) + %do_200 = ttg.memdesc_index %do_167[%c1_i32] : !ttg.memdesc<3x64x128xbf16, #shared, #smem, mutable> -> !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc862) + %do_ptrs_201 = tt.splat %do_ptrs_187 : i1 -> tensor<64x128xi1, #blocked> loc(#loc890) + %do_202 = ttg.async_copy_global_to_local %do_ptrs_189, %do_200 mask %do_ptrs_201 : tensor<64x128x!tt.ptr, #blocked> -> <64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc862) + %do_203 = ttg.async_commit_group tokens %do_202 loc(#loc862) + %Di_204 = tt.addptr %Di, %offs_m1_190 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc722) + %Di_205 = ttg.memdesc_index %Di_168[%c1_i32] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc720) + %Di_206 = ttg.async_copy_global_to_local %Di_204, %Di_205 mask %do_ptrs_197 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma}>> -> <64xf32, #shared2, #smem, mutable, 2x64> loc(#loc720) + %Di_207 = ttg.async_commit_group tokens %Di_206 loc(#loc720) + %do_ptrs_208:17 = scf.for %do_ptrs_211 = %c0_i32 to %hi_88 step %c1_i32 iter_args(%do_ptrs_212 = %do_ptrs_157#1, %do_ptrs_213 = %do_ptrs_157#0, %qT_ptrs_214 = %qT_ptrs_188, %do_ptrs_215 = %do_ptrs_189, %offs_m1_216 = %offs_m1_190, %arg27 = %c1_i32, %arg28 = %c-1_i32, %arg29 = %c1_i32, %arg30 = %c-1_i32, %qT_217 = %qT_173, %qT_218 = %qT_194, %lse_219 = %lse_178, %lse_220 = %lse_199, %do_221 = %do_182, %do_222 = %do_203, %Di_223 = %Di_186, %Di_224 = %Di_207) -> (tensor<128x128xf32, #mma1>, tensor<128x128xf32, #mma1>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>>, i32, i32, i32, i32, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token) : i32 { + %do_ptrs_225 = arith.subi %hi_88, %c2_i32 : i32 loc(#loc890) + %do_ptrs_226 = arith.cmpi slt, %do_ptrs_211, %do_ptrs_225 : i32 loc(#loc890) + %do_ptrs_227 = arith.subi %hi_88, %c1_i32 : i32 loc(#loc890) + %do_ptrs_228 = arith.cmpi slt, %do_ptrs_211, %do_ptrs_227 : i32 loc(#loc890) + %do_ptrs_229 = arith.addi %arg30, %c1_i32 : i32 loc(#loc890) + %do_ptrs_230 = arith.cmpi sge, %do_ptrs_229, %c2_i32 : i32 loc(#loc890) + %do_ptrs_231 = arith.select %do_ptrs_230, %c0_i32, %do_ptrs_229 : i32 loc(#loc890) + %do_ptrs_232 = arith.addi %arg28, %c1_i32 : i32 loc(#loc890) + %do_ptrs_233 = arith.cmpi sge, %do_ptrs_232, %c3_i32 : i32 loc(#loc890) + %do_ptrs_234 = arith.select %do_ptrs_233, %c0_i32, %do_ptrs_232 : i32 loc(#loc890) + %qT_235 = ttg.async_wait %qT_217, %lse_219, %do_221, %Di_223 {num = 4 : i32} loc(#loc861) + %qT_236 = ttg.memdesc_index %qT_165[%do_ptrs_234] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc861) + %dk_237 = ttg.memdesc_trans %qT_236 {order = array} : !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> -> !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc726) + %lse_238 = ttg.memdesc_index %lse_166[%do_ptrs_231] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc718) + %lse_239 = ttg.local_load %lse_238 token %qT_235 : !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> -> tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc718) + %lse_240 = arith.cmpf oeq, %lse_239, %cst_19 : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc727) + %lse_241 = arith.select %lse_240, %cst_20, %lse_239 : tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma}>>, tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc728) + %qkT = ttng.warp_group_dot %k_43, %qT_236, %cst_5 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x128xbf16, #shared, #smem> * !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> -> tensor<128x64xf32, #mma> loc(#loc729) + %qkT_242:3 = ttng.warp_group_dot_wait %qkT, %k_43, %qT_236 {pendings = 0 : i32} : tensor<128x64xf32, #mma>, !ttg.memdesc<128x128xbf16, #shared, #smem>, !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc729) + %qkT_243 = arith.mulf %qkT_242#0, %cst_6 : tensor<128x64xf32, #mma> loc(#loc730) + %post_mod_scores = arith.mulf %qkT_243, %cst_8 : tensor<128x64xf32, #mma> loc(#loc731) + %pT = tt.expand_dims %lse_241 {axis = 0 : i32} : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xf32, #mma> loc(#loc732) + %pT_244 = tt.broadcast %pT : tensor<1x64xf32, #mma> -> tensor<128x64xf32, #mma> loc(#loc733) + %pT_245 = arith.subf %post_mod_scores, %pT_244 : tensor<128x64xf32, #mma> loc(#loc733) + %pT_246 = math.exp2 %pT_245 : tensor<128x64xf32, #mma> loc(#loc734) + %do_247 = ttg.memdesc_index %do_167[%do_ptrs_234] : !ttg.memdesc<3x64x128xbf16, #shared, #smem, mutable> -> !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc862) + %dpT = ttg.memdesc_trans %do_247 {order = array} : !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc735) + %dv = arith.truncf %pT_246 : tensor<128x64xf32, #mma> to tensor<128x64xbf16, #mma> loc(#loc736) + %dv_248 = ttg.convert_layout %dv : tensor<128x64xbf16, #mma> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> loc(#loc736) + %dv_249 = ttng.warp_group_dot %dv_248, %do_247, %do_ptrs_213 {inputPrecision = 0 : i32, isAsync = true} : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> -> tensor<128x128xf32, #mma1> loc(#loc737) + %Di_250 = ttg.memdesc_index %Di_168[%do_ptrs_231] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc720) + %Di_251 = ttg.local_load %Di_250 token %qT_235 : !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> -> tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc720) + %dpT_252 = ttng.warp_group_dot %v_48, %dpT, %cst_5 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x128xbf16, #shared, #smem> * !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> -> tensor<128x64xf32, #mma> loc(#loc738) + %dpT_253:3 = ttng.warp_group_dot_wait %dpT_252, %v_48, %dpT {pendings = 0 : i32} : tensor<128x64xf32, #mma>, !ttg.memdesc<128x128xbf16, #shared, #smem>, !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc738) + %dsT = tt.expand_dims %Di_251 {axis = 0 : i32} : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xf32, #mma> loc(#loc739) + %dsT_254 = tt.broadcast %dsT : tensor<1x64xf32, #mma> -> tensor<128x64xf32, #mma> loc(#loc740) + %dsT_255 = arith.subf %dpT_253#0, %dsT_254 : tensor<128x64xf32, #mma> loc(#loc740) + %dsT_256 = arith.mulf %pT_246, %dsT_255 : tensor<128x64xf32, #mma> loc(#loc741) + %dk_257 = arith.truncf %dsT_256 : tensor<128x64xf32, #mma> to tensor<128x64xbf16, #mma> loc(#loc742) + %dk_258 = ttg.convert_layout %dk_257 : tensor<128x64xbf16, #mma> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> loc(#loc742) + %dk_259 = ttng.warp_group_dot %dk_258, %dk_237, %do_ptrs_212 {inputPrecision = 0 : i32, isAsync = true} : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> -> tensor<128x128xf32, #mma1> loc(#loc743) + %do_ptrs_260 = arith.addi %do_ptrs_211, %c1_i32 : i32 loc(#loc890) + %cur_block_idx = arith.divsi %do_ptrs_260, %c2_i32 : i32 loc(#loc863) + %cur_block = tt.addptr %q_indices_72, %cur_block_idx : !tt.ptr, i32 loc(#loc864) + %cur_block_261 = tt.load %cur_block, %do_ptrs_228 evictionPolicy = evict_last : !tt.ptr loc(#loc865) + %next_block = arith.addi %cur_block_idx, %c1_i32 : i32 loc(#loc866) + %next_block_262 = arith.cmpi slt, %next_block, %sparse_q_num_blocks_76 : i32 loc(#loc867) + %next_block_263 = tt.addptr %cur_block, %c1_i32 : !tt.ptr, i32 loc(#loc868) + %do_ptrs_264 = arith.andi %do_ptrs_228, %next_block_262 : i1 loc(#loc890) + %next_block_265 = tt.load %next_block_263, %do_ptrs_264 evictionPolicy = evict_last : !tt.ptr loc(#loc869) + %needs_jump = arith.addi %do_ptrs_211, %c2_i32 : i32 loc(#loc870) + %needs_jump_266 = arith.remsi %needs_jump, %c2_i32 : i32 loc(#loc871) + %needs_jump_267 = arith.cmpi eq, %needs_jump_266, %c0_i32 : i32 loc(#loc872) + %jump_to_block = arith.subi %next_block_265, %cur_block_261 : i32 loc(#loc873) + %jump_to_block_268 = arith.muli %jump_to_block, %c128_i32 : i32 loc(#loc874) + %jump_to_block_269 = arith.subi %jump_to_block_268, %c64_i32 : i32 loc(#loc875) + %offset = arith.extui %needs_jump_267 : i1 to i32 loc(#loc876) + %offset_270 = arith.muli %jump_to_block_269, %offset : i32 loc(#loc876) + %offset_271 = arith.subi %c1_i32, %offset : i32 loc(#loc877) + %offset_272 = arith.muli %offset_271, %c64_i32 : i32 loc(#loc878) + %offset_273 = arith.addi %offset_270, %offset_272 : i32 loc(#loc879) + %qT_ptrs_274 = arith.muli %offset_273, %c4096_i32 : i32 loc(#loc745) + %qT_ptrs_275 = tt.splat %qT_ptrs_274 : i32 -> tensor<128x64xi32, #blocked1> loc(#loc723) + %qT_ptrs_276 = tt.addptr %qT_ptrs_214, %qT_ptrs_275 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc723) + %do_ptrs_277 = arith.muli %offset_273, %c128_i32 : i32 loc(#loc746) + %do_ptrs_278 = tt.splat %do_ptrs_277 : i32 -> tensor<64x128xi32, #blocked> loc(#loc724) + %do_ptrs_279 = tt.addptr %do_ptrs_215, %do_ptrs_278 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> loc(#loc724) + %offs_m1_280 = tt.splat %offset_273 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc725) + %offs_m1_281 = arith.addi %offs_m1_216, %offs_m1_280 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc725) + %do_ptrs_282 = arith.addi %arg29, %c1_i32 : i32 loc(#loc890) + %do_ptrs_283 = arith.cmpi sge, %do_ptrs_282, %c2_i32 : i32 loc(#loc890) + %do_ptrs_284 = arith.select %do_ptrs_283, %c0_i32, %do_ptrs_282 : i32 loc(#loc890) + %do_ptrs_285 = arith.addi %arg27, %c1_i32 : i32 loc(#loc890) + %do_ptrs_286 = arith.cmpi sge, %do_ptrs_285, %c3_i32 : i32 loc(#loc890) + %do_ptrs_287 = arith.select %do_ptrs_286, %c0_i32, %do_ptrs_285 : i32 loc(#loc890) + %qT_288 = ttg.memdesc_index %qT_165[%do_ptrs_287] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc861) + %do_ptrs_289 = tt.splat %do_ptrs_226 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc890) + %qT_290 = ttg.async_copy_global_to_local %qT_ptrs_276, %qT_288 mask %do_ptrs_289 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc861) + %qT_291 = ttg.async_commit_group tokens %qT_290 loc(#loc861) + %lse_292 = tt.addptr %lse, %offs_m1_281 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc721) + %lse_293 = ttg.memdesc_index %lse_166[%do_ptrs_284] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc718) + %do_ptrs_294 = tt.splat %do_ptrs_226 : i1 -> tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc890) + %lse_295 = ttg.async_copy_global_to_local %lse_292, %lse_293 mask %do_ptrs_294 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma}>> -> <64xf32, #shared2, #smem, mutable, 2x64> loc(#loc718) + %lse_296 = ttg.async_commit_group tokens %lse_295 loc(#loc718) + %do_297 = ttg.memdesc_index %do_167[%do_ptrs_287] : !ttg.memdesc<3x64x128xbf16, #shared, #smem, mutable> -> !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc862) + %do_ptrs_298 = tt.splat %do_ptrs_226 : i1 -> tensor<64x128xi1, #blocked> loc(#loc890) + %do_299 = ttg.async_copy_global_to_local %do_ptrs_279, %do_297 mask %do_ptrs_298 : tensor<64x128x!tt.ptr, #blocked> -> <64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc862) + %do_300 = ttg.async_commit_group tokens %do_299 loc(#loc862) + %Di_301 = tt.addptr %Di, %offs_m1_281 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> loc(#loc722) + %Di_302 = ttg.memdesc_index %Di_168[%do_ptrs_284] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc720) + %Di_303 = ttg.async_copy_global_to_local %Di_301, %Di_302 mask %do_ptrs_294 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma}>> -> <64xf32, #shared2, #smem, mutable, 2x64> loc(#loc720) + %Di_304 = ttg.async_commit_group tokens %Di_303 loc(#loc720) + scf.yield %dk_259, %dv_249, %qT_ptrs_276, %do_ptrs_279, %offs_m1_281, %do_ptrs_287, %do_ptrs_234, %do_ptrs_284, %do_ptrs_231, %qT_218, %qT_291, %lse_220, %lse_296, %do_222, %do_300, %Di_224, %Di_304 : tensor<128x128xf32, #mma1>, tensor<128x128xf32, #mma1>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>>, i32, i32, i32, i32, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token loc(#loc890) + } loc(#loc890) + %do_ptrs_209:2 = ttng.warp_group_dot_wait %do_ptrs_208#1, %do_ptrs_208#0 {pendings = 0 : i32} : tensor<128x128xf32, #mma1>, tensor<128x128xf32, #mma1> loc(#loc890) + %do_ptrs_210 = ttg.async_wait {num = 0 : i32} loc(#loc890) + ttg.local_dealloc %Di_168 : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> loc(#loc890) + ttg.local_dealloc %do_167 : !ttg.memdesc<3x64x128xbf16, #shared, #smem, mutable> loc(#loc890) + ttg.local_dealloc %lse_166 : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> loc(#loc890) + ttg.local_dealloc %qT_165 : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> loc(#loc890) + scf.yield %do_ptrs_209#0, %do_ptrs_209#1 : tensor<128x128xf32, #mma1>, tensor<128x128xf32, #mma1> loc(#loc277) + } loc(#loc667) + %dv_ptrs = tt.splat %DV : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> loc(#loc577) + %dv_ptrs_89 = tt.addptr %dv_ptrs, %ptr_35 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> loc(#loc577) + %dv_ptrs_90 = tt.broadcast %dv_ptrs_89 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x128x!tt.ptr, #blocked> loc(#loc578) + %dv_ptrs_91 = tt.addptr %dv_ptrs_90, %ptr_41 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> loc(#loc578) + %1 = arith.truncf %dk#0 : tensor<128x128xf32, #mma1> to tensor<128x128xbf16, #mma1> loc(#loc280) + %2 = ttg.convert_layout %1 : tensor<128x128xbf16, #mma1> -> tensor<128x128xbf16, #blocked> loc(#loc280) + tt.store %dv_ptrs_91, %2 : tensor<128x128x!tt.ptr, #blocked> loc(#loc280) + %dk_92 = arith.mulf %dk#1, %cst_3 : tensor<128x128xf32, #mma1> loc(#loc579) + %mask = arith.cmpi slt, %ptr, %cst_13 : tensor<128x1xi32, #blocked> loc(#loc580) + %xindex = tt.broadcast %ptr_35 : tensor<128x1xi32, #blocked> -> tensor<128x128xi32, #blocked> loc(#loc581) + %xindex_93 = arith.addi %ptr_41, %xindex : tensor<128x128xi32, #blocked> loc(#loc581) + %xindex_94 = tt.splat %k_adj : i32 -> tensor<128x128xi32, #blocked> loc(#loc582) + %xindex_95 = arith.addi %xindex_93, %xindex_94 : tensor<128x128xi32, #blocked> loc(#loc582) + %xindex_96 = tt.splat %dv_adj : i32 -> tensor<128x128xi32, #blocked> loc(#loc583) + %xindex_97 = arith.addi %xindex_95, %xindex_96 : tensor<128x128xi32, #blocked> loc(#loc583) + %3 = tt.splat %out_ptr0 : !tt.ptr -> tensor<128x128x!tt.ptr, #blocked> loc(#loc286) + %4 = tt.addptr %3, %xindex_97 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> loc(#loc286) + %5 = tt.broadcast %mask : tensor<128x1xi1, #blocked> -> tensor<128x128xi1, #blocked> loc(#loc287) + %6 = arith.truncf %dk_92 : tensor<128x128xf32, #mma1> to tensor<128x128xbf16, #mma1> loc(#loc287) + %7 = ttg.convert_layout %6 : tensor<128x128xbf16, #mma1> -> tensor<128x128xbf16, #blocked> loc(#loc287) + tt.store %4, %7, %5 : tensor<128x128x!tt.ptr, #blocked> loc(#loc287) + } loc(#loc18) + tt.return loc(#loc288) + } loc(#loc) +} loc(#loc) +#loc1 = loc(unknown) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":111:24) +#loc3 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":115:27) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":116:28) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":117:23) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":124:25) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":124:47) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":124:35) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":124:59) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":128:50) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":128:37) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":128:61) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":131:9) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":132:9) +#loc15 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":133:10) +#loc16 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":136:26) +#loc17 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":139:14) +#loc18 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":139:7) +#loc19 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":140:24) +#loc20 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":144:29) +#loc21 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":144:54) +#loc22 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":144:44) +#loc23 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":145:35) +#loc24 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":154:55) +#loc25 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":154:78) +#loc26 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":155:50) +#loc27 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":155:83) +#loc28 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":155:68) +#loc29 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":158:30) +#loc30 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":158:52) +#loc31 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":158:40) +#loc32 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":158:63) +#loc33 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":159:32) +#loc34 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":159:42) +#loc35 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":159:66) +#loc36 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":161:30) +#loc37 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":161:35) +#loc38 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":161:46) +#loc39 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":161:56) +#loc40 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":163:17) +#loc41 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":164:19) +#loc42 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":167:19) +#loc43 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":168:21) +#loc44 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":169:25) +#loc45 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":174:36) +#loc46 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":175:29) +#loc47 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":825:27) +#loc48 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":178:107) +#loc49 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":825:38) +#loc50 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":825:20) +#loc51 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":825:56) +#loc52 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":825:49) +#loc53 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":835:23) +#loc54 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":179:111) +#loc55 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":185:34) +#loc56 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":185:25) +#loc57 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":186:33) +#loc58 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":186:26) +#loc59 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":190:30) +#loc60 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":190:50) +#loc61 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":191:18) +#loc62 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":195:30) +#loc63 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":196:27) +#loc64 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":196:41) +#loc65 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":197:53) +#loc66 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":197:39) +#loc67 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":199:42) +#loc68 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":199:29) +#loc69 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":390:26) +#loc70 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":207:12) +#loc71 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":390:37) +#loc72 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":390:18) +#loc73 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":390:56) +#loc74 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":390:49) +#loc75 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":391:18) +#loc76 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":391:49) +#loc77 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":395:43) +#loc78 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":395:63) +#loc79 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":482:23) +#loc80 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":405:12) +#loc81 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":485:34) +#loc82 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":397:28) +#loc83 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":485:23) +#loc84 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":486:22) +#loc85 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":487:23) +#loc86 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":488:23) +#loc87 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":489:23) +#loc88 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":525:39) +#loc89 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":531:22) +#loc90 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":531:19) +#loc91 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":458:105) +#loc92 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":528:104) +#loc93 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":414:19) +#loc94 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":415:19) +#loc95 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":459:19) +#loc96 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":553:30) +#loc97 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":461:14) +#loc98 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":464:36) +#loc99 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":483:23) +#loc100 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":490:23) +#loc101 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":493:24) +#loc102 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":494:24) +#loc103 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":496:25) +#loc104 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":497:92) +#loc105 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":500:24) +#loc106 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":501:24) +#loc107 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":502:39) +#loc108 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":503:25) +#loc109 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":504:24) +#loc110 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":505:24) +#loc111 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":506:23) +#loc112 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":507:25) +#loc113 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":508:25) +#loc114 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":509:92) +#loc115 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":511:24) +#loc116 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":512:24) +#loc117 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":513:39) +#loc118 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":514:25) +#loc119 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":515:24) +#loc120 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":516:24) +#loc121 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":521:69) +#loc122 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":524:27) +#loc123 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":525:21) +#loc124 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":530:20) +#loc125 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":531:14) +#loc126 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":549:43) +#loc127 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":551:15) +#loc128 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":553:21) +#loc129 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":417:19) +#loc130 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":788:33) +#loc131 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":411:64) +#loc132 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":789:38) +#loc133 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":789:24) +#loc134 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":790:109) +#loc135 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":790:113) +#loc136 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":790:55) +#loc137 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":790:25) +#loc138 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":791:30) +#loc139 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":791:35) +#loc140 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":791:60) +#loc141 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":792:34) +#loc142 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":792:48) +#loc143 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":792:63) +#loc144 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":793:29) +#loc145 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":793:47) +#loc146 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":793:61) +#loc147 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":793:42) +#loc148 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":414:28) +#loc149 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":214:39) +#loc150 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":215:31) +#loc151 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":215:45) +#loc152 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":216:62) +#loc153 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":216:43) +#loc154 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":218:33) +#loc155 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":226:16) +#loc156 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":231:24) +#loc157 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":231:56) +#loc158 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":232:14) +#loc159 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":234:30) +#loc160 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":252:25) +#loc161 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":253:29) +#loc162 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":256:107) +#loc163 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":257:107) +#loc164 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":263:32) +#loc165 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":266:56) +#loc166 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":269:34) +#loc167 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":281:58) +#loc168 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":281:80) +#loc169 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":282:53) +#loc170 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":282:81) +#loc171 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":282:70) +#loc172 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":286:32) +#loc173 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":287:30) +#loc174 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":287:43) +#loc175 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":288:55) +#loc176 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":288:42) +#loc177 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":290:45) +#loc178 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":290:32) +#loc179 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":601:26) +#loc180 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":298:16) +#loc181 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":601:37) +#loc182 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":601:56) +#loc183 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":601:49) +#loc184 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":602:27) +#loc185 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":602:38) +#loc186 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":602:51) +#loc187 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":608:42) +#loc188 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":608:61) +#loc189 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":698:25) +#loc190 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":618:12) +#loc191 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":699:25) +#loc192 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":701:35) +#loc193 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":610:28) +#loc194 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":701:24) +#loc195 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":702:24) +#loc196 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":704:24) +#loc197 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":705:24) +#loc198 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":709:25) +#loc199 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":710:25) +#loc200 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":712:25) +#loc201 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":713:92) +#loc202 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":716:24) +#loc203 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":717:24) +#loc204 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":718:39) +#loc205 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":719:25) +#loc206 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":720:24) +#loc207 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":721:24) +#loc208 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":731:24) +#loc209 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":306:41) +#loc210 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":307:34) +#loc211 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":307:47) +#loc212 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":308:64) +#loc213 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":308:46) +#loc214 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":310:36) +#loc215 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":318:20) +#loc216 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":676:20) +#loc217 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":262:30) +#loc218 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":263:51) +#loc219 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":266:34) +#loc220 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":266:44) +#loc221 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":266:67) +#loc222 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":267:36) +#loc223 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":267:46) +#loc224 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":267:70) +#loc225 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":269:39) +#loc226 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":269:50) +#loc227 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":269:60) +#loc228 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":271:21) +#loc229 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":272:23) +#loc230 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":275:25) +#loc231 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":276:29) +#loc232 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":601:18) +#loc233 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":602:19) +#loc234 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":672:28) +#loc235 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":746:29) +#loc236 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":669:105) +#loc237 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":672:22) +#loc238 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":741:99) +#loc239 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":746:21) +#loc240 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":626:19) +#loc241 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":627:19) +#loc242 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":628:19) +#loc243 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":775:52) +#loc244 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":675:26) +#loc245 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":675:46) +#loc246 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":678:15) +#loc247 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":680:36) +#loc248 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":703:25) +#loc249 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":706:24) +#loc250 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":722:24) +#loc251 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":723:25) +#loc252 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":724:25) +#loc253 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":725:92) +#loc254 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":727:24) +#loc255 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":728:24) +#loc256 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":729:39) +#loc257 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":730:25) +#loc258 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":732:24) +#loc259 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":736:69) +#loc260 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":739:27) +#loc261 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":740:44) +#loc262 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":740:40) +#loc263 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":740:22) +#loc264 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":750:29) +#loc265 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":744:24) +#loc266 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":744:43) +#loc267 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":750:20) +#loc268 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":751:25) +#loc269 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":751:22) +#loc270 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":751:16) +#loc271 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":773:45) +#loc272 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":775:24) +#loc273 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":775:43) +#loc274 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":623:62) +#loc275 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":626:28) +#loc276 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":627:28) +#loc277 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":303:12) +#loc278 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":323:23) +#loc279 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":323:55) +#loc280 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":330:30) +#loc281 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":334:14) +#loc282 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":337:29) +#loc283 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":344:27) +#loc284 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":344:41) +#loc285 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":344:58) +#loc286 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":345:29) +#loc287 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":345:69) +#loc288 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":139:4) +#loc307 = loc("pid"(#loc2)) +#loc308 = loc("off_zq"(#loc3)) +#loc309 = loc("off_hkv"(#loc4)) +#loc310 = loc("off_zkv"(#loc5)) +#loc311 = loc("k_adj"(#loc6)) +#loc312 = loc("k_adj"(#loc7)) +#loc313 = loc("k_adj"(#loc8)) +#loc314 = loc("k_adj"(#loc9)) +#loc315 = loc("dv_adj"(#loc10)) +#loc316 = loc("dv_adj"(#loc11)) +#loc317 = loc("dv_adj"(#loc12)) +#loc318 = loc("K"(#loc13)) +#loc319 = loc("V"(#loc14)) +#loc320 = loc("DV"(#loc15)) +#loc321 = loc("offs_k"(#loc16)) +#loc322 = loc("off_pid"(#loc19)) +#loc323 = loc("off_hq2"(#loc20)) +#loc324 = loc("off_hq2"(#loc21)) +#loc325 = loc("off_hq2"(#loc22)) +#loc326 = loc("start_m2_block"(#loc23)) +#loc327 = loc("sparse_kv_num_blks_offset"(#loc24)) +#loc328 = loc("sparse_kv_num_blks_offset"(#loc25)) +#loc329 = loc("sparse_kv_idx_offset"(#loc26)) +#loc330 = loc("sparse_kv_idx_offset"(#loc27)) +#loc331 = loc("sparse_kv_idx_offset"(#loc28)) +#loc332 = loc("q_adj2"(#loc29)) +#loc333 = loc("q_adj2"(#loc30)) +#loc334 = loc("q_adj2"(#loc31)) +#loc335 = loc("q_adj2"(#loc32)) +#loc336 = loc("do_adj2"(#loc33)) +#loc337 = loc("do_adj2"(#loc34)) +#loc338 = loc("do_adj2"(#loc35)) +#loc339 = loc("off_chz2"(#loc36)) +#loc340 = loc("off_chz2"(#loc37)) +#loc341 = loc("off_chz2"(#loc38)) +#loc342 = loc("off_chz2"(#loc39)) +#loc343 = loc("Q2"(#loc40)) +#loc344 = loc("DO2"(#loc41)) +#loc345 = loc("DQ2"(#loc42)) +#loc346 = loc("LSE2"(#loc43)) +#loc347 = loc("DELTA2"(#loc44)) +#loc348 = loc("start_m2"(#loc45)) +#loc349 = loc("offs_m2"(#loc46)) +#loc350 = loc("ptr"(#loc47)) +#loc351 = loc("q"(#loc48)) +#loc352 = loc("ptr"(#loc49)) +#loc353 = loc("ptr"(#loc50)) +#loc354 = loc("ptr"(#loc51)) +#loc355 = loc("ptr"(#loc52)) +#loc356 = loc("do"(#loc54)) +#loc357 = loc("Di"(#loc55)) +#loc358 = loc("Di"(#loc56)) +#loc359 = loc("lse"(#loc57)) +#loc360 = loc("lse"(#loc58)) +#loc361 = loc("lse"(#loc59)) +#loc362 = loc("lse"(#loc60)) +#loc363 = loc("lse"(#loc61)) +#loc364 = loc("kv_indices"(#loc62)) +#loc365 = loc("kv_start"(#loc63)) +#loc366 = loc("kv_start"(#loc64)) +#loc367 = loc("sparse_kv_num_blocks"(#loc65)) +#loc368 = loc("sparse_kv_num_blocks"(#loc66)) +#loc369 = loc("offs_n2"(#loc67)) +#loc370 = loc("offs_n2"(#loc68)) +#loc371 = loc("kT_ptrs"(#loc69)) +#loc372 = loc("dq"(#loc70)) +#loc373 = loc("kT_ptrs"(#loc71)) +#loc374 = loc("kT_ptrs"(#loc72)) +#loc375 = loc("kT_ptrs"(#loc73)) +#loc376 = loc("kT_ptrs"(#loc74)) +#loc377 = loc("vT_ptrs"(#loc75)) +#loc378 = loc("vT_ptrs"(#loc76)) +#loc379 = loc("hi"(#loc77)) +#loc380 = loc("hi"(#loc78)) +#loc381 = loc("tmp4"(#loc79)) +#loc382 = loc("dq"(#loc80)) +#loc383 = loc("tmp7"(#loc81)) +#loc384 = loc("dq"(#loc82)) +#loc385 = loc("tmp7"(#loc83)) +#loc386 = loc("tmp8"(#loc84)) +#loc387 = loc("tmp9"(#loc85)) +#loc388 = loc("tmp10"(#loc86)) +#loc389 = loc("tmp11"(#loc87)) +#loc390 = loc("p"(#loc88)) +#loc391 = loc("ds"(#loc89)) +#loc392 = loc("ds"(#loc90)) +#loc393 = loc("kT"(#loc91)) +#loc394 = loc("vT"(#loc92)) +#loc395 = loc("kT_ptrs"(#loc93)) +#loc396 = loc("vT_ptrs"(#loc94)) +#loc397 = loc("qk"(#loc95)) +#loc398 = loc("dq"(#loc96)) +#loc399 = loc("qk"(#loc97)) +#loc400 = loc("n"(#loc98)) +#loc401 = loc("tmp5"(#loc99)) +#loc402 = loc("tmp12"(#loc100)) +#loc403 = loc("tmp15"(#loc101)) +#loc404 = loc("tmp16"(#loc102)) +#loc405 = loc("tmp18"(#loc103)) +#loc406 = loc("tmp19"(#loc104)) +#loc407 = loc("tmp22"(#loc105)) +#loc408 = loc("tmp23"(#loc106)) +#loc409 = loc("tmp24"(#loc107)) +#loc410 = loc("tmp25"(#loc108)) +#loc411 = loc("tmp26"(#loc109)) +#loc412 = loc("tmp27"(#loc110)) +#loc413 = loc("tmp28"(#loc111)) +#loc414 = loc("tmp29"(#loc112)) +#loc415 = loc("tmp30"(#loc113)) +#loc416 = loc("tmp31"(#loc114)) +#loc417 = loc("tmp33"(#loc115)) +#loc418 = loc("tmp34"(#loc116)) +#loc419 = loc("tmp35"(#loc117)) +#loc420 = loc("tmp36"(#loc118)) +#loc421 = loc("tmp37"(#loc119)) +#loc422 = loc("tmp38"(#loc120)) +#loc423 = loc("post_mod_scores"(#loc121)) +#loc424 = loc("post_mod_scores"(#loc122)) +#loc425 = loc("p"(#loc123)) +#loc426 = loc("dp"(#loc124)) +#loc427 = loc("ds"(#loc125)) +#loc428 = loc("ds"(#loc126)) +#loc429 = loc("ds"(#loc127)) +#loc430 = loc("dq"(#loc128)) +#loc431 = loc("offs_n2"(#loc129)) +#loc432 = loc("cur_block_idx"(#loc130)) +#loc433 = loc("offset"(#loc131)) +#loc434 = loc("cur_block"(#loc132)) +#loc435 = loc("cur_block"(#loc133)) +#loc436 = loc("next_block"(#loc134)) +#loc437 = loc("next_block"(#loc135)) +#loc438 = loc("next_block"(#loc136)) +#loc439 = loc("next_block"(#loc137)) +#loc440 = loc("needs_jump"(#loc138)) +#loc441 = loc("needs_jump"(#loc139)) +#loc442 = loc("needs_jump"(#loc140)) +#loc443 = loc("jump_to_block"(#loc141)) +#loc444 = loc("jump_to_block"(#loc142)) +#loc445 = loc("jump_to_block"(#loc143)) +#loc446 = loc("offset"(#loc144)) +#loc447 = loc("offset"(#loc145)) +#loc448 = loc("offset"(#loc146)) +#loc449 = loc("offset"(#loc147)) +#loc450 = loc("kT_ptrs"(#loc148)) +#loc451 = loc("kv_indices"(#loc149)) +#loc452 = loc("kv_start"(#loc150)) +#loc453 = loc("kv_start"(#loc151)) +#loc454 = loc("sparse_kv_num_blocks"(#loc152)) +#loc455 = loc("sparse_kv_num_blocks"(#loc153)) +#loc456 = loc("offs_n2"(#loc154)) +#loc457 = loc("dq"(#loc155)) +#loc458 = loc("dq_ptrs"(#loc156)) +#loc459 = loc("dq_ptrs"(#loc157)) +#loc460 = loc("dq"(#loc158)) +#loc461 = loc("start_n1"(#loc160)) +#loc462 = loc("offs_n1"(#loc161)) +#loc463 = loc("k"(#loc162)) +#loc464 = loc("v"(#loc163)) +#loc465 = loc("off_hq1"(#loc164)) +#loc466 = loc("q_adj1"(#loc165)) +#loc467 = loc("off_chz1"(#loc166)) +#loc468 = loc("sparse_q_num_blks_offset"(#loc167)) +#loc469 = loc("sparse_q_num_blks_offset"(#loc168)) +#loc470 = loc("sparse_q_idx_offset"(#loc169)) +#loc471 = loc("sparse_q_idx_offset"(#loc170)) +#loc472 = loc("sparse_q_idx_offset"(#loc171)) +#loc473 = loc("q_indices"(#loc172)) +#loc474 = loc("q_start"(#loc173)) +#loc475 = loc("q_start"(#loc174)) +#loc476 = loc("sparse_q_num_blocks"(#loc175)) +#loc477 = loc("sparse_q_num_blocks"(#loc176)) +#loc478 = loc("offs_m1"(#loc177)) +#loc479 = loc("offs_m1"(#loc178)) +#loc480 = loc("qT_ptrs"(#loc179)) +#loc481 = loc("qT_ptrs"(#loc181)) +#loc482 = loc("qT_ptrs"(#loc182)) +#loc483 = loc("qT_ptrs"(#loc183)) +#loc484 = loc("do_ptrs"(#loc184)) +#loc485 = loc("do_ptrs"(#loc185)) +#loc486 = loc("do_ptrs"(#loc186)) +#loc487 = loc("hi"(#loc187)) +#loc488 = loc("hi"(#loc188)) +#loc489 = loc("tmp44"(#loc189)) +#loc490 = loc(callsite(#loc190 at #loc180)) +#loc491 = loc("tmp45"(#loc191)) +#loc492 = loc("tmp47"(#loc192)) +#loc493 = loc("dk"(#loc193)) +#loc494 = loc("tmp47"(#loc194)) +#loc495 = loc("tmp48"(#loc195)) +#loc496 = loc("tmp50"(#loc196)) +#loc497 = loc("tmp51"(#loc197)) +#loc498 = loc("tmp55"(#loc198)) +#loc499 = loc("tmp56"(#loc199)) +#loc500 = loc("tmp58"(#loc200)) +#loc501 = loc("tmp59"(#loc201)) +#loc502 = loc("tmp62"(#loc202)) +#loc503 = loc("tmp63"(#loc203)) +#loc504 = loc("tmp64"(#loc204)) +#loc505 = loc("tmp65"(#loc205)) +#loc506 = loc("tmp66"(#loc206)) +#loc507 = loc("tmp67"(#loc207)) +#loc508 = loc("tmp77"(#loc208)) +#loc509 = loc("q_indices"(#loc209)) +#loc510 = loc("q_start"(#loc210)) +#loc511 = loc("q_start"(#loc211)) +#loc512 = loc("sparse_q_num_blocks"(#loc212)) +#loc513 = loc("sparse_q_num_blocks"(#loc213)) +#loc514 = loc("offs_m1"(#loc214)) +#loc515 = loc("qkT"(#loc216)) +#loc516 = loc("dv"(#loc217)) +#loc517 = loc("off_hq1"(#loc218)) +#loc518 = loc("q_adj1"(#loc219)) +#loc519 = loc("q_adj1"(#loc220)) +#loc520 = loc("q_adj1"(#loc221)) +#loc521 = loc("do_adj1"(#loc222)) +#loc522 = loc("do_adj1"(#loc223)) +#loc523 = loc("do_adj1"(#loc224)) +#loc524 = loc("off_chz1"(#loc225)) +#loc525 = loc("off_chz1"(#loc226)) +#loc526 = loc("off_chz1"(#loc227)) +#loc527 = loc("Q1"(#loc228)) +#loc528 = loc("DO1"(#loc229)) +#loc529 = loc("LSE1"(#loc230)) +#loc530 = loc("DELTA1"(#loc231)) +#loc531 = loc("qT_ptrs"(#loc232)) +#loc532 = loc("do_ptrs"(#loc233)) +#loc533 = loc("lse"(#loc234)) +#loc534 = loc("Di"(#loc235)) +#loc535 = loc("qT"(#loc236)) +#loc536 = loc("lse"(#loc237)) +#loc537 = loc("do"(#loc238)) +#loc538 = loc("Di"(#loc239)) +#loc539 = loc("qT_ptrs"(#loc240)) +#loc540 = loc("do_ptrs"(#loc241)) +#loc541 = loc("offs_m1"(#loc242)) +#loc542 = loc("dk"(#loc243)) +#loc543 = loc("lse"(#loc244)) +#loc544 = loc("lse"(#loc245)) +#loc545 = loc("qkT"(#loc246)) +#loc546 = loc("m"(#loc247)) +#loc547 = loc("tmp49"(#loc248)) +#loc548 = loc("tmp52"(#loc249)) +#loc549 = loc("tmp68"(#loc250)) +#loc550 = loc("tmp69"(#loc251)) +#loc551 = loc("tmp70"(#loc252)) +#loc552 = loc("tmp71"(#loc253)) +#loc553 = loc("tmp73"(#loc254)) +#loc554 = loc("tmp74"(#loc255)) +#loc555 = loc("tmp75"(#loc256)) +#loc556 = loc("tmp76"(#loc257)) +#loc557 = loc("tmp78"(#loc258)) +#loc558 = loc("post_mod_scores"(#loc259)) +#loc559 = loc("post_mod_scores"(#loc260)) +#loc560 = loc("pT"(#loc261)) +#loc561 = loc("pT"(#loc262)) +#loc562 = loc("pT"(#loc263)) +#loc563 = loc("dpT"(#loc264)) +#loc564 = loc("dv"(#loc265)) +#loc565 = loc("dv"(#loc266)) +#loc566 = loc("dpT"(#loc267)) +#loc567 = loc("dsT"(#loc268)) +#loc568 = loc("dsT"(#loc269)) +#loc569 = loc("dsT"(#loc270)) +#loc570 = loc("dsT"(#loc271)) +#loc571 = loc("dk"(#loc272)) +#loc572 = loc("dk"(#loc273)) +#loc573 = loc("offset"(#loc274)) +#loc574 = loc("qT_ptrs"(#loc275)) +#loc575 = loc("do_ptrs"(#loc276)) +#loc576 = loc(callsite(#loc190 at #loc215)) +#loc577 = loc("dv_ptrs"(#loc278)) +#loc578 = loc("dv_ptrs"(#loc279)) +#loc579 = loc("dk"(#loc281)) +#loc580 = loc("mask"(#loc282)) +#loc581 = loc("xindex"(#loc283)) +#loc582 = loc("xindex"(#loc284)) +#loc583 = loc("xindex"(#loc285)) +#loc584 = loc(callsite(#loc350 at #loc351)) +#loc585 = loc(callsite(#loc352 at #loc351)) +#loc586 = loc(callsite(#loc353 at #loc351)) +#loc587 = loc(callsite(#loc354 at #loc351)) +#loc588 = loc(callsite(#loc355 at #loc351)) +#loc589 = loc(callsite(#loc53 at #loc351)) +#loc590 = loc(callsite(#loc352 at #loc356)) +#loc591 = loc(callsite(#loc353 at #loc356)) +#loc592 = loc(callsite(#loc355 at #loc356)) +#loc593 = loc(callsite(#loc53 at #loc356)) +#loc594 = loc(callsite(#loc371 at #loc372)) +#loc595 = loc(callsite(#loc373 at #loc372)) +#loc596 = loc(callsite(#loc374 at #loc372)) +#loc597 = loc(callsite(#loc375 at #loc372)) +#loc598 = loc(callsite(#loc376 at #loc372)) +#loc599 = loc(callsite(#loc377 at #loc372)) +#loc600 = loc(callsite(#loc378 at #loc372)) +#loc601 = loc(callsite(#loc379 at #loc372)) +#loc602 = loc(callsite(#loc380 at #loc372)) +#loc603 = loc(callsite(#loc382 at #loc372)) +#loc604 = loc("offs_n2"(#loc384)) +#loc605 = loc(callsite(#loc395 at #loc372)) +#loc606 = loc(callsite(#loc396 at #loc372)) +#loc607 = loc(callsite(#loc431 at #loc372)) +#loc608 = loc(callsite(#loc433 at #loc372)) +#loc609 = loc(callsite(#loc450 at #loc372)) +#loc610 = loc(callsite(#loc371 at #loc457)) +#loc611 = loc(callsite(#loc373 at #loc457)) +#loc612 = loc(callsite(#loc374 at #loc457)) +#loc613 = loc(callsite(#loc376 at #loc457)) +#loc614 = loc(callsite(#loc377 at #loc457)) +#loc615 = loc(callsite(#loc378 at #loc457)) +#loc616 = loc(callsite(#loc379 at #loc457)) +#loc617 = loc(callsite(#loc380 at #loc457)) +#loc618 = loc(callsite(#loc382 at #loc457)) +#loc619 = loc(callsite(#loc395 at #loc457)) +#loc620 = loc(callsite(#loc396 at #loc457)) +#loc621 = loc(callsite(#loc433 at #loc457)) +#loc622 = loc(callsite(#loc450 at #loc457)) +#loc623 = loc(callsite(#loc350 at #loc463)) +#loc624 = loc(callsite(#loc352 at #loc463)) +#loc625 = loc(callsite(#loc353 at #loc463)) +#loc626 = loc(callsite(#loc354 at #loc463)) +#loc627 = loc(callsite(#loc355 at #loc463)) +#loc628 = loc(callsite(#loc53 at #loc463)) +#loc629 = loc(callsite(#loc353 at #loc464)) +#loc630 = loc(callsite(#loc355 at #loc464)) +#loc631 = loc(callsite(#loc53 at #loc464)) +#loc632 = loc(callsite(#loc480 at #loc180)) +#loc633 = loc(callsite(#loc481 at #loc180)) +#loc634 = loc(callsite(#loc482 at #loc180)) +#loc635 = loc(callsite(#loc483 at #loc180)) +#loc636 = loc(callsite(#loc484 at #loc180)) +#loc637 = loc(callsite(#loc485 at #loc180)) +#loc638 = loc(callsite(#loc486 at #loc180)) +#loc639 = loc(callsite(#loc487 at #loc180)) +#loc640 = loc(callsite(#loc488 at #loc180)) +#loc641 = loc(callsite(#loc489 at #loc490)) +#loc642 = loc(callsite(#loc491 at #loc490)) +#loc643 = loc(callsite(#loc492 at #loc490)) +#loc644 = loc("dv"(#loc493)) +#loc645 = loc(callsite(#loc494 at #loc490)) +#loc646 = loc(callsite(#loc495 at #loc490)) +#loc647 = loc(callsite(#loc496 at #loc490)) +#loc648 = loc(callsite(#loc497 at #loc490)) +#loc649 = loc(callsite(#loc498 at #loc490)) +#loc650 = loc(callsite(#loc499 at #loc490)) +#loc651 = loc(callsite(#loc500 at #loc490)) +#loc652 = loc(callsite(#loc501 at #loc490)) +#loc653 = loc(callsite(#loc502 at #loc490)) +#loc654 = loc(callsite(#loc503 at #loc490)) +#loc655 = loc(callsite(#loc504 at #loc490)) +#loc656 = loc(callsite(#loc505 at #loc490)) +#loc657 = loc(callsite(#loc506 at #loc490)) +#loc658 = loc(callsite(#loc507 at #loc490)) +#loc659 = loc(callsite(#loc508 at #loc490)) +#loc660 = loc(callsite(#loc480 at #loc215)) +#loc661 = loc(callsite(#loc481 at #loc215)) +#loc662 = loc(callsite(#loc484 at #loc215)) +#loc663 = loc(callsite(#loc485 at #loc215)) +#loc664 = loc(callsite(#loc487 at #loc215)) +#loc665 = loc(callsite(#loc488 at #loc215)) +#loc666 = loc(callsite(#loc515 at #loc490)) +#loc667 = loc("dk"(#loc516)) +#loc668 = loc(callsite(#loc531 at #loc180)) +#loc669 = loc(callsite(#loc532 at #loc180)) +#loc670 = loc(callsite(#loc533 at #loc490)) +#loc671 = loc(callsite(#loc534 at #loc490)) +#loc672 = loc(callsite(#loc535 at #loc490)) +#loc673 = loc(callsite(#loc536 at #loc490)) +#loc674 = loc(callsite(#loc537 at #loc490)) +#loc675 = loc(callsite(#loc538 at #loc490)) +#loc676 = loc(callsite(#loc539 at #loc180)) +#loc677 = loc(callsite(#loc540 at #loc180)) +#loc678 = loc(callsite(#loc541 at #loc180)) +#loc679 = loc(callsite(#loc542 at #loc490)) +#loc680 = loc(callsite(#loc543 at #loc490)) +#loc681 = loc(callsite(#loc544 at #loc490)) +#loc682 = loc(callsite(#loc545 at #loc490)) +#loc683 = loc(callsite(#loc546 at #loc490)) +#loc684 = loc(callsite(#loc547 at #loc490)) +#loc685 = loc(callsite(#loc548 at #loc490)) +#loc686 = loc(callsite(#loc549 at #loc490)) +#loc687 = loc(callsite(#loc550 at #loc490)) +#loc688 = loc(callsite(#loc551 at #loc490)) +#loc689 = loc(callsite(#loc552 at #loc490)) +#loc690 = loc(callsite(#loc553 at #loc490)) +#loc691 = loc(callsite(#loc554 at #loc490)) +#loc692 = loc(callsite(#loc555 at #loc490)) +#loc693 = loc(callsite(#loc556 at #loc490)) +#loc694 = loc(callsite(#loc557 at #loc490)) +#loc695 = loc(callsite(#loc558 at #loc490)) +#loc696 = loc(callsite(#loc559 at #loc490)) +#loc697 = loc(callsite(#loc560 at #loc490)) +#loc698 = loc(callsite(#loc561 at #loc490)) +#loc699 = loc(callsite(#loc562 at #loc490)) +#loc700 = loc(callsite(#loc563 at #loc490)) +#loc701 = loc(callsite(#loc564 at #loc490)) +#loc702 = loc(callsite(#loc565 at #loc490)) +#loc703 = loc(callsite(#loc566 at #loc490)) +#loc704 = loc(callsite(#loc567 at #loc490)) +#loc705 = loc(callsite(#loc568 at #loc490)) +#loc706 = loc(callsite(#loc569 at #loc490)) +#loc707 = loc(callsite(#loc570 at #loc490)) +#loc708 = loc(callsite(#loc571 at #loc490)) +#loc709 = loc(callsite(#loc572 at #loc490)) +#loc710 = loc(callsite(#loc573 at #loc180)) +#loc711 = loc(callsite(#loc574 at #loc180)) +#loc712 = loc(callsite(#loc575 at #loc180)) +#loc713 = loc(callsite(#loc531 at #loc215)) +#loc714 = loc(callsite(#loc483 at #loc215)) +#loc715 = loc(callsite(#loc532 at #loc215)) +#loc716 = loc(callsite(#loc486 at #loc215)) +#loc717 = loc(callsite(#loc535 at #loc576)) +#loc718 = loc(callsite(#loc536 at #loc576)) +#loc719 = loc(callsite(#loc537 at #loc576)) +#loc720 = loc(callsite(#loc538 at #loc576)) +#loc721 = loc(callsite(#loc533 at #loc576)) +#loc722 = loc(callsite(#loc534 at #loc576)) +#loc723 = loc(callsite(#loc539 at #loc215)) +#loc724 = loc(callsite(#loc540 at #loc215)) +#loc725 = loc(callsite(#loc541 at #loc215)) +#loc726 = loc(callsite(#loc542 at #loc576)) +#loc727 = loc(callsite(#loc543 at #loc576)) +#loc728 = loc(callsite(#loc544 at #loc576)) +#loc729 = loc(callsite(#loc515 at #loc576)) +#loc730 = loc(callsite(#loc545 at #loc576)) +#loc731 = loc(callsite(#loc559 at #loc576)) +#loc732 = loc(callsite(#loc560 at #loc576)) +#loc733 = loc(callsite(#loc561 at #loc576)) +#loc734 = loc(callsite(#loc562 at #loc576)) +#loc735 = loc(callsite(#loc563 at #loc576)) +#loc736 = loc(callsite(#loc564 at #loc576)) +#loc737 = loc(callsite(#loc565 at #loc576)) +#loc738 = loc(callsite(#loc566 at #loc576)) +#loc739 = loc(callsite(#loc567 at #loc576)) +#loc740 = loc(callsite(#loc568 at #loc576)) +#loc741 = loc(callsite(#loc569 at #loc576)) +#loc742 = loc(callsite(#loc571 at #loc576)) +#loc743 = loc(callsite(#loc572 at #loc576)) +#loc744 = loc(callsite(#loc573 at #loc215)) +#loc745 = loc(callsite(#loc574 at #loc215)) +#loc746 = loc(callsite(#loc575 at #loc215)) +#loc747 = loc(callsite(#loc381 at #loc603)) +#loc748 = loc(callsite(#loc383 at #loc603)) +#loc749 = loc("kT_ptrs"(#loc604)) +#loc750 = loc(callsite(#loc385 at #loc603)) +#loc751 = loc(callsite(#loc386 at #loc603)) +#loc752 = loc(callsite(#loc387 at #loc603)) +#loc753 = loc(callsite(#loc388 at #loc603)) +#loc754 = loc(callsite(#loc389 at #loc603)) +#loc755 = loc(callsite(#loc390 at #loc603)) +#loc756 = loc(callsite(#loc391 at #loc603)) +#loc757 = loc(callsite(#loc392 at #loc603)) +#loc758 = loc(callsite(#loc393 at #loc603)) +#loc759 = loc(callsite(#loc394 at #loc603)) +#loc760 = loc(callsite(#loc397 at #loc603)) +#loc761 = loc(callsite(#loc398 at #loc603)) +#loc762 = loc(callsite(#loc399 at #loc603)) +#loc763 = loc(callsite(#loc400 at #loc603)) +#loc764 = loc(callsite(#loc401 at #loc603)) +#loc765 = loc(callsite(#loc402 at #loc603)) +#loc766 = loc(callsite(#loc403 at #loc603)) +#loc767 = loc(callsite(#loc404 at #loc603)) +#loc768 = loc(callsite(#loc405 at #loc603)) +#loc769 = loc(callsite(#loc406 at #loc603)) +#loc770 = loc(callsite(#loc407 at #loc603)) +#loc771 = loc(callsite(#loc408 at #loc603)) +#loc772 = loc(callsite(#loc409 at #loc603)) +#loc773 = loc(callsite(#loc410 at #loc603)) +#loc774 = loc(callsite(#loc411 at #loc603)) +#loc775 = loc(callsite(#loc412 at #loc603)) +#loc776 = loc(callsite(#loc413 at #loc603)) +#loc777 = loc(callsite(#loc414 at #loc603)) +#loc778 = loc(callsite(#loc415 at #loc603)) +#loc779 = loc(callsite(#loc416 at #loc603)) +#loc780 = loc(callsite(#loc417 at #loc603)) +#loc781 = loc(callsite(#loc418 at #loc603)) +#loc782 = loc(callsite(#loc419 at #loc603)) +#loc783 = loc(callsite(#loc420 at #loc603)) +#loc784 = loc(callsite(#loc421 at #loc603)) +#loc785 = loc(callsite(#loc422 at #loc603)) +#loc786 = loc(callsite(#loc423 at #loc603)) +#loc787 = loc(callsite(#loc424 at #loc603)) +#loc788 = loc(callsite(#loc425 at #loc603)) +#loc789 = loc(callsite(#loc426 at #loc603)) +#loc790 = loc(callsite(#loc427 at #loc603)) +#loc791 = loc(callsite(#loc428 at #loc603)) +#loc792 = loc(callsite(#loc429 at #loc603)) +#loc793 = loc(callsite(#loc430 at #loc603)) +#loc794 = loc(callsite(#loc432 at #loc608)) +#loc795 = loc(callsite(#loc434 at #loc608)) +#loc796 = loc(callsite(#loc435 at #loc608)) +#loc797 = loc(callsite(#loc436 at #loc608)) +#loc798 = loc(callsite(#loc437 at #loc608)) +#loc799 = loc(callsite(#loc438 at #loc608)) +#loc800 = loc(callsite(#loc439 at #loc608)) +#loc801 = loc(callsite(#loc440 at #loc608)) +#loc802 = loc(callsite(#loc441 at #loc608)) +#loc803 = loc(callsite(#loc442 at #loc608)) +#loc804 = loc(callsite(#loc443 at #loc608)) +#loc805 = loc(callsite(#loc444 at #loc608)) +#loc806 = loc(callsite(#loc445 at #loc608)) +#loc807 = loc(callsite(#loc446 at #loc608)) +#loc808 = loc(callsite(#loc447 at #loc608)) +#loc809 = loc(callsite(#loc448 at #loc608)) +#loc810 = loc(callsite(#loc449 at #loc608)) +#loc811 = loc(callsite(#loc393 at #loc618)) +#loc812 = loc(callsite(#loc394 at #loc618)) +#loc813 = loc(callsite(#loc397 at #loc618)) +#loc814 = loc(callsite(#loc398 at #loc618)) +#loc815 = loc(callsite(#loc399 at #loc618)) +#loc816 = loc(callsite(#loc424 at #loc618)) +#loc817 = loc(callsite(#loc390 at #loc618)) +#loc818 = loc(callsite(#loc425 at #loc618)) +#loc819 = loc(callsite(#loc426 at #loc618)) +#loc820 = loc(callsite(#loc392 at #loc618)) +#loc821 = loc(callsite(#loc427 at #loc618)) +#loc822 = loc(callsite(#loc429 at #loc618)) +#loc823 = loc(callsite(#loc430 at #loc618)) +#loc824 = loc(callsite(#loc432 at #loc621)) +#loc825 = loc(callsite(#loc434 at #loc621)) +#loc826 = loc(callsite(#loc435 at #loc621)) +#loc827 = loc(callsite(#loc436 at #loc621)) +#loc828 = loc(callsite(#loc437 at #loc621)) +#loc829 = loc(callsite(#loc438 at #loc621)) +#loc830 = loc(callsite(#loc439 at #loc621)) +#loc831 = loc(callsite(#loc440 at #loc621)) +#loc832 = loc(callsite(#loc441 at #loc621)) +#loc833 = loc(callsite(#loc442 at #loc621)) +#loc834 = loc(callsite(#loc443 at #loc621)) +#loc835 = loc(callsite(#loc444 at #loc621)) +#loc836 = loc(callsite(#loc445 at #loc621)) +#loc837 = loc(callsite(#loc446 at #loc621)) +#loc838 = loc(callsite(#loc447 at #loc621)) +#loc839 = loc(callsite(#loc448 at #loc621)) +#loc840 = loc(callsite(#loc449 at #loc621)) +#loc841 = loc("offs_m1"(#loc644)) +#loc842 = loc(callsite(#loc53 at #loc672)) +#loc843 = loc(callsite(#loc53 at #loc674)) +#loc844 = loc(callsite(#loc432 at #loc710)) +#loc845 = loc(callsite(#loc434 at #loc710)) +#loc846 = loc(callsite(#loc435 at #loc710)) +#loc847 = loc(callsite(#loc436 at #loc710)) +#loc848 = loc(callsite(#loc437 at #loc710)) +#loc849 = loc(callsite(#loc438 at #loc710)) +#loc850 = loc(callsite(#loc439 at #loc710)) +#loc851 = loc(callsite(#loc440 at #loc710)) +#loc852 = loc(callsite(#loc441 at #loc710)) +#loc853 = loc(callsite(#loc442 at #loc710)) +#loc854 = loc(callsite(#loc443 at #loc710)) +#loc855 = loc(callsite(#loc444 at #loc710)) +#loc856 = loc(callsite(#loc445 at #loc710)) +#loc857 = loc(callsite(#loc446 at #loc710)) +#loc858 = loc(callsite(#loc447 at #loc710)) +#loc859 = loc(callsite(#loc448 at #loc710)) +#loc860 = loc(callsite(#loc449 at #loc710)) +#loc861 = loc(callsite(#loc53 at #loc717)) +#loc862 = loc(callsite(#loc53 at #loc719)) +#loc863 = loc(callsite(#loc432 at #loc744)) +#loc864 = loc(callsite(#loc434 at #loc744)) +#loc865 = loc(callsite(#loc435 at #loc744)) +#loc866 = loc(callsite(#loc436 at #loc744)) +#loc867 = loc(callsite(#loc437 at #loc744)) +#loc868 = loc(callsite(#loc438 at #loc744)) +#loc869 = loc(callsite(#loc439 at #loc744)) +#loc870 = loc(callsite(#loc440 at #loc744)) +#loc871 = loc(callsite(#loc441 at #loc744)) +#loc872 = loc(callsite(#loc442 at #loc744)) +#loc873 = loc(callsite(#loc443 at #loc744)) +#loc874 = loc(callsite(#loc444 at #loc744)) +#loc875 = loc(callsite(#loc445 at #loc744)) +#loc876 = loc(callsite(#loc446 at #loc744)) +#loc877 = loc(callsite(#loc447 at #loc744)) +#loc878 = loc(callsite(#loc448 at #loc744)) +#loc879 = loc(callsite(#loc449 at #loc744)) +#loc880 = loc("vT_ptrs"(#loc749)) +#loc881 = loc(callsite(#loc53 at #loc758)) +#loc882 = loc(callsite(#loc53 at #loc759)) +#loc883 = loc(callsite(#loc53 at #loc811)) +#loc884 = loc(callsite(#loc53 at #loc812)) +#loc885 = loc("qT_ptrs"(#loc841)) +#loc886 = loc(callsite(#loc880 at #loc372)) +#loc887 = loc(callsite(#loc880 at #loc457)) +#loc888 = loc("do_ptrs"(#loc885)) +#loc889 = loc(callsite(#loc888 at #loc180)) +#loc890 = loc(callsite(#loc888 at #loc215)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/HULI6VL32CI7Q6UGP75BIUMHXEMSAWXMPR3LDJJLGOCFNMUFGPVQ/triton_tem_fused_zeros_1.ttir b/SpecForge-ext/cache/compiled_kernels/triton/0/HULI6VL32CI7Q6UGP75BIUMHXEMSAWXMPR3LDJJLGOCFNMUFGPVQ/triton_tem_fused_zeros_1.ttir new file mode 100644 index 0000000000000000000000000000000000000000..a91e1206a8cf0208c9bd79bdde3c1f329010f5bd --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/HULI6VL32CI7Q6UGP75BIUMHXEMSAWXMPR3LDJJLGOCFNMUFGPVQ/triton_tem_fused_zeros_1.ttir @@ -0,0 +1,1440 @@ +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":18:0) +#loc291 = loc("arg_Q"(#loc)) +#loc292 = loc("arg_K"(#loc)) +#loc293 = loc("arg_V"(#loc)) +#loc294 = loc("arg_LSE"(#loc)) +#loc295 = loc("arg_DELTA"(#loc)) +#loc296 = loc("arg_DO"(#loc)) +#loc297 = loc("arg_DQ"(#loc)) +#loc298 = loc("arg_DV"(#loc)) +#loc299 = loc("arg_KV_NUM_BLKS"(#loc)) +#loc300 = loc("arg_KV_IDX"(#loc)) +#loc301 = loc("arg_Q_NUM_BLKS"(#loc)) +#loc302 = loc("arg_Q_IDX"(#loc)) +#loc303 = loc("arg_FULL_KV_NUM_BLKS"(#loc)) +#loc304 = loc("arg_FULL_KV_IDX"(#loc)) +#loc305 = loc("arg_FULL_Q_NUM_BLKS"(#loc)) +#loc306 = loc("arg_FULL_Q_IDX"(#loc)) +#loc307 = loc("in_ptr16"(#loc)) +#loc308 = loc("out_ptr0"(#loc)) +module { + tt.func public @triton_tem_fused_zeros_1(%arg_Q: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_Q"(#loc)), %arg_K: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_K"(#loc)), %arg_V: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_V"(#loc)), %arg_LSE: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_LSE"(#loc)), %arg_DELTA: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_DELTA"(#loc)), %arg_DO: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_DO"(#loc)), %arg_DQ: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_DQ"(#loc)), %arg_DV: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_DV"(#loc)), %arg_KV_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_KV_NUM_BLKS"(#loc)), %arg_KV_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_KV_IDX"(#loc)), %arg_Q_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_Q_NUM_BLKS"(#loc)), %arg_Q_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_Q_IDX"(#loc)), %arg_FULL_KV_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_KV_NUM_BLKS"(#loc)), %arg_FULL_KV_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_KV_IDX"(#loc)), %arg_FULL_Q_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_Q_NUM_BLKS"(#loc)), %arg_FULL_Q_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_Q_IDX"(#loc)), %in_ptr16: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr16"(#loc)), %out_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr0"(#loc))) attributes {noinline = false} { + %cst = arith.constant dense<128> : tensor<64x1xi32> loc(#loc1) + %cst_0 = arith.constant dense<4096> : tensor<1x64xi32> loc(#loc1) + %cst_1 = arith.constant dense<128> : tensor<1x64xi32> loc(#loc1) + %cst_2 = arith.constant dense<0> : tensor<128x1xi32> loc(#loc1) + %cst_3 = arith.constant dense<0.000000e+00> : tensor<64xf32> loc(#loc1) + %cst_4 = arith.constant dense<0xFF800000> : tensor<64xf32> loc(#loc1) + %c64_i32 = arith.constant 64 : i32 loc(#loc1) + %cst_5 = arith.constant dense<2048> : tensor<128x64xi32> loc(#loc1) + %cst_6 = arith.constant dense<2048> : tensor<1x64xi32> loc(#loc1) + %cst_7 = arith.constant dense<1.44269502> : tensor<128x64xf32> loc(#loc1) + %cst_8 = arith.constant dense<0xFF800000> : tensor<128x64xf32> loc(#loc1) + %cst_9 = arith.constant dense<0> : tensor<128x64xi32> loc(#loc1) + %cst_10 = arith.constant dense<0> : tensor<1x64xi32> loc(#loc1) + %cst_11 = arith.constant dense<0.0883883461> : tensor<128x64xf32> loc(#loc1) + %cst_12 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc1) + %cst_13 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> loc(#loc1) + %c0_i32 = arith.constant 0 : i32 loc(#loc1) + %cst_14 = arith.constant dense<2048> : tensor<128x1xi32> loc(#loc1) + %cst_15 = arith.constant dense<128> : tensor<128x1xi32> loc(#loc1) + %cst_16 = arith.constant dense<0.0883883461> : tensor<128x128xf32> loc(#loc1) + %cst_17 = arith.constant dense<4096> : tensor<128x1xi32> loc(#loc1) + %cst_18 = arith.constant dense<0.000000e+00> : tensor<128xf32> loc(#loc1) + %cst_19 = arith.constant dense<0xFF800000> : tensor<128xf32> loc(#loc1) + %c256_i32 = arith.constant 256 : i32 loc(#loc1) + %c16_i32 = arith.constant 16 : i32 loc(#loc1) + %c4_i32 = arith.constant 4 : i32 loc(#loc1) + %c8388608_i32 = arith.constant 8388608 : i32 loc(#loc1) + %c128_i32 = arith.constant 128 : i32 loc(#loc1) + %c4096_i32 = arith.constant 4096 : i32 loc(#loc1) + %c1_i32 = arith.constant 1 : i32 loc(#loc1) + %c2097152_i32 = arith.constant 2097152 : i32 loc(#loc1) + %c262144_i32 = arith.constant 262144 : i32 loc(#loc1) + %c2_i32 = arith.constant 2 : i32 loc(#loc1) + %c32_i32 = arith.constant 32 : i32 loc(#loc1) + %c2048_i32 = arith.constant 2048 : i32 loc(#loc1) + %pid = tt.get_program_id x : i32 loc(#loc309) + %off_zq = tt.get_program_id y : i32 loc(#loc310) + %off_hkv = tt.get_program_id z : i32 loc(#loc311) + %off_zkv = arith.remsi %off_zq, %c2_i32 : i32 loc(#loc312) + %k_adj = arith.muli %off_hkv, %c262144_i32 : i32 loc(#loc313) + %k_adj_20 = arith.muli %off_zkv, %c2097152_i32 : i32 loc(#loc314) + %k_adj_21 = arith.addi %k_adj, %k_adj_20 : i32 loc(#loc315) + %k_adj_22 = arith.extsi %k_adj_21 : i32 to i64 loc(#loc316) + %dv_adj = arith.muli %off_zq, %c2097152_i32 : i32 loc(#loc317) + %dv_adj_23 = arith.addi %k_adj, %dv_adj : i32 loc(#loc318) + %dv_adj_24 = arith.extsi %dv_adj_23 : i32 to i64 loc(#loc319) + %K = tt.addptr %arg_K, %k_adj_22 : !tt.ptr, i64 loc(#loc320) + %V = tt.addptr %arg_V, %k_adj_22 : !tt.ptr, i64 loc(#loc321) + %DV = tt.addptr %arg_DV, %dv_adj_24 : !tt.ptr, i64 loc(#loc322) + %offs_k = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc323) + %0 = arith.cmpi sge, %pid, %c16_i32 : i32 loc(#loc17) + scf.if %0 { + %off_pid = arith.subi %pid, %c16_i32 : i32 loc(#loc324) + %off_hq2 = arith.divsi %off_pid, %c16_i32 : i32 loc(#loc325) + %off_hq2_25 = arith.muli %off_hkv, %c4_i32 : i32 loc(#loc326) + %off_hq2_26 = arith.addi %off_hq2, %off_hq2_25 : i32 loc(#loc327) + %start_m2_block = arith.remsi %off_pid, %c16_i32 : i32 loc(#loc328) + %sparse_kv_num_blks_offset = arith.muli %off_zkv, %c16_i32 : i32 loc(#loc329) + %sparse_kv_num_blks_offset_27 = arith.addi %sparse_kv_num_blks_offset, %start_m2_block : i32 loc(#loc330) + %sparse_kv_idx_offset = arith.muli %off_zkv, %c256_i32 : i32 loc(#loc331) + %sparse_kv_idx_offset_28 = arith.muli %start_m2_block, %c16_i32 : i32 loc(#loc332) + %sparse_kv_idx_offset_29 = arith.addi %sparse_kv_idx_offset, %sparse_kv_idx_offset_28 : i32 loc(#loc333) + %q_adj2 = arith.muli %off_hq2_26, %c128_i32 : i32 loc(#loc334) + %q_adj2_30 = arith.muli %off_zq, %c8388608_i32 : i32 loc(#loc335) + %q_adj2_31 = arith.addi %q_adj2, %q_adj2_30 : i32 loc(#loc336) + %q_adj2_32 = arith.extsi %q_adj2_31 : i32 to i64 loc(#loc337) + %do_adj2 = arith.muli %off_hq2_26, %c262144_i32 : i32 loc(#loc338) + %do_adj2_33 = arith.addi %do_adj2, %q_adj2_30 : i32 loc(#loc339) + %do_adj2_34 = arith.extsi %do_adj2_33 : i32 to i64 loc(#loc340) + %off_chz2 = arith.muli %off_zq, %c32_i32 : i32 loc(#loc341) + %off_chz2_35 = arith.addi %off_chz2, %off_hq2_26 : i32 loc(#loc342) + %off_chz2_36 = arith.muli %off_chz2_35, %c2048_i32 : i32 loc(#loc343) + %off_chz2_37 = arith.extsi %off_chz2_36 : i32 to i64 loc(#loc344) + %Q2 = tt.addptr %arg_Q, %q_adj2_32 : !tt.ptr, i64 loc(#loc345) + %DO2 = tt.addptr %arg_DO, %do_adj2_34 : !tt.ptr, i64 loc(#loc346) + %DQ2 = tt.addptr %arg_DQ, %q_adj2_32 : !tt.ptr, i64 loc(#loc347) + %LSE2 = tt.addptr %arg_LSE, %off_chz2_37 : !tt.ptr, i64 loc(#loc348) + %DELTA2 = tt.addptr %arg_DELTA, %off_chz2_37 : !tt.ptr, i64 loc(#loc349) + %start_m2 = arith.muli %start_m2_block, %c128_i32 : i32 loc(#loc350) + %offs_m2 = tt.splat %start_m2 : i32 -> tensor<128xi32> loc(#loc351) + %offs_m2_38 = arith.addi %offs_m2, %offs_k : tensor<128xi32> loc(#loc351) + %ptr = tt.expand_dims %offs_m2_38 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc588) + %ptr_39 = arith.muli %ptr, %cst_17 : tensor<128x1xi32> loc(#loc589) + %ptr_40 = tt.splat %Q2 : !tt.ptr -> tensor<128x1x!tt.ptr> loc(#loc590) + %ptr_41 = tt.addptr %ptr_40, %ptr_39 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> loc(#loc590) + %ptr_42 = tt.expand_dims %offs_k {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc591) + %ptr_43 = tt.broadcast %ptr_41 : tensor<128x1x!tt.ptr> -> tensor<128x128x!tt.ptr> loc(#loc592) + %ptr_44 = tt.broadcast %ptr_42 : tensor<1x128xi32> -> tensor<128x128xi32> loc(#loc592) + %ptr_45 = tt.addptr %ptr_43, %ptr_44 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> loc(#loc592) + %q = tt.load %ptr_45 : tensor<128x128x!tt.ptr> loc(#loc593) + %ptr_46 = arith.muli %ptr, %cst_15 : tensor<128x1xi32> loc(#loc594) + %ptr_47 = tt.splat %DO2 : !tt.ptr -> tensor<128x1x!tt.ptr> loc(#loc595) + %ptr_48 = tt.addptr %ptr_47, %ptr_46 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> loc(#loc595) + %ptr_49 = tt.broadcast %ptr_48 : tensor<128x1x!tt.ptr> -> tensor<128x128x!tt.ptr> loc(#loc596) + %ptr_50 = tt.addptr %ptr_49, %ptr_44 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> loc(#loc596) + %do = tt.load %ptr_50 : tensor<128x128x!tt.ptr> loc(#loc597) + %Di = tt.splat %DELTA2 : !tt.ptr -> tensor<128x!tt.ptr> loc(#loc359) + %Di_51 = tt.addptr %Di, %offs_m2_38 : tensor<128x!tt.ptr>, tensor<128xi32> loc(#loc359) + %Di_52 = tt.load %Di_51 : tensor<128x!tt.ptr> loc(#loc360) + %lse = tt.splat %LSE2 : !tt.ptr -> tensor<128x!tt.ptr> loc(#loc361) + %lse_53 = tt.addptr %lse, %offs_m2_38 : tensor<128x!tt.ptr>, tensor<128xi32> loc(#loc361) + %lse_54 = tt.load %lse_53 : tensor<128x!tt.ptr> loc(#loc362) + %lse_55 = arith.cmpf oeq, %lse_54, %cst_19 : tensor<128xf32> loc(#loc363) + %lse_56 = arith.select %lse_55, %cst_18, %lse_54 : tensor<128xi1>, tensor<128xf32> loc(#loc364) + %lse_57 = tt.expand_dims %lse_56 {axis = 1 : i32} : tensor<128xf32> -> tensor<128x1xf32> loc(#loc365) + %kv_indices = tt.addptr %arg_KV_IDX, %sparse_kv_idx_offset_29 : !tt.ptr, i32 loc(#loc366) + %kv_start = tt.load %kv_indices : !tt.ptr loc(#loc367) + %kv_start_58 = arith.muli %kv_start, %c128_i32 : i32 loc(#loc368) + %sparse_kv_num_blocks = tt.addptr %arg_KV_NUM_BLKS, %sparse_kv_num_blks_offset_27 : !tt.ptr, i32 loc(#loc369) + %sparse_kv_num_blocks_59 = tt.load %sparse_kv_num_blocks : !tt.ptr loc(#loc370) + %offs_n2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc371) + %offs_n2_60 = tt.splat %kv_start_58 : i32 -> tensor<64xi32> loc(#loc372) + %offs_n2_61 = arith.addi %offs_n2_60, %offs_n2 : tensor<64xi32> loc(#loc372) + %kT_ptrs = tt.expand_dims %offs_n2_61 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc598) + %kT_ptrs_62 = arith.muli %kT_ptrs, %cst_1 : tensor<1x64xi32> loc(#loc599) + %kT_ptrs_63 = tt.splat %K : !tt.ptr -> tensor<1x64x!tt.ptr> loc(#loc600) + %kT_ptrs_64 = tt.addptr %kT_ptrs_63, %kT_ptrs_62 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> loc(#loc600) + %kT_ptrs_65 = tt.expand_dims %offs_k {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc601) + %kT_ptrs_66 = tt.broadcast %kT_ptrs_64 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> loc(#loc602) + %kT_ptrs_67 = tt.broadcast %kT_ptrs_65 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc602) + %kT_ptrs_68 = tt.addptr %kT_ptrs_66, %kT_ptrs_67 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc602) + %vT_ptrs = tt.splat %V : !tt.ptr -> tensor<1x64x!tt.ptr> loc(#loc603) + %vT_ptrs_69 = tt.addptr %vT_ptrs, %kT_ptrs_62 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> loc(#loc603) + %vT_ptrs_70 = tt.broadcast %vT_ptrs_69 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> loc(#loc604) + %vT_ptrs_71 = tt.addptr %vT_ptrs_70, %kT_ptrs_67 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc604) + %hi = arith.muli %sparse_kv_num_blocks_59, %c2_i32 : i32 loc(#loc605) + %hi_72 = arith.minsi %hi, %c32_i32 : i32 loc(#loc606) + %vT_ptrs_73:4 = scf.for %start_n = %c0_i32 to %hi_72 step %c1_i32 iter_args(%dq_95 = %cst_13, %offs_n2_96 = %offs_n2_61, %kT_ptrs_97 = %kT_ptrs_68, %vT_ptrs_98 = %vT_ptrs_71) -> (tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<128x64x!tt.ptr>) : i32 { + %kT = tt.load %kT_ptrs_97 : tensor<128x64x!tt.ptr> loc(#loc889) + %qk = tt.dot %q, %kT, %cst_12, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc756) + %qk_99 = arith.mulf %qk, %cst_11 : tensor<128x64xf32> loc(#loc757) + %n = tt.expand_dims %offs_n2_96 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc758) + %tmp4 = tt.broadcast %ptr : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc759) + %tmp4_100 = tt.broadcast %n : tensor<1x64xi32> -> tensor<128x64xi32> loc(#loc759) + %tmp4_101 = arith.cmpi sge, %tmp4, %tmp4_100 : tensor<128x64xi32> loc(#loc759) + %tmp5 = arith.extsi %n : tensor<1x64xi32> to tensor<1x64xi64> loc(#loc760) + %tmp7 = tt.addptr %in_ptr16, %off_zq : !tt.ptr, i32 loc(#loc761) + %tmp7_102 = tt.load %tmp7 : !tt.ptr loc(#loc762) + %tmp8 = tt.splat %tmp7_102 : i64 -> tensor<1x64xi64> loc(#loc763) + %tmp8_103 = arith.cmpi slt, %tmp5, %tmp8 : tensor<1x64xi64> loc(#loc763) + %tmp9 = arith.extsi %ptr : tensor<128x1xi32> to tensor<128x1xi64> loc(#loc764) + %tmp10 = tt.splat %tmp7_102 : i64 -> tensor<128x1xi64> loc(#loc765) + %tmp10_104 = arith.cmpi slt, %tmp9, %tmp10 : tensor<128x1xi64> loc(#loc765) + %tmp11 = tt.broadcast %tmp8_103 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc766) + %tmp11_105 = tt.broadcast %tmp10_104 : tensor<128x1xi1> -> tensor<128x64xi1> loc(#loc766) + %tmp11_106 = arith.andi %tmp11, %tmp11_105 : tensor<128x64xi1> loc(#loc766) + %tmp12 = arith.andi %tmp4_101, %tmp11_106 : tensor<128x64xi1> loc(#loc767) + %tmp15 = arith.cmpi sge, %n, %cst_6 : tensor<1x64xi32> loc(#loc768) + %tmp16 = arith.remsi %n, %cst_6 : tensor<1x64xi32> loc(#loc769) + %tmp18 = arith.cmpi ne, %tmp16, %cst_10 : tensor<1x64xi32> loc(#loc770) + %tmp19 = arith.cmpi slt, %tmp16, %cst_10 : tensor<1x64xi32> loc(#loc771) + %tmp22 = arith.andi %tmp18, %tmp19 : tensor<1x64xi1> loc(#loc772) + %tmp23 = arith.addi %tmp16, %cst_6 : tensor<1x64xi32> loc(#loc773) + %tmp24 = arith.select %tmp22, %tmp23, %tmp16 : tensor<1x64xi1>, tensor<1x64xi32> loc(#loc774) + %tmp25 = arith.extsi %tmp24 : tensor<1x64xi32> to tensor<1x64xi64> loc(#loc775) + %tmp26 = arith.cmpi slt, %tmp25, %tmp8 : tensor<1x64xi64> loc(#loc776) + %tmp27 = arith.andi %tmp15, %tmp26 : tensor<1x64xi1> loc(#loc777) + %tmp28 = arith.subi %tmp4_100, %tmp4 : tensor<128x64xi32> loc(#loc778) + %tmp29 = arith.remsi %tmp28, %cst_5 : tensor<128x64xi32> loc(#loc779) + %tmp30 = arith.cmpi ne, %tmp29, %cst_9 : tensor<128x64xi32> loc(#loc780) + %tmp31 = arith.cmpi slt, %tmp29, %cst_9 : tensor<128x64xi32> loc(#loc781) + %tmp33 = arith.andi %tmp30, %tmp31 : tensor<128x64xi1> loc(#loc782) + %tmp34 = arith.addi %tmp29, %cst_5 : tensor<128x64xi32> loc(#loc783) + %tmp35 = arith.select %tmp33, %tmp34, %tmp29 : tensor<128x64xi1>, tensor<128x64xi32> loc(#loc784) + %tmp36 = arith.cmpi eq, %tmp35, %cst_9 : tensor<128x64xi32> loc(#loc785) + %tmp37 = tt.broadcast %tmp27 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc786) + %tmp37_107 = arith.andi %tmp37, %tmp36 : tensor<128x64xi1> loc(#loc786) + %tmp38 = arith.ori %tmp12, %tmp37_107 : tensor<128x64xi1> loc(#loc787) + %post_mod_scores = arith.select %tmp38, %qk_99, %cst_8 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc788) + %post_mod_scores_108 = arith.mulf %post_mod_scores, %cst_7 : tensor<128x64xf32> loc(#loc789) + %p = tt.broadcast %lse_57 : tensor<128x1xf32> -> tensor<128x64xf32> loc(#loc790) + %p_109 = arith.subf %post_mod_scores_108, %p : tensor<128x64xf32> loc(#loc790) + %p_110 = math.exp2 %p_109 : tensor<128x64xf32> loc(#loc791) + %vT = tt.load %vT_ptrs_98 : tensor<128x64x!tt.ptr> loc(#loc890) + %dp = tt.dot %do, %vT, %cst_12, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc793) + %ds = tt.expand_dims %Di_52 {axis = 1 : i32} : tensor<128xf32> -> tensor<128x1xf32> loc(#loc794) + %ds_111 = tt.broadcast %ds : tensor<128x1xf32> -> tensor<128x64xf32> loc(#loc795) + %ds_112 = arith.subf %dp, %ds_111 : tensor<128x64xf32> loc(#loc795) + %ds_113 = arith.mulf %p_110, %ds_112 : tensor<128x64xf32> loc(#loc796) + %ds_114 = arith.select %tmp38, %ds_113, %cst_12 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc797) + %ds_115 = arith.truncf %ds_114 : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc798) + %dq_116 = tt.trans %kT {order = array} : tensor<128x64xbf16> -> tensor<64x128xbf16> loc(#loc799) + %dq_117 = tt.dot %ds_115, %dq_116, %dq_95, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc800) + %cur_block_idx = arith.divsi %start_n, %c2_i32 : i32 loc(#loc801) + %cur_block = tt.addptr %kv_indices, %cur_block_idx : !tt.ptr, i32 loc(#loc802) + %cur_block_118 = tt.load %cur_block evictionPolicy = evict_last : !tt.ptr loc(#loc803) + %next_block = arith.addi %cur_block_idx, %c1_i32 : i32 loc(#loc804) + %next_block_119 = arith.cmpi slt, %next_block, %sparse_kv_num_blocks_59 : i32 loc(#loc805) + %next_block_120 = tt.addptr %cur_block, %c1_i32 : !tt.ptr, i32 loc(#loc806) + %next_block_121 = tt.load %next_block_120, %next_block_119 evictionPolicy = evict_last : !tt.ptr loc(#loc807) + %needs_jump = arith.addi %start_n, %c1_i32 : i32 loc(#loc808) + %needs_jump_122 = arith.remsi %needs_jump, %c2_i32 : i32 loc(#loc809) + %needs_jump_123 = arith.cmpi eq, %needs_jump_122, %c0_i32 : i32 loc(#loc810) + %jump_to_block = arith.subi %next_block_121, %cur_block_118 : i32 loc(#loc811) + %jump_to_block_124 = arith.muli %jump_to_block, %c128_i32 : i32 loc(#loc812) + %jump_to_block_125 = arith.subi %jump_to_block_124, %c64_i32 : i32 loc(#loc813) + %offset = arith.extui %needs_jump_123 : i1 to i32 loc(#loc814) + %offset_126 = arith.muli %jump_to_block_125, %offset : i32 loc(#loc814) + %offset_127 = arith.subi %c1_i32, %offset : i32 loc(#loc815) + %offset_128 = arith.muli %offset_127, %c64_i32 : i32 loc(#loc816) + %offset_129 = arith.addi %offset_126, %offset_128 : i32 loc(#loc817) + %kT_ptrs_130 = arith.muli %offset_129, %c128_i32 : i32 loc(#loc610) + %kT_ptrs_131 = tt.splat %kT_ptrs_130 : i32 -> tensor<128x64xi32> loc(#loc611) + %kT_ptrs_132 = tt.addptr %kT_ptrs_97, %kT_ptrs_131 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc611) + %vT_ptrs_133 = tt.addptr %vT_ptrs_98, %kT_ptrs_131 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc612) + %offs_n2_134 = tt.splat %offset_129 : i32 -> tensor<64xi32> loc(#loc613) + %offs_n2_135 = arith.addi %offs_n2_96, %offs_n2_134 : tensor<64xi32> loc(#loc613) + scf.yield %dq_117, %offs_n2_135, %kT_ptrs_132, %vT_ptrs_133 : tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<128x64x!tt.ptr> loc(#loc614) + } loc(#loc894) + %kv_indices_74 = tt.addptr %arg_FULL_KV_IDX, %sparse_kv_idx_offset_29 : !tt.ptr, i32 loc(#loc453) + %kv_start_75 = tt.load %kv_indices_74 : !tt.ptr loc(#loc454) + %kv_start_76 = arith.muli %kv_start_75, %c128_i32 : i32 loc(#loc455) + %sparse_kv_num_blocks_77 = tt.addptr %arg_FULL_KV_NUM_BLKS, %sparse_kv_num_blks_offset_27 : !tt.ptr, i32 loc(#loc456) + %sparse_kv_num_blocks_78 = tt.load %sparse_kv_num_blocks_77 : !tt.ptr loc(#loc457) + %offs_n2_79 = tt.splat %kv_start_76 : i32 -> tensor<64xi32> loc(#loc458) + %offs_n2_80 = arith.addi %offs_n2_79, %offs_n2 : tensor<64xi32> loc(#loc458) + %kT_ptrs_81 = tt.expand_dims %offs_n2_80 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc615) + %kT_ptrs_82 = arith.muli %kT_ptrs_81, %cst_1 : tensor<1x64xi32> loc(#loc616) + %kT_ptrs_83 = tt.addptr %kT_ptrs_63, %kT_ptrs_82 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> loc(#loc617) + %kT_ptrs_84 = tt.broadcast %kT_ptrs_83 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> loc(#loc618) + %kT_ptrs_85 = tt.addptr %kT_ptrs_84, %kT_ptrs_67 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc618) + %vT_ptrs_86 = tt.addptr %vT_ptrs, %kT_ptrs_82 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> loc(#loc619) + %vT_ptrs_87 = tt.broadcast %vT_ptrs_86 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> loc(#loc620) + %vT_ptrs_88 = tt.addptr %vT_ptrs_87, %kT_ptrs_67 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc620) + %hi_89 = arith.muli %sparse_kv_num_blocks_78, %c2_i32 : i32 loc(#loc621) + %hi_90 = arith.minsi %hi_89, %c32_i32 : i32 loc(#loc622) + %vT_ptrs_91:4 = scf.for %start_n = %c0_i32 to %hi_90 step %c1_i32 iter_args(%dq_95 = %vT_ptrs_73#0, %offs_n2_96 = %offs_n2_80, %kT_ptrs_97 = %kT_ptrs_85, %vT_ptrs_98 = %vT_ptrs_88) -> (tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<128x64x!tt.ptr>) : i32 { + %kT = tt.load %kT_ptrs_97 : tensor<128x64x!tt.ptr> loc(#loc891) + %qk = tt.dot %q, %kT, %cst_12, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc819) + %qk_99 = arith.mulf %qk, %cst_11 : tensor<128x64xf32> loc(#loc820) + %post_mod_scores = arith.mulf %qk_99, %cst_7 : tensor<128x64xf32> loc(#loc821) + %p = tt.broadcast %lse_57 : tensor<128x1xf32> -> tensor<128x64xf32> loc(#loc822) + %p_100 = arith.subf %post_mod_scores, %p : tensor<128x64xf32> loc(#loc822) + %p_101 = math.exp2 %p_100 : tensor<128x64xf32> loc(#loc823) + %vT = tt.load %vT_ptrs_98 : tensor<128x64x!tt.ptr> loc(#loc892) + %dp = tt.dot %do, %vT, %cst_12, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc825) + %ds = tt.expand_dims %Di_52 {axis = 1 : i32} : tensor<128xf32> -> tensor<128x1xf32> loc(#loc826) + %ds_102 = tt.broadcast %ds : tensor<128x1xf32> -> tensor<128x64xf32> loc(#loc827) + %ds_103 = arith.subf %dp, %ds_102 : tensor<128x64xf32> loc(#loc827) + %ds_104 = arith.mulf %p_101, %ds_103 : tensor<128x64xf32> loc(#loc828) + %ds_105 = arith.truncf %ds_104 : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc829) + %dq_106 = tt.trans %kT {order = array} : tensor<128x64xbf16> -> tensor<64x128xbf16> loc(#loc830) + %dq_107 = tt.dot %ds_105, %dq_106, %dq_95, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc831) + %cur_block_idx = arith.divsi %start_n, %c2_i32 : i32 loc(#loc832) + %cur_block = tt.addptr %kv_indices_74, %cur_block_idx : !tt.ptr, i32 loc(#loc833) + %cur_block_108 = tt.load %cur_block evictionPolicy = evict_last : !tt.ptr loc(#loc834) + %next_block = arith.addi %cur_block_idx, %c1_i32 : i32 loc(#loc835) + %next_block_109 = arith.cmpi slt, %next_block, %sparse_kv_num_blocks_78 : i32 loc(#loc836) + %next_block_110 = tt.addptr %cur_block, %c1_i32 : !tt.ptr, i32 loc(#loc837) + %next_block_111 = tt.load %next_block_110, %next_block_109 evictionPolicy = evict_last : !tt.ptr loc(#loc838) + %needs_jump = arith.addi %start_n, %c1_i32 : i32 loc(#loc839) + %needs_jump_112 = arith.remsi %needs_jump, %c2_i32 : i32 loc(#loc840) + %needs_jump_113 = arith.cmpi eq, %needs_jump_112, %c0_i32 : i32 loc(#loc841) + %jump_to_block = arith.subi %next_block_111, %cur_block_108 : i32 loc(#loc842) + %jump_to_block_114 = arith.muli %jump_to_block, %c128_i32 : i32 loc(#loc843) + %jump_to_block_115 = arith.subi %jump_to_block_114, %c64_i32 : i32 loc(#loc844) + %offset = arith.extui %needs_jump_113 : i1 to i32 loc(#loc845) + %offset_116 = arith.muli %jump_to_block_115, %offset : i32 loc(#loc845) + %offset_117 = arith.subi %c1_i32, %offset : i32 loc(#loc846) + %offset_118 = arith.muli %offset_117, %c64_i32 : i32 loc(#loc847) + %offset_119 = arith.addi %offset_116, %offset_118 : i32 loc(#loc848) + %kT_ptrs_120 = arith.muli %offset_119, %c128_i32 : i32 loc(#loc625) + %kT_ptrs_121 = tt.splat %kT_ptrs_120 : i32 -> tensor<128x64xi32> loc(#loc626) + %kT_ptrs_122 = tt.addptr %kT_ptrs_97, %kT_ptrs_121 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc626) + %vT_ptrs_123 = tt.addptr %vT_ptrs_98, %kT_ptrs_121 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc627) + %offs_n2_124 = tt.splat %offset_119 : i32 -> tensor<64xi32> loc(#loc628) + %offs_n2_125 = arith.addi %offs_n2_96, %offs_n2_124 : tensor<64xi32> loc(#loc628) + scf.yield %dq_107, %offs_n2_125, %kT_ptrs_122, %vT_ptrs_123 : tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<128x64x!tt.ptr> loc(#loc629) + } loc(#loc895) + %dq_ptrs = tt.splat %DQ2 : !tt.ptr -> tensor<128x1x!tt.ptr> loc(#loc460) + %dq_ptrs_92 = tt.addptr %dq_ptrs, %ptr_39 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> loc(#loc460) + %dq_ptrs_93 = tt.broadcast %dq_ptrs_92 : tensor<128x1x!tt.ptr> -> tensor<128x128x!tt.ptr> loc(#loc461) + %dq_ptrs_94 = tt.addptr %dq_ptrs_93, %ptr_44 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> loc(#loc461) + %dq = arith.mulf %vT_ptrs_91#0, %cst_16 : tensor<128x128xf32> loc(#loc462) + %1 = arith.truncf %dq : tensor<128x128xf32> to tensor<128x128xbf16> loc(#loc160) + tt.store %dq_ptrs_94, %1 : tensor<128x128x!tt.ptr> loc(#loc160) + } else { + %start_n1 = arith.muli %pid, %c128_i32 : i32 loc(#loc463) + %offs_n1 = tt.splat %start_n1 : i32 -> tensor<128xi32> loc(#loc464) + %offs_n1_25 = arith.addi %offs_n1, %offs_k : tensor<128xi32> loc(#loc464) + %ptr = tt.expand_dims %offs_n1_25 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc630) + %ptr_26 = arith.muli %ptr, %cst_15 : tensor<128x1xi32> loc(#loc631) + %ptr_27 = tt.splat %K : !tt.ptr -> tensor<128x1x!tt.ptr> loc(#loc632) + %ptr_28 = tt.addptr %ptr_27, %ptr_26 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> loc(#loc632) + %ptr_29 = tt.expand_dims %offs_k {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc633) + %ptr_30 = tt.broadcast %ptr_28 : tensor<128x1x!tt.ptr> -> tensor<128x128x!tt.ptr> loc(#loc634) + %ptr_31 = tt.broadcast %ptr_29 : tensor<1x128xi32> -> tensor<128x128xi32> loc(#loc634) + %ptr_32 = tt.addptr %ptr_30, %ptr_31 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> loc(#loc634) + %k = tt.load %ptr_32 : tensor<128x128x!tt.ptr> loc(#loc635) + %ptr_33 = tt.splat %V : !tt.ptr -> tensor<128x1x!tt.ptr> loc(#loc636) + %ptr_34 = tt.addptr %ptr_33, %ptr_26 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> loc(#loc636) + %ptr_35 = tt.broadcast %ptr_34 : tensor<128x1x!tt.ptr> -> tensor<128x128x!tt.ptr> loc(#loc637) + %ptr_36 = tt.addptr %ptr_35, %ptr_31 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> loc(#loc637) + %v = tt.load %ptr_36 : tensor<128x128x!tt.ptr> loc(#loc638) + %dk:2 = scf.for %off_g = %c0_i32 to %c4_i32 step %c1_i32 iter_args(%dv = %cst_13, %dk_46 = %cst_13) -> (tensor<128x128xf32>, tensor<128x128xf32>) : i32 { + %off_hq1 = arith.muli %off_hkv, %c4_i32 : i32 loc(#loc468) + %off_hq1_47 = arith.addi %off_hq1, %off_g : i32 loc(#loc469) + %q_adj1 = arith.muli %off_hq1_47, %c128_i32 : i32 loc(#loc470) + %q_adj1_48 = arith.muli %off_zq, %c8388608_i32 : i32 loc(#loc471) + %q_adj1_49 = arith.addi %q_adj1, %q_adj1_48 : i32 loc(#loc472) + %q_adj1_50 = arith.extsi %q_adj1_49 : i32 to i64 loc(#loc473) + %do_adj1 = arith.muli %off_hq1_47, %c262144_i32 : i32 loc(#loc474) + %do_adj1_51 = arith.addi %do_adj1, %q_adj1_48 : i32 loc(#loc475) + %do_adj1_52 = arith.extsi %do_adj1_51 : i32 to i64 loc(#loc476) + %off_chz1 = arith.muli %off_zq, %c32_i32 : i32 loc(#loc477) + %off_chz1_53 = arith.addi %off_chz1, %off_hq1_47 : i32 loc(#loc478) + %off_chz1_54 = arith.muli %off_chz1_53, %c2048_i32 : i32 loc(#loc479) + %off_chz1_55 = arith.extsi %off_chz1_54 : i32 to i64 loc(#loc480) + %Q1 = tt.addptr %arg_Q, %q_adj1_50 : !tt.ptr, i64 loc(#loc481) + %DO1 = tt.addptr %arg_DO, %do_adj1_52 : !tt.ptr, i64 loc(#loc482) + %LSE1 = tt.addptr %arg_LSE, %off_chz1_55 : !tt.ptr, i64 loc(#loc483) + %DELTA1 = tt.addptr %arg_DELTA, %off_chz1_55 : !tt.ptr, i64 loc(#loc484) + %sparse_q_num_blks_offset = arith.muli %off_zkv, %c16_i32 : i32 loc(#loc485) + %sparse_q_num_blks_offset_56 = arith.addi %sparse_q_num_blks_offset, %pid : i32 loc(#loc486) + %sparse_q_idx_offset = arith.muli %off_zkv, %c256_i32 : i32 loc(#loc487) + %sparse_q_idx_offset_57 = arith.muli %pid, %c16_i32 : i32 loc(#loc488) + %sparse_q_idx_offset_58 = arith.addi %sparse_q_idx_offset, %sparse_q_idx_offset_57 : i32 loc(#loc489) + %q_indices = tt.addptr %arg_Q_IDX, %sparse_q_idx_offset_58 : !tt.ptr, i32 loc(#loc490) + %q_start = tt.load %q_indices : !tt.ptr loc(#loc491) + %q_start_59 = arith.muli %q_start, %c128_i32 : i32 loc(#loc492) + %sparse_q_num_blocks = tt.addptr %arg_Q_NUM_BLKS, %sparse_q_num_blks_offset_56 : !tt.ptr, i32 loc(#loc493) + %sparse_q_num_blocks_60 = tt.load %sparse_q_num_blocks : !tt.ptr loc(#loc494) + %offs_m1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc495) + %offs_m1_61 = tt.splat %q_start_59 : i32 -> tensor<64xi32> loc(#loc496) + %offs_m1_62 = arith.addi %offs_m1_61, %offs_m1 : tensor<64xi32> loc(#loc496) + %qT_ptrs = tt.expand_dims %offs_m1_62 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc640) + %qT_ptrs_63 = arith.muli %qT_ptrs, %cst_0 : tensor<1x64xi32> loc(#loc641) + %qT_ptrs_64 = tt.splat %Q1 : !tt.ptr -> tensor<1x64x!tt.ptr> loc(#loc642) + %qT_ptrs_65 = tt.addptr %qT_ptrs_64, %qT_ptrs_63 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> loc(#loc642) + %qT_ptrs_66 = tt.expand_dims %offs_k {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc643) + %qT_ptrs_67 = tt.broadcast %qT_ptrs_65 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> loc(#loc644) + %qT_ptrs_68 = tt.broadcast %qT_ptrs_66 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc644) + %qT_ptrs_69 = tt.addptr %qT_ptrs_67, %qT_ptrs_68 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc644) + %do_ptrs = tt.expand_dims %offs_m1_62 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc645) + %do_ptrs_70 = arith.muli %do_ptrs, %cst : tensor<64x1xi32> loc(#loc646) + %do_ptrs_71 = tt.splat %DO1 : !tt.ptr -> tensor<64x1x!tt.ptr> loc(#loc647) + %do_ptrs_72 = tt.addptr %do_ptrs_71, %do_ptrs_70 : tensor<64x1x!tt.ptr>, tensor<64x1xi32> loc(#loc647) + %do_ptrs_73 = tt.broadcast %do_ptrs_72 : tensor<64x1x!tt.ptr> -> tensor<64x128x!tt.ptr> loc(#loc648) + %do_ptrs_74 = tt.broadcast %ptr_29 : tensor<1x128xi32> -> tensor<64x128xi32> loc(#loc648) + %do_ptrs_75 = tt.addptr %do_ptrs_73, %do_ptrs_74 : tensor<64x128x!tt.ptr>, tensor<64x128xi32> loc(#loc648) + %hi = arith.muli %sparse_q_num_blocks_60, %c2_i32 : i32 loc(#loc649) + %hi_76 = arith.minsi %hi, %c32_i32 : i32 loc(#loc650) + %do_ptrs_77:5 = scf.for %start_m = %c0_i32 to %hi_76 step %c1_i32 iter_args(%dk_98 = %dk_46, %dv_99 = %dv, %offs_m1_100 = %offs_m1_62, %qT_ptrs_101 = %qT_ptrs_69, %do_ptrs_102 = %do_ptrs_75) -> (tensor<128x128xf32>, tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<64x128x!tt.ptr>) : i32 { + %qT = tt.load %qT_ptrs_101 : tensor<128x64x!tt.ptr> loc(#loc850) + %lse = tt.splat %LSE1 : !tt.ptr -> tensor<64x!tt.ptr> loc(#loc653) + %lse_103 = tt.addptr %lse, %offs_m1_100 : tensor<64x!tt.ptr>, tensor<64xi32> loc(#loc653) + %lse_104 = tt.load %lse_103 : tensor<64x!tt.ptr> loc(#loc654) + %lse_105 = arith.cmpf oeq, %lse_104, %cst_4 : tensor<64xf32> loc(#loc655) + %lse_106 = arith.select %lse_105, %cst_3, %lse_104 : tensor<64xi1>, tensor<64xf32> loc(#loc656) + %qkT = tt.dot %k, %qT, %cst_12, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc657) + %qkT_107 = arith.mulf %qkT, %cst_11 : tensor<128x64xf32> loc(#loc658) + %m = tt.expand_dims %offs_m1_100 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc659) + %tmp44 = tt.broadcast %m : tensor<1x64xi32> -> tensor<128x64xi32> loc(#loc660) + %tmp44_108 = tt.broadcast %ptr : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc660) + %tmp44_109 = arith.cmpi sge, %tmp44, %tmp44_108 : tensor<128x64xi32> loc(#loc660) + %tmp45 = arith.extsi %ptr : tensor<128x1xi32> to tensor<128x1xi64> loc(#loc661) + %tmp47 = tt.addptr %in_ptr16, %off_zq : !tt.ptr, i32 loc(#loc662) + %tmp47_110 = tt.load %tmp47 : !tt.ptr loc(#loc663) + %tmp48 = tt.splat %tmp47_110 : i64 -> tensor<128x1xi64> loc(#loc664) + %tmp48_111 = arith.cmpi slt, %tmp45, %tmp48 : tensor<128x1xi64> loc(#loc664) + %tmp49 = arith.extsi %m : tensor<1x64xi32> to tensor<1x64xi64> loc(#loc665) + %tmp50 = tt.splat %tmp47_110 : i64 -> tensor<1x64xi64> loc(#loc666) + %tmp50_112 = arith.cmpi slt, %tmp49, %tmp50 : tensor<1x64xi64> loc(#loc666) + %tmp51 = tt.broadcast %tmp48_111 : tensor<128x1xi1> -> tensor<128x64xi1> loc(#loc667) + %tmp51_113 = tt.broadcast %tmp50_112 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc667) + %tmp51_114 = arith.andi %tmp51, %tmp51_113 : tensor<128x64xi1> loc(#loc667) + %tmp52 = arith.andi %tmp44_109, %tmp51_114 : tensor<128x64xi1> loc(#loc668) + %tmp55 = arith.cmpi sge, %ptr, %cst_14 : tensor<128x1xi32> loc(#loc669) + %tmp56 = arith.remsi %ptr, %cst_14 : tensor<128x1xi32> loc(#loc670) + %tmp58 = arith.cmpi ne, %tmp56, %cst_2 : tensor<128x1xi32> loc(#loc671) + %tmp59 = arith.cmpi slt, %tmp56, %cst_2 : tensor<128x1xi32> loc(#loc672) + %tmp62 = arith.andi %tmp58, %tmp59 : tensor<128x1xi1> loc(#loc673) + %tmp63 = arith.addi %tmp56, %cst_14 : tensor<128x1xi32> loc(#loc674) + %tmp64 = arith.select %tmp62, %tmp63, %tmp56 : tensor<128x1xi1>, tensor<128x1xi32> loc(#loc675) + %tmp65 = arith.extsi %tmp64 : tensor<128x1xi32> to tensor<128x1xi64> loc(#loc676) + %tmp66 = arith.cmpi slt, %tmp65, %tmp48 : tensor<128x1xi64> loc(#loc677) + %tmp67 = arith.andi %tmp55, %tmp66 : tensor<128x1xi1> loc(#loc678) + %tmp68 = arith.subi %tmp44_108, %tmp44 : tensor<128x64xi32> loc(#loc679) + %tmp69 = arith.remsi %tmp68, %cst_5 : tensor<128x64xi32> loc(#loc680) + %tmp70 = arith.cmpi ne, %tmp69, %cst_9 : tensor<128x64xi32> loc(#loc681) + %tmp71 = arith.cmpi slt, %tmp69, %cst_9 : tensor<128x64xi32> loc(#loc682) + %tmp73 = arith.andi %tmp70, %tmp71 : tensor<128x64xi1> loc(#loc683) + %tmp74 = arith.addi %tmp69, %cst_5 : tensor<128x64xi32> loc(#loc684) + %tmp75 = arith.select %tmp73, %tmp74, %tmp69 : tensor<128x64xi1>, tensor<128x64xi32> loc(#loc685) + %tmp76 = arith.cmpi eq, %tmp75, %cst_9 : tensor<128x64xi32> loc(#loc686) + %tmp77 = tt.broadcast %tmp67 : tensor<128x1xi1> -> tensor<128x64xi1> loc(#loc687) + %tmp77_115 = arith.andi %tmp77, %tmp76 : tensor<128x64xi1> loc(#loc687) + %tmp78 = arith.ori %tmp52, %tmp77_115 : tensor<128x64xi1> loc(#loc688) + %post_mod_scores = arith.select %tmp78, %qkT_107, %cst_8 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc689) + %post_mod_scores_116 = arith.mulf %post_mod_scores, %cst_7 : tensor<128x64xf32> loc(#loc690) + %pT = tt.expand_dims %lse_106 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> loc(#loc691) + %pT_117 = tt.broadcast %pT : tensor<1x64xf32> -> tensor<128x64xf32> loc(#loc692) + %pT_118 = arith.subf %post_mod_scores_116, %pT_117 : tensor<128x64xf32> loc(#loc692) + %pT_119 = math.exp2 %pT_118 : tensor<128x64xf32> loc(#loc693) + %do = tt.load %do_ptrs_102 : tensor<64x128x!tt.ptr> loc(#loc851) + %dv_120 = arith.truncf %pT_119 : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc695) + %dv_121 = tt.dot %dv_120, %do, %dv_99, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc696) + %Di = tt.splat %DELTA1 : !tt.ptr -> tensor<64x!tt.ptr> loc(#loc697) + %Di_122 = tt.addptr %Di, %offs_m1_100 : tensor<64x!tt.ptr>, tensor<64xi32> loc(#loc697) + %Di_123 = tt.load %Di_122 : tensor<64x!tt.ptr> loc(#loc698) + %dpT = tt.trans %do {order = array} : tensor<64x128xbf16> -> tensor<128x64xbf16> loc(#loc699) + %dpT_124 = tt.dot %v, %dpT, %cst_12, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc700) + %dsT = tt.expand_dims %Di_123 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> loc(#loc701) + %dsT_125 = tt.broadcast %dsT : tensor<1x64xf32> -> tensor<128x64xf32> loc(#loc702) + %dsT_126 = arith.subf %dpT_124, %dsT_125 : tensor<128x64xf32> loc(#loc702) + %dsT_127 = arith.mulf %pT_119, %dsT_126 : tensor<128x64xf32> loc(#loc703) + %dsT_128 = arith.select %tmp78, %dsT_127, %cst_12 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc704) + %dk_129 = arith.truncf %dsT_128 : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc705) + %dk_130 = tt.trans %qT {order = array} : tensor<128x64xbf16> -> tensor<64x128xbf16> loc(#loc706) + %dk_131 = tt.dot %dk_129, %dk_130, %dk_98, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc707) + %cur_block_idx = arith.divsi %start_m, %c2_i32 : i32 loc(#loc852) + %cur_block = tt.addptr %q_indices, %cur_block_idx : !tt.ptr, i32 loc(#loc853) + %cur_block_132 = tt.load %cur_block evictionPolicy = evict_last : !tt.ptr loc(#loc854) + %next_block = arith.addi %cur_block_idx, %c1_i32 : i32 loc(#loc855) + %next_block_133 = arith.cmpi slt, %next_block, %sparse_q_num_blocks_60 : i32 loc(#loc856) + %next_block_134 = tt.addptr %cur_block, %c1_i32 : !tt.ptr, i32 loc(#loc857) + %next_block_135 = tt.load %next_block_134, %next_block_133 evictionPolicy = evict_last : !tt.ptr loc(#loc858) + %needs_jump = arith.addi %start_m, %c1_i32 : i32 loc(#loc859) + %needs_jump_136 = arith.remsi %needs_jump, %c2_i32 : i32 loc(#loc860) + %needs_jump_137 = arith.cmpi eq, %needs_jump_136, %c0_i32 : i32 loc(#loc861) + %jump_to_block = arith.subi %next_block_135, %cur_block_132 : i32 loc(#loc862) + %jump_to_block_138 = arith.muli %jump_to_block, %c128_i32 : i32 loc(#loc863) + %jump_to_block_139 = arith.subi %jump_to_block_138, %c64_i32 : i32 loc(#loc864) + %offset = arith.extui %needs_jump_137 : i1 to i32 loc(#loc865) + %offset_140 = arith.muli %jump_to_block_139, %offset : i32 loc(#loc865) + %offset_141 = arith.subi %c1_i32, %offset : i32 loc(#loc866) + %offset_142 = arith.muli %offset_141, %c64_i32 : i32 loc(#loc867) + %offset_143 = arith.addi %offset_140, %offset_142 : i32 loc(#loc868) + %qT_ptrs_144 = arith.muli %offset_143, %c4096_i32 : i32 loc(#loc709) + %qT_ptrs_145 = tt.splat %qT_ptrs_144 : i32 -> tensor<128x64xi32> loc(#loc710) + %qT_ptrs_146 = tt.addptr %qT_ptrs_101, %qT_ptrs_145 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc710) + %do_ptrs_147 = arith.muli %offset_143, %c128_i32 : i32 loc(#loc711) + %do_ptrs_148 = tt.splat %do_ptrs_147 : i32 -> tensor<64x128xi32> loc(#loc712) + %do_ptrs_149 = tt.addptr %do_ptrs_102, %do_ptrs_148 : tensor<64x128x!tt.ptr>, tensor<64x128xi32> loc(#loc712) + %offs_m1_150 = tt.splat %offset_143 : i32 -> tensor<64xi32> loc(#loc713) + %offs_m1_151 = arith.addi %offs_m1_100, %offs_m1_150 : tensor<64xi32> loc(#loc713) + scf.yield %dk_131, %dv_121, %offs_m1_151, %qT_ptrs_146, %do_ptrs_149 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<64x128x!tt.ptr> loc(#loc572) + } loc(#loc897) + %q_indices_78 = tt.addptr %arg_FULL_Q_IDX, %sparse_q_idx_offset_58 : !tt.ptr, i32 loc(#loc573) + %q_start_79 = tt.load %q_indices_78 : !tt.ptr loc(#loc574) + %q_start_80 = arith.muli %q_start_79, %c128_i32 : i32 loc(#loc575) + %sparse_q_num_blocks_81 = tt.addptr %arg_FULL_Q_NUM_BLKS, %sparse_q_num_blks_offset_56 : !tt.ptr, i32 loc(#loc576) + %sparse_q_num_blocks_82 = tt.load %sparse_q_num_blocks_81 : !tt.ptr loc(#loc577) + %offs_m1_83 = tt.splat %q_start_80 : i32 -> tensor<64xi32> loc(#loc578) + %offs_m1_84 = arith.addi %offs_m1_83, %offs_m1 : tensor<64xi32> loc(#loc578) + %qT_ptrs_85 = tt.expand_dims %offs_m1_84 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc714) + %qT_ptrs_86 = arith.muli %qT_ptrs_85, %cst_0 : tensor<1x64xi32> loc(#loc715) + %qT_ptrs_87 = tt.addptr %qT_ptrs_64, %qT_ptrs_86 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> loc(#loc716) + %qT_ptrs_88 = tt.broadcast %qT_ptrs_87 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> loc(#loc717) + %qT_ptrs_89 = tt.addptr %qT_ptrs_88, %qT_ptrs_68 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc717) + %do_ptrs_90 = tt.expand_dims %offs_m1_84 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc718) + %do_ptrs_91 = arith.muli %do_ptrs_90, %cst : tensor<64x1xi32> loc(#loc719) + %do_ptrs_92 = tt.addptr %do_ptrs_71, %do_ptrs_91 : tensor<64x1x!tt.ptr>, tensor<64x1xi32> loc(#loc720) + %do_ptrs_93 = tt.broadcast %do_ptrs_92 : tensor<64x1x!tt.ptr> -> tensor<64x128x!tt.ptr> loc(#loc721) + %do_ptrs_94 = tt.addptr %do_ptrs_93, %do_ptrs_74 : tensor<64x128x!tt.ptr>, tensor<64x128xi32> loc(#loc721) + %hi_95 = arith.muli %sparse_q_num_blocks_82, %c2_i32 : i32 loc(#loc722) + %hi_96 = arith.minsi %hi_95, %c32_i32 : i32 loc(#loc723) + %do_ptrs_97:5 = scf.for %start_m = %c0_i32 to %hi_96 step %c1_i32 iter_args(%dk_98 = %do_ptrs_77#0, %dv_99 = %do_ptrs_77#1, %offs_m1_100 = %offs_m1_84, %qT_ptrs_101 = %qT_ptrs_89, %do_ptrs_102 = %do_ptrs_94) -> (tensor<128x128xf32>, tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<64x128x!tt.ptr>) : i32 { + %qT = tt.load %qT_ptrs_101 : tensor<128x64x!tt.ptr> loc(#loc869) + %lse = tt.splat %LSE1 : !tt.ptr -> tensor<64x!tt.ptr> loc(#loc725) + %lse_103 = tt.addptr %lse, %offs_m1_100 : tensor<64x!tt.ptr>, tensor<64xi32> loc(#loc725) + %lse_104 = tt.load %lse_103 : tensor<64x!tt.ptr> loc(#loc726) + %lse_105 = arith.cmpf oeq, %lse_104, %cst_4 : tensor<64xf32> loc(#loc727) + %lse_106 = arith.select %lse_105, %cst_3, %lse_104 : tensor<64xi1>, tensor<64xf32> loc(#loc728) + %qkT = tt.dot %k, %qT, %cst_12, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc729) + %qkT_107 = arith.mulf %qkT, %cst_11 : tensor<128x64xf32> loc(#loc730) + %post_mod_scores = arith.mulf %qkT_107, %cst_7 : tensor<128x64xf32> loc(#loc731) + %pT = tt.expand_dims %lse_106 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> loc(#loc732) + %pT_108 = tt.broadcast %pT : tensor<1x64xf32> -> tensor<128x64xf32> loc(#loc733) + %pT_109 = arith.subf %post_mod_scores, %pT_108 : tensor<128x64xf32> loc(#loc733) + %pT_110 = math.exp2 %pT_109 : tensor<128x64xf32> loc(#loc734) + %do = tt.load %do_ptrs_102 : tensor<64x128x!tt.ptr> loc(#loc870) + %dv_111 = arith.truncf %pT_110 : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc736) + %dv_112 = tt.dot %dv_111, %do, %dv_99, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc737) + %Di = tt.splat %DELTA1 : !tt.ptr -> tensor<64x!tt.ptr> loc(#loc738) + %Di_113 = tt.addptr %Di, %offs_m1_100 : tensor<64x!tt.ptr>, tensor<64xi32> loc(#loc738) + %Di_114 = tt.load %Di_113 : tensor<64x!tt.ptr> loc(#loc739) + %dpT = tt.trans %do {order = array} : tensor<64x128xbf16> -> tensor<128x64xbf16> loc(#loc740) + %dpT_115 = tt.dot %v, %dpT, %cst_12, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc741) + %dsT = tt.expand_dims %Di_114 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> loc(#loc742) + %dsT_116 = tt.broadcast %dsT : tensor<1x64xf32> -> tensor<128x64xf32> loc(#loc743) + %dsT_117 = arith.subf %dpT_115, %dsT_116 : tensor<128x64xf32> loc(#loc743) + %dsT_118 = arith.mulf %pT_110, %dsT_117 : tensor<128x64xf32> loc(#loc744) + %dk_119 = arith.truncf %dsT_118 : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc745) + %dk_120 = tt.trans %qT {order = array} : tensor<128x64xbf16> -> tensor<64x128xbf16> loc(#loc746) + %dk_121 = tt.dot %dk_119, %dk_120, %dk_98, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc747) + %cur_block_idx = arith.divsi %start_m, %c2_i32 : i32 loc(#loc871) + %cur_block = tt.addptr %q_indices_78, %cur_block_idx : !tt.ptr, i32 loc(#loc872) + %cur_block_122 = tt.load %cur_block evictionPolicy = evict_last : !tt.ptr loc(#loc873) + %next_block = arith.addi %cur_block_idx, %c1_i32 : i32 loc(#loc874) + %next_block_123 = arith.cmpi slt, %next_block, %sparse_q_num_blocks_82 : i32 loc(#loc875) + %next_block_124 = tt.addptr %cur_block, %c1_i32 : !tt.ptr, i32 loc(#loc876) + %next_block_125 = tt.load %next_block_124, %next_block_123 evictionPolicy = evict_last : !tt.ptr loc(#loc877) + %needs_jump = arith.addi %start_m, %c1_i32 : i32 loc(#loc878) + %needs_jump_126 = arith.remsi %needs_jump, %c2_i32 : i32 loc(#loc879) + %needs_jump_127 = arith.cmpi eq, %needs_jump_126, %c0_i32 : i32 loc(#loc880) + %jump_to_block = arith.subi %next_block_125, %cur_block_122 : i32 loc(#loc881) + %jump_to_block_128 = arith.muli %jump_to_block, %c128_i32 : i32 loc(#loc882) + %jump_to_block_129 = arith.subi %jump_to_block_128, %c64_i32 : i32 loc(#loc883) + %offset = arith.extui %needs_jump_127 : i1 to i32 loc(#loc884) + %offset_130 = arith.muli %jump_to_block_129, %offset : i32 loc(#loc884) + %offset_131 = arith.subi %c1_i32, %offset : i32 loc(#loc885) + %offset_132 = arith.muli %offset_131, %c64_i32 : i32 loc(#loc886) + %offset_133 = arith.addi %offset_130, %offset_132 : i32 loc(#loc887) + %qT_ptrs_134 = arith.muli %offset_133, %c4096_i32 : i32 loc(#loc749) + %qT_ptrs_135 = tt.splat %qT_ptrs_134 : i32 -> tensor<128x64xi32> loc(#loc750) + %qT_ptrs_136 = tt.addptr %qT_ptrs_101, %qT_ptrs_135 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc750) + %do_ptrs_137 = arith.muli %offset_133, %c128_i32 : i32 loc(#loc751) + %do_ptrs_138 = tt.splat %do_ptrs_137 : i32 -> tensor<64x128xi32> loc(#loc752) + %do_ptrs_139 = tt.addptr %do_ptrs_102, %do_ptrs_138 : tensor<64x128x!tt.ptr>, tensor<64x128xi32> loc(#loc752) + %offs_m1_140 = tt.splat %offset_133 : i32 -> tensor<64xi32> loc(#loc753) + %offs_m1_141 = arith.addi %offs_m1_100, %offs_m1_140 : tensor<64xi32> loc(#loc753) + scf.yield %dk_121, %dv_112, %offs_m1_141, %qT_ptrs_136, %do_ptrs_139 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<64x128x!tt.ptr> loc(#loc580) + } loc(#loc898) + scf.yield %do_ptrs_97#1, %do_ptrs_97#0 : tensor<128x128xf32>, tensor<128x128xf32> loc(#loc279) + } loc(#loc639) + %dv_ptrs = tt.splat %DV : !tt.ptr -> tensor<128x1x!tt.ptr> loc(#loc581) + %dv_ptrs_37 = tt.addptr %dv_ptrs, %ptr_26 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> loc(#loc581) + %dv_ptrs_38 = tt.broadcast %dv_ptrs_37 : tensor<128x1x!tt.ptr> -> tensor<128x128x!tt.ptr> loc(#loc582) + %dv_ptrs_39 = tt.addptr %dv_ptrs_38, %ptr_31 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> loc(#loc582) + %1 = arith.truncf %dk#0 : tensor<128x128xf32> to tensor<128x128xbf16> loc(#loc282) + tt.store %dv_ptrs_39, %1 : tensor<128x128x!tt.ptr> loc(#loc282) + %dk_40 = arith.mulf %dk#1, %cst_16 : tensor<128x128xf32> loc(#loc583) + %mask = arith.cmpi slt, %ptr, %cst_14 : tensor<128x1xi32> loc(#loc584) + %xindex = tt.broadcast %ptr_26 : tensor<128x1xi32> -> tensor<128x128xi32> loc(#loc585) + %xindex_41 = arith.addi %ptr_31, %xindex : tensor<128x128xi32> loc(#loc585) + %xindex_42 = tt.splat %k_adj : i32 -> tensor<128x128xi32> loc(#loc586) + %xindex_43 = arith.addi %xindex_41, %xindex_42 : tensor<128x128xi32> loc(#loc586) + %xindex_44 = tt.splat %dv_adj : i32 -> tensor<128x128xi32> loc(#loc587) + %xindex_45 = arith.addi %xindex_43, %xindex_44 : tensor<128x128xi32> loc(#loc587) + %2 = tt.splat %out_ptr0 : !tt.ptr -> tensor<128x128x!tt.ptr> loc(#loc288) + %3 = tt.addptr %2, %xindex_45 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> loc(#loc288) + %4 = tt.broadcast %mask : tensor<128x1xi1> -> tensor<128x128xi1> loc(#loc289) + %5 = arith.truncf %dk_40 : tensor<128x128xf32> to tensor<128x128xbf16> loc(#loc289) + tt.store %3, %5, %4 : tensor<128x128x!tt.ptr> loc(#loc289) + } loc(#loc18) + tt.return loc(#loc290) + } loc(#loc) +} loc(#loc) +#loc1 = loc(unknown) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":111:24) +#loc3 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":115:27) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":116:28) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":117:23) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":124:25) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":124:47) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":124:35) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":124:59) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":128:50) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":128:37) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":128:61) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":131:9) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":132:9) +#loc15 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":133:10) +#loc16 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":136:26) +#loc17 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":139:14) +#loc18 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":139:7) +#loc19 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":140:24) +#loc20 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":144:29) +#loc21 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":144:54) +#loc22 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":144:44) +#loc23 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":145:35) +#loc24 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":154:55) +#loc25 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":154:78) +#loc26 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":155:50) +#loc27 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":155:83) +#loc28 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":155:68) +#loc29 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":158:30) +#loc30 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":158:52) +#loc31 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":158:40) +#loc32 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":158:63) +#loc33 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":159:32) +#loc34 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":159:42) +#loc35 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":159:66) +#loc36 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":161:30) +#loc37 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":161:35) +#loc38 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":161:46) +#loc39 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":161:56) +#loc40 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":163:17) +#loc41 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":164:19) +#loc42 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":167:19) +#loc43 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":168:21) +#loc44 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":169:25) +#loc45 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":174:36) +#loc46 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":175:29) +#loc47 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":825:27) +#loc48 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":178:107) +#loc49 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":825:38) +#loc50 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":825:20) +#loc51 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":825:56) +#loc52 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":825:49) +#loc53 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":835:23) +#loc54 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":179:111) +#loc55 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":185:34) +#loc56 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":185:25) +#loc57 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":186:33) +#loc58 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":186:26) +#loc59 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":190:30) +#loc60 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":190:50) +#loc61 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":191:18) +#loc62 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":195:30) +#loc63 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":196:27) +#loc64 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":196:41) +#loc65 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":197:53) +#loc66 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":197:39) +#loc67 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":199:42) +#loc68 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":199:29) +#loc69 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":390:26) +#loc70 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":207:12) +#loc71 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":390:37) +#loc72 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":390:18) +#loc73 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":390:56) +#loc74 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":390:49) +#loc75 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":391:18) +#loc76 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":391:49) +#loc77 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":395:43) +#loc78 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":395:63) +#loc79 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":397:28) +#loc80 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":458:105) +#loc81 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":405:12) +#loc82 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":459:19) +#loc83 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":461:14) +#loc84 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":464:36) +#loc85 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":482:23) +#loc86 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":483:23) +#loc87 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":485:34) +#loc88 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":485:23) +#loc89 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":486:22) +#loc90 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":487:23) +#loc91 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":488:23) +#loc92 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":489:23) +#loc93 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":490:23) +#loc94 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":493:24) +#loc95 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":494:24) +#loc96 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":496:25) +#loc97 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":497:92) +#loc98 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":500:24) +#loc99 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":501:24) +#loc100 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":502:39) +#loc101 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":503:25) +#loc102 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":504:24) +#loc103 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":505:24) +#loc104 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":506:23) +#loc105 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":507:25) +#loc106 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":508:25) +#loc107 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":509:92) +#loc108 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":511:24) +#loc109 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":512:24) +#loc110 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":513:39) +#loc111 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":514:25) +#loc112 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":515:24) +#loc113 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":516:24) +#loc114 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":521:69) +#loc115 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":524:27) +#loc116 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":525:39) +#loc117 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":525:21) +#loc118 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":528:104) +#loc119 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":530:20) +#loc120 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":531:22) +#loc121 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":531:19) +#loc122 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":531:14) +#loc123 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":549:43) +#loc124 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":551:15) +#loc125 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":553:30) +#loc126 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":553:21) +#loc127 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":788:33) +#loc128 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":411:64) +#loc129 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":789:38) +#loc130 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":789:24) +#loc131 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":790:109) +#loc132 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":790:113) +#loc133 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":790:55) +#loc134 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":790:25) +#loc135 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":791:30) +#loc136 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":791:35) +#loc137 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":791:60) +#loc138 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":792:34) +#loc139 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":792:48) +#loc140 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":792:63) +#loc141 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":793:29) +#loc142 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":793:47) +#loc143 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":793:61) +#loc144 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":793:42) +#loc145 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":414:28) +#loc146 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":414:19) +#loc147 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":415:19) +#loc148 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":417:19) +#loc149 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":417:8) +#loc150 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":214:39) +#loc151 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":215:31) +#loc152 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":215:45) +#loc153 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":216:62) +#loc154 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":216:43) +#loc155 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":218:33) +#loc156 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":226:16) +#loc157 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":231:24) +#loc158 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":231:56) +#loc159 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":232:14) +#loc160 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":234:30) +#loc161 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":252:25) +#loc162 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":253:29) +#loc163 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":256:107) +#loc164 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":257:107) +#loc165 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":262:30) +#loc166 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":263:32) +#loc167 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":263:51) +#loc168 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":266:34) +#loc169 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":266:56) +#loc170 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":266:44) +#loc171 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":266:67) +#loc172 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":267:36) +#loc173 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":267:46) +#loc174 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":267:70) +#loc175 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":269:34) +#loc176 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":269:39) +#loc177 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":269:50) +#loc178 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":269:60) +#loc179 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":271:21) +#loc180 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":272:23) +#loc181 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":275:25) +#loc182 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":276:29) +#loc183 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":281:58) +#loc184 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":281:80) +#loc185 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":282:53) +#loc186 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":282:81) +#loc187 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":282:70) +#loc188 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":286:32) +#loc189 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":287:30) +#loc190 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":287:43) +#loc191 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":288:55) +#loc192 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":288:42) +#loc193 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":290:45) +#loc194 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":290:32) +#loc195 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":601:26) +#loc196 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":298:16) +#loc197 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":601:37) +#loc198 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":601:18) +#loc199 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":601:56) +#loc200 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":601:49) +#loc201 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":602:27) +#loc202 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":602:38) +#loc203 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":602:19) +#loc204 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":602:51) +#loc205 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":608:42) +#loc206 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":608:61) +#loc207 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":610:28) +#loc208 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":669:105) +#loc209 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":618:12) +#loc210 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":672:28) +#loc211 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":672:22) +#loc212 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":675:26) +#loc213 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":675:46) +#loc214 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":676:20) +#loc215 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":678:15) +#loc216 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":680:36) +#loc217 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":698:25) +#loc218 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":699:25) +#loc219 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":701:35) +#loc220 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":701:24) +#loc221 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":702:24) +#loc222 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":703:25) +#loc223 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":704:24) +#loc224 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":705:24) +#loc225 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":706:24) +#loc226 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":709:25) +#loc227 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":710:25) +#loc228 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":712:25) +#loc229 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":713:92) +#loc230 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":716:24) +#loc231 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":717:24) +#loc232 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":718:39) +#loc233 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":719:25) +#loc234 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":720:24) +#loc235 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":721:24) +#loc236 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":722:24) +#loc237 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":723:25) +#loc238 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":724:25) +#loc239 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":725:92) +#loc240 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":727:24) +#loc241 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":728:24) +#loc242 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":729:39) +#loc243 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":730:25) +#loc244 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":731:24) +#loc245 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":732:24) +#loc246 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":736:69) +#loc247 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":739:27) +#loc248 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":740:44) +#loc249 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":740:40) +#loc250 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":740:22) +#loc251 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":741:99) +#loc252 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":744:24) +#loc253 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":744:43) +#loc254 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":746:29) +#loc255 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":746:21) +#loc256 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":750:29) +#loc257 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":750:20) +#loc258 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":751:25) +#loc259 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":751:22) +#loc260 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":751:16) +#loc261 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":773:45) +#loc262 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":775:24) +#loc263 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":775:52) +#loc264 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":775:43) +#loc265 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":623:62) +#loc266 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":626:28) +#loc267 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":626:19) +#loc268 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":627:28) +#loc269 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":627:19) +#loc270 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":628:19) +#loc271 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":628:8) +#loc272 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":306:41) +#loc273 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":307:34) +#loc274 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":307:47) +#loc275 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":308:64) +#loc276 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":308:46) +#loc277 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":310:36) +#loc278 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":318:20) +#loc279 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":303:12) +#loc280 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":323:23) +#loc281 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":323:55) +#loc282 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":330:30) +#loc283 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":334:14) +#loc284 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":337:29) +#loc285 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":344:27) +#loc286 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":344:41) +#loc287 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":344:58) +#loc288 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":345:29) +#loc289 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":345:69) +#loc290 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bm/cbmx7oaxptcsxwpex6gfoa6ejqiyihkp7hpmjiunqdckofo5tqvz.py":139:4) +#loc309 = loc("pid"(#loc2)) +#loc310 = loc("off_zq"(#loc3)) +#loc311 = loc("off_hkv"(#loc4)) +#loc312 = loc("off_zkv"(#loc5)) +#loc313 = loc("k_adj"(#loc6)) +#loc314 = loc("k_adj"(#loc7)) +#loc315 = loc("k_adj"(#loc8)) +#loc316 = loc("k_adj"(#loc9)) +#loc317 = loc("dv_adj"(#loc10)) +#loc318 = loc("dv_adj"(#loc11)) +#loc319 = loc("dv_adj"(#loc12)) +#loc320 = loc("K"(#loc13)) +#loc321 = loc("V"(#loc14)) +#loc322 = loc("DV"(#loc15)) +#loc323 = loc("offs_k"(#loc16)) +#loc324 = loc("off_pid"(#loc19)) +#loc325 = loc("off_hq2"(#loc20)) +#loc326 = loc("off_hq2"(#loc21)) +#loc327 = loc("off_hq2"(#loc22)) +#loc328 = loc("start_m2_block"(#loc23)) +#loc329 = loc("sparse_kv_num_blks_offset"(#loc24)) +#loc330 = loc("sparse_kv_num_blks_offset"(#loc25)) +#loc331 = loc("sparse_kv_idx_offset"(#loc26)) +#loc332 = loc("sparse_kv_idx_offset"(#loc27)) +#loc333 = loc("sparse_kv_idx_offset"(#loc28)) +#loc334 = loc("q_adj2"(#loc29)) +#loc335 = loc("q_adj2"(#loc30)) +#loc336 = loc("q_adj2"(#loc31)) +#loc337 = loc("q_adj2"(#loc32)) +#loc338 = loc("do_adj2"(#loc33)) +#loc339 = loc("do_adj2"(#loc34)) +#loc340 = loc("do_adj2"(#loc35)) +#loc341 = loc("off_chz2"(#loc36)) +#loc342 = loc("off_chz2"(#loc37)) +#loc343 = loc("off_chz2"(#loc38)) +#loc344 = loc("off_chz2"(#loc39)) +#loc345 = loc("Q2"(#loc40)) +#loc346 = loc("DO2"(#loc41)) +#loc347 = loc("DQ2"(#loc42)) +#loc348 = loc("LSE2"(#loc43)) +#loc349 = loc("DELTA2"(#loc44)) +#loc350 = loc("start_m2"(#loc45)) +#loc351 = loc("offs_m2"(#loc46)) +#loc352 = loc("ptr"(#loc47)) +#loc353 = loc("q"(#loc48)) +#loc354 = loc("ptr"(#loc49)) +#loc355 = loc("ptr"(#loc50)) +#loc356 = loc("ptr"(#loc51)) +#loc357 = loc("ptr"(#loc52)) +#loc358 = loc("do"(#loc54)) +#loc359 = loc("Di"(#loc55)) +#loc360 = loc("Di"(#loc56)) +#loc361 = loc("lse"(#loc57)) +#loc362 = loc("lse"(#loc58)) +#loc363 = loc("lse"(#loc59)) +#loc364 = loc("lse"(#loc60)) +#loc365 = loc("lse"(#loc61)) +#loc366 = loc("kv_indices"(#loc62)) +#loc367 = loc("kv_start"(#loc63)) +#loc368 = loc("kv_start"(#loc64)) +#loc369 = loc("sparse_kv_num_blocks"(#loc65)) +#loc370 = loc("sparse_kv_num_blocks"(#loc66)) +#loc371 = loc("offs_n2"(#loc67)) +#loc372 = loc("offs_n2"(#loc68)) +#loc373 = loc("kT_ptrs"(#loc69)) +#loc374 = loc("dq"(#loc70)) +#loc375 = loc("kT_ptrs"(#loc71)) +#loc376 = loc("kT_ptrs"(#loc72)) +#loc377 = loc("kT_ptrs"(#loc73)) +#loc378 = loc("kT_ptrs"(#loc74)) +#loc379 = loc("vT_ptrs"(#loc75)) +#loc380 = loc("vT_ptrs"(#loc76)) +#loc381 = loc("hi"(#loc77)) +#loc382 = loc("hi"(#loc78)) +#loc383 = loc("dq"(#loc79)) +#loc384 = loc("kT"(#loc80)) +#loc385 = loc("dq"(#loc81)) +#loc386 = loc("qk"(#loc82)) +#loc387 = loc("qk"(#loc83)) +#loc388 = loc("n"(#loc84)) +#loc389 = loc("tmp4"(#loc85)) +#loc390 = loc("tmp5"(#loc86)) +#loc391 = loc("tmp7"(#loc87)) +#loc392 = loc("tmp7"(#loc88)) +#loc393 = loc("tmp8"(#loc89)) +#loc394 = loc("tmp9"(#loc90)) +#loc395 = loc("tmp10"(#loc91)) +#loc396 = loc("tmp11"(#loc92)) +#loc397 = loc("tmp12"(#loc93)) +#loc398 = loc("tmp15"(#loc94)) +#loc399 = loc("tmp16"(#loc95)) +#loc400 = loc("tmp18"(#loc96)) +#loc401 = loc("tmp19"(#loc97)) +#loc402 = loc("tmp22"(#loc98)) +#loc403 = loc("tmp23"(#loc99)) +#loc404 = loc("tmp24"(#loc100)) +#loc405 = loc("tmp25"(#loc101)) +#loc406 = loc("tmp26"(#loc102)) +#loc407 = loc("tmp27"(#loc103)) +#loc408 = loc("tmp28"(#loc104)) +#loc409 = loc("tmp29"(#loc105)) +#loc410 = loc("tmp30"(#loc106)) +#loc411 = loc("tmp31"(#loc107)) +#loc412 = loc("tmp33"(#loc108)) +#loc413 = loc("tmp34"(#loc109)) +#loc414 = loc("tmp35"(#loc110)) +#loc415 = loc("tmp36"(#loc111)) +#loc416 = loc("tmp37"(#loc112)) +#loc417 = loc("tmp38"(#loc113)) +#loc418 = loc("post_mod_scores"(#loc114)) +#loc419 = loc("post_mod_scores"(#loc115)) +#loc420 = loc("p"(#loc116)) +#loc421 = loc("p"(#loc117)) +#loc422 = loc("vT"(#loc118)) +#loc423 = loc("dp"(#loc119)) +#loc424 = loc("ds"(#loc120)) +#loc425 = loc("ds"(#loc121)) +#loc426 = loc("ds"(#loc122)) +#loc427 = loc("ds"(#loc123)) +#loc428 = loc("ds"(#loc124)) +#loc429 = loc("dq"(#loc125)) +#loc430 = loc("dq"(#loc126)) +#loc431 = loc("cur_block_idx"(#loc127)) +#loc432 = loc("offset"(#loc128)) +#loc433 = loc("cur_block"(#loc129)) +#loc434 = loc("cur_block"(#loc130)) +#loc435 = loc("next_block"(#loc131)) +#loc436 = loc("next_block"(#loc132)) +#loc437 = loc("next_block"(#loc133)) +#loc438 = loc("next_block"(#loc134)) +#loc439 = loc("needs_jump"(#loc135)) +#loc440 = loc("needs_jump"(#loc136)) +#loc441 = loc("needs_jump"(#loc137)) +#loc442 = loc("jump_to_block"(#loc138)) +#loc443 = loc("jump_to_block"(#loc139)) +#loc444 = loc("jump_to_block"(#loc140)) +#loc445 = loc("offset"(#loc141)) +#loc446 = loc("offset"(#loc142)) +#loc447 = loc("offset"(#loc143)) +#loc448 = loc("offset"(#loc144)) +#loc449 = loc("kT_ptrs"(#loc145)) +#loc450 = loc("kT_ptrs"(#loc146)) +#loc451 = loc("vT_ptrs"(#loc147)) +#loc452 = loc("offs_n2"(#loc148)) +#loc453 = loc("kv_indices"(#loc150)) +#loc454 = loc("kv_start"(#loc151)) +#loc455 = loc("kv_start"(#loc152)) +#loc456 = loc("sparse_kv_num_blocks"(#loc153)) +#loc457 = loc("sparse_kv_num_blocks"(#loc154)) +#loc458 = loc("offs_n2"(#loc155)) +#loc459 = loc("dq"(#loc156)) +#loc460 = loc("dq_ptrs"(#loc157)) +#loc461 = loc("dq_ptrs"(#loc158)) +#loc462 = loc("dq"(#loc159)) +#loc463 = loc("start_n1"(#loc161)) +#loc464 = loc("offs_n1"(#loc162)) +#loc465 = loc("k"(#loc163)) +#loc466 = loc("v"(#loc164)) +#loc467 = loc("dv"(#loc165)) +#loc468 = loc("off_hq1"(#loc166)) +#loc469 = loc("off_hq1"(#loc167)) +#loc470 = loc("q_adj1"(#loc168)) +#loc471 = loc("q_adj1"(#loc169)) +#loc472 = loc("q_adj1"(#loc170)) +#loc473 = loc("q_adj1"(#loc171)) +#loc474 = loc("do_adj1"(#loc172)) +#loc475 = loc("do_adj1"(#loc173)) +#loc476 = loc("do_adj1"(#loc174)) +#loc477 = loc("off_chz1"(#loc175)) +#loc478 = loc("off_chz1"(#loc176)) +#loc479 = loc("off_chz1"(#loc177)) +#loc480 = loc("off_chz1"(#loc178)) +#loc481 = loc("Q1"(#loc179)) +#loc482 = loc("DO1"(#loc180)) +#loc483 = loc("LSE1"(#loc181)) +#loc484 = loc("DELTA1"(#loc182)) +#loc485 = loc("sparse_q_num_blks_offset"(#loc183)) +#loc486 = loc("sparse_q_num_blks_offset"(#loc184)) +#loc487 = loc("sparse_q_idx_offset"(#loc185)) +#loc488 = loc("sparse_q_idx_offset"(#loc186)) +#loc489 = loc("sparse_q_idx_offset"(#loc187)) +#loc490 = loc("q_indices"(#loc188)) +#loc491 = loc("q_start"(#loc189)) +#loc492 = loc("q_start"(#loc190)) +#loc493 = loc("sparse_q_num_blocks"(#loc191)) +#loc494 = loc("sparse_q_num_blocks"(#loc192)) +#loc495 = loc("offs_m1"(#loc193)) +#loc496 = loc("offs_m1"(#loc194)) +#loc497 = loc("qT_ptrs"(#loc195)) +#loc498 = loc("qT_ptrs"(#loc197)) +#loc499 = loc("qT_ptrs"(#loc198)) +#loc500 = loc("qT_ptrs"(#loc199)) +#loc501 = loc("qT_ptrs"(#loc200)) +#loc502 = loc("do_ptrs"(#loc201)) +#loc503 = loc("do_ptrs"(#loc202)) +#loc504 = loc("do_ptrs"(#loc203)) +#loc505 = loc("do_ptrs"(#loc204)) +#loc506 = loc("hi"(#loc205)) +#loc507 = loc("hi"(#loc206)) +#loc508 = loc("dk"(#loc207)) +#loc509 = loc("qT"(#loc208)) +#loc510 = loc(callsite(#loc209 at #loc196)) +#loc511 = loc("lse"(#loc210)) +#loc512 = loc("lse"(#loc211)) +#loc513 = loc("lse"(#loc212)) +#loc514 = loc("lse"(#loc213)) +#loc515 = loc("qkT"(#loc214)) +#loc516 = loc("qkT"(#loc215)) +#loc517 = loc("m"(#loc216)) +#loc518 = loc("tmp44"(#loc217)) +#loc519 = loc("tmp45"(#loc218)) +#loc520 = loc("tmp47"(#loc219)) +#loc521 = loc("tmp47"(#loc220)) +#loc522 = loc("tmp48"(#loc221)) +#loc523 = loc("tmp49"(#loc222)) +#loc524 = loc("tmp50"(#loc223)) +#loc525 = loc("tmp51"(#loc224)) +#loc526 = loc("tmp52"(#loc225)) +#loc527 = loc("tmp55"(#loc226)) +#loc528 = loc("tmp56"(#loc227)) +#loc529 = loc("tmp58"(#loc228)) +#loc530 = loc("tmp59"(#loc229)) +#loc531 = loc("tmp62"(#loc230)) +#loc532 = loc("tmp63"(#loc231)) +#loc533 = loc("tmp64"(#loc232)) +#loc534 = loc("tmp65"(#loc233)) +#loc535 = loc("tmp66"(#loc234)) +#loc536 = loc("tmp67"(#loc235)) +#loc537 = loc("tmp68"(#loc236)) +#loc538 = loc("tmp69"(#loc237)) +#loc539 = loc("tmp70"(#loc238)) +#loc540 = loc("tmp71"(#loc239)) +#loc541 = loc("tmp73"(#loc240)) +#loc542 = loc("tmp74"(#loc241)) +#loc543 = loc("tmp75"(#loc242)) +#loc544 = loc("tmp76"(#loc243)) +#loc545 = loc("tmp77"(#loc244)) +#loc546 = loc("tmp78"(#loc245)) +#loc547 = loc("post_mod_scores"(#loc246)) +#loc548 = loc("post_mod_scores"(#loc247)) +#loc549 = loc("pT"(#loc248)) +#loc550 = loc("pT"(#loc249)) +#loc551 = loc("pT"(#loc250)) +#loc552 = loc("do"(#loc251)) +#loc553 = loc("dv"(#loc252)) +#loc554 = loc("dv"(#loc253)) +#loc555 = loc("Di"(#loc254)) +#loc556 = loc("Di"(#loc255)) +#loc557 = loc("dpT"(#loc256)) +#loc558 = loc("dpT"(#loc257)) +#loc559 = loc("dsT"(#loc258)) +#loc560 = loc("dsT"(#loc259)) +#loc561 = loc("dsT"(#loc260)) +#loc562 = loc("dsT"(#loc261)) +#loc563 = loc("dk"(#loc262)) +#loc564 = loc("dk"(#loc263)) +#loc565 = loc("dk"(#loc264)) +#loc566 = loc("offset"(#loc265)) +#loc567 = loc("qT_ptrs"(#loc266)) +#loc568 = loc("qT_ptrs"(#loc267)) +#loc569 = loc("do_ptrs"(#loc268)) +#loc570 = loc("do_ptrs"(#loc269)) +#loc571 = loc("offs_m1"(#loc270)) +#loc572 = loc(callsite(#loc271 at #loc196)) +#loc573 = loc("q_indices"(#loc272)) +#loc574 = loc("q_start"(#loc273)) +#loc575 = loc("q_start"(#loc274)) +#loc576 = loc("sparse_q_num_blocks"(#loc275)) +#loc577 = loc("sparse_q_num_blocks"(#loc276)) +#loc578 = loc("offs_m1"(#loc277)) +#loc579 = loc(callsite(#loc209 at #loc278)) +#loc580 = loc(callsite(#loc271 at #loc278)) +#loc581 = loc("dv_ptrs"(#loc280)) +#loc582 = loc("dv_ptrs"(#loc281)) +#loc583 = loc("dk"(#loc283)) +#loc584 = loc("mask"(#loc284)) +#loc585 = loc("xindex"(#loc285)) +#loc586 = loc("xindex"(#loc286)) +#loc587 = loc("xindex"(#loc287)) +#loc588 = loc(callsite(#loc352 at #loc353)) +#loc589 = loc(callsite(#loc354 at #loc353)) +#loc590 = loc(callsite(#loc355 at #loc353)) +#loc591 = loc(callsite(#loc356 at #loc353)) +#loc592 = loc(callsite(#loc357 at #loc353)) +#loc593 = loc(callsite(#loc53 at #loc353)) +#loc594 = loc(callsite(#loc354 at #loc358)) +#loc595 = loc(callsite(#loc355 at #loc358)) +#loc596 = loc(callsite(#loc357 at #loc358)) +#loc597 = loc(callsite(#loc53 at #loc358)) +#loc598 = loc(callsite(#loc373 at #loc374)) +#loc599 = loc(callsite(#loc375 at #loc374)) +#loc600 = loc(callsite(#loc376 at #loc374)) +#loc601 = loc(callsite(#loc377 at #loc374)) +#loc602 = loc(callsite(#loc378 at #loc374)) +#loc603 = loc(callsite(#loc379 at #loc374)) +#loc604 = loc(callsite(#loc380 at #loc374)) +#loc605 = loc(callsite(#loc381 at #loc374)) +#loc606 = loc(callsite(#loc382 at #loc374)) +#loc607 = loc("offs_n2"(#loc383)) +#loc608 = loc(callsite(#loc385 at #loc374)) +#loc609 = loc(callsite(#loc432 at #loc374)) +#loc610 = loc(callsite(#loc449 at #loc374)) +#loc611 = loc(callsite(#loc450 at #loc374)) +#loc612 = loc(callsite(#loc451 at #loc374)) +#loc613 = loc(callsite(#loc452 at #loc374)) +#loc614 = loc(callsite(#loc149 at #loc374)) +#loc615 = loc(callsite(#loc373 at #loc459)) +#loc616 = loc(callsite(#loc375 at #loc459)) +#loc617 = loc(callsite(#loc376 at #loc459)) +#loc618 = loc(callsite(#loc378 at #loc459)) +#loc619 = loc(callsite(#loc379 at #loc459)) +#loc620 = loc(callsite(#loc380 at #loc459)) +#loc621 = loc(callsite(#loc381 at #loc459)) +#loc622 = loc(callsite(#loc382 at #loc459)) +#loc623 = loc(callsite(#loc385 at #loc459)) +#loc624 = loc(callsite(#loc432 at #loc459)) +#loc625 = loc(callsite(#loc449 at #loc459)) +#loc626 = loc(callsite(#loc450 at #loc459)) +#loc627 = loc(callsite(#loc451 at #loc459)) +#loc628 = loc(callsite(#loc452 at #loc459)) +#loc629 = loc(callsite(#loc149 at #loc459)) +#loc630 = loc(callsite(#loc352 at #loc465)) +#loc631 = loc(callsite(#loc354 at #loc465)) +#loc632 = loc(callsite(#loc355 at #loc465)) +#loc633 = loc(callsite(#loc356 at #loc465)) +#loc634 = loc(callsite(#loc357 at #loc465)) +#loc635 = loc(callsite(#loc53 at #loc465)) +#loc636 = loc(callsite(#loc355 at #loc466)) +#loc637 = loc(callsite(#loc357 at #loc466)) +#loc638 = loc(callsite(#loc53 at #loc466)) +#loc639 = loc("dk"(#loc467)) +#loc640 = loc(callsite(#loc497 at #loc196)) +#loc641 = loc(callsite(#loc498 at #loc196)) +#loc642 = loc(callsite(#loc499 at #loc196)) +#loc643 = loc(callsite(#loc500 at #loc196)) +#loc644 = loc(callsite(#loc501 at #loc196)) +#loc645 = loc(callsite(#loc502 at #loc196)) +#loc646 = loc(callsite(#loc503 at #loc196)) +#loc647 = loc(callsite(#loc504 at #loc196)) +#loc648 = loc(callsite(#loc505 at #loc196)) +#loc649 = loc(callsite(#loc506 at #loc196)) +#loc650 = loc(callsite(#loc507 at #loc196)) +#loc651 = loc("dv"(#loc508)) +#loc652 = loc(callsite(#loc509 at #loc510)) +#loc653 = loc(callsite(#loc511 at #loc510)) +#loc654 = loc(callsite(#loc512 at #loc510)) +#loc655 = loc(callsite(#loc513 at #loc510)) +#loc656 = loc(callsite(#loc514 at #loc510)) +#loc657 = loc(callsite(#loc515 at #loc510)) +#loc658 = loc(callsite(#loc516 at #loc510)) +#loc659 = loc(callsite(#loc517 at #loc510)) +#loc660 = loc(callsite(#loc518 at #loc510)) +#loc661 = loc(callsite(#loc519 at #loc510)) +#loc662 = loc(callsite(#loc520 at #loc510)) +#loc663 = loc(callsite(#loc521 at #loc510)) +#loc664 = loc(callsite(#loc522 at #loc510)) +#loc665 = loc(callsite(#loc523 at #loc510)) +#loc666 = loc(callsite(#loc524 at #loc510)) +#loc667 = loc(callsite(#loc525 at #loc510)) +#loc668 = loc(callsite(#loc526 at #loc510)) +#loc669 = loc(callsite(#loc527 at #loc510)) +#loc670 = loc(callsite(#loc528 at #loc510)) +#loc671 = loc(callsite(#loc529 at #loc510)) +#loc672 = loc(callsite(#loc530 at #loc510)) +#loc673 = loc(callsite(#loc531 at #loc510)) +#loc674 = loc(callsite(#loc532 at #loc510)) +#loc675 = loc(callsite(#loc533 at #loc510)) +#loc676 = loc(callsite(#loc534 at #loc510)) +#loc677 = loc(callsite(#loc535 at #loc510)) +#loc678 = loc(callsite(#loc536 at #loc510)) +#loc679 = loc(callsite(#loc537 at #loc510)) +#loc680 = loc(callsite(#loc538 at #loc510)) +#loc681 = loc(callsite(#loc539 at #loc510)) +#loc682 = loc(callsite(#loc540 at #loc510)) +#loc683 = loc(callsite(#loc541 at #loc510)) +#loc684 = loc(callsite(#loc542 at #loc510)) +#loc685 = loc(callsite(#loc543 at #loc510)) +#loc686 = loc(callsite(#loc544 at #loc510)) +#loc687 = loc(callsite(#loc545 at #loc510)) +#loc688 = loc(callsite(#loc546 at #loc510)) +#loc689 = loc(callsite(#loc547 at #loc510)) +#loc690 = loc(callsite(#loc548 at #loc510)) +#loc691 = loc(callsite(#loc549 at #loc510)) +#loc692 = loc(callsite(#loc550 at #loc510)) +#loc693 = loc(callsite(#loc551 at #loc510)) +#loc694 = loc(callsite(#loc552 at #loc510)) +#loc695 = loc(callsite(#loc553 at #loc510)) +#loc696 = loc(callsite(#loc554 at #loc510)) +#loc697 = loc(callsite(#loc555 at #loc510)) +#loc698 = loc(callsite(#loc556 at #loc510)) +#loc699 = loc(callsite(#loc557 at #loc510)) +#loc700 = loc(callsite(#loc558 at #loc510)) +#loc701 = loc(callsite(#loc559 at #loc510)) +#loc702 = loc(callsite(#loc560 at #loc510)) +#loc703 = loc(callsite(#loc561 at #loc510)) +#loc704 = loc(callsite(#loc562 at #loc510)) +#loc705 = loc(callsite(#loc563 at #loc510)) +#loc706 = loc(callsite(#loc564 at #loc510)) +#loc707 = loc(callsite(#loc565 at #loc510)) +#loc708 = loc(callsite(#loc566 at #loc196)) +#loc709 = loc(callsite(#loc567 at #loc196)) +#loc710 = loc(callsite(#loc568 at #loc196)) +#loc711 = loc(callsite(#loc569 at #loc196)) +#loc712 = loc(callsite(#loc570 at #loc196)) +#loc713 = loc(callsite(#loc571 at #loc196)) +#loc714 = loc(callsite(#loc497 at #loc278)) +#loc715 = loc(callsite(#loc498 at #loc278)) +#loc716 = loc(callsite(#loc499 at #loc278)) +#loc717 = loc(callsite(#loc501 at #loc278)) +#loc718 = loc(callsite(#loc502 at #loc278)) +#loc719 = loc(callsite(#loc503 at #loc278)) +#loc720 = loc(callsite(#loc504 at #loc278)) +#loc721 = loc(callsite(#loc505 at #loc278)) +#loc722 = loc(callsite(#loc506 at #loc278)) +#loc723 = loc(callsite(#loc507 at #loc278)) +#loc724 = loc(callsite(#loc509 at #loc579)) +#loc725 = loc(callsite(#loc511 at #loc579)) +#loc726 = loc(callsite(#loc512 at #loc579)) +#loc727 = loc(callsite(#loc513 at #loc579)) +#loc728 = loc(callsite(#loc514 at #loc579)) +#loc729 = loc(callsite(#loc515 at #loc579)) +#loc730 = loc(callsite(#loc516 at #loc579)) +#loc731 = loc(callsite(#loc548 at #loc579)) +#loc732 = loc(callsite(#loc549 at #loc579)) +#loc733 = loc(callsite(#loc550 at #loc579)) +#loc734 = loc(callsite(#loc551 at #loc579)) +#loc735 = loc(callsite(#loc552 at #loc579)) +#loc736 = loc(callsite(#loc553 at #loc579)) +#loc737 = loc(callsite(#loc554 at #loc579)) +#loc738 = loc(callsite(#loc555 at #loc579)) +#loc739 = loc(callsite(#loc556 at #loc579)) +#loc740 = loc(callsite(#loc557 at #loc579)) +#loc741 = loc(callsite(#loc558 at #loc579)) +#loc742 = loc(callsite(#loc559 at #loc579)) +#loc743 = loc(callsite(#loc560 at #loc579)) +#loc744 = loc(callsite(#loc561 at #loc579)) +#loc745 = loc(callsite(#loc563 at #loc579)) +#loc746 = loc(callsite(#loc564 at #loc579)) +#loc747 = loc(callsite(#loc565 at #loc579)) +#loc748 = loc(callsite(#loc566 at #loc278)) +#loc749 = loc(callsite(#loc567 at #loc278)) +#loc750 = loc(callsite(#loc568 at #loc278)) +#loc751 = loc(callsite(#loc569 at #loc278)) +#loc752 = loc(callsite(#loc570 at #loc278)) +#loc753 = loc(callsite(#loc571 at #loc278)) +#loc754 = loc("kT_ptrs"(#loc607)) +#loc755 = loc(callsite(#loc384 at #loc608)) +#loc756 = loc(callsite(#loc386 at #loc608)) +#loc757 = loc(callsite(#loc387 at #loc608)) +#loc758 = loc(callsite(#loc388 at #loc608)) +#loc759 = loc(callsite(#loc389 at #loc608)) +#loc760 = loc(callsite(#loc390 at #loc608)) +#loc761 = loc(callsite(#loc391 at #loc608)) +#loc762 = loc(callsite(#loc392 at #loc608)) +#loc763 = loc(callsite(#loc393 at #loc608)) +#loc764 = loc(callsite(#loc394 at #loc608)) +#loc765 = loc(callsite(#loc395 at #loc608)) +#loc766 = loc(callsite(#loc396 at #loc608)) +#loc767 = loc(callsite(#loc397 at #loc608)) +#loc768 = loc(callsite(#loc398 at #loc608)) +#loc769 = loc(callsite(#loc399 at #loc608)) +#loc770 = loc(callsite(#loc400 at #loc608)) +#loc771 = loc(callsite(#loc401 at #loc608)) +#loc772 = loc(callsite(#loc402 at #loc608)) +#loc773 = loc(callsite(#loc403 at #loc608)) +#loc774 = loc(callsite(#loc404 at #loc608)) +#loc775 = loc(callsite(#loc405 at #loc608)) +#loc776 = loc(callsite(#loc406 at #loc608)) +#loc777 = loc(callsite(#loc407 at #loc608)) +#loc778 = loc(callsite(#loc408 at #loc608)) +#loc779 = loc(callsite(#loc409 at #loc608)) +#loc780 = loc(callsite(#loc410 at #loc608)) +#loc781 = loc(callsite(#loc411 at #loc608)) +#loc782 = loc(callsite(#loc412 at #loc608)) +#loc783 = loc(callsite(#loc413 at #loc608)) +#loc784 = loc(callsite(#loc414 at #loc608)) +#loc785 = loc(callsite(#loc415 at #loc608)) +#loc786 = loc(callsite(#loc416 at #loc608)) +#loc787 = loc(callsite(#loc417 at #loc608)) +#loc788 = loc(callsite(#loc418 at #loc608)) +#loc789 = loc(callsite(#loc419 at #loc608)) +#loc790 = loc(callsite(#loc420 at #loc608)) +#loc791 = loc(callsite(#loc421 at #loc608)) +#loc792 = loc(callsite(#loc422 at #loc608)) +#loc793 = loc(callsite(#loc423 at #loc608)) +#loc794 = loc(callsite(#loc424 at #loc608)) +#loc795 = loc(callsite(#loc425 at #loc608)) +#loc796 = loc(callsite(#loc426 at #loc608)) +#loc797 = loc(callsite(#loc427 at #loc608)) +#loc798 = loc(callsite(#loc428 at #loc608)) +#loc799 = loc(callsite(#loc429 at #loc608)) +#loc800 = loc(callsite(#loc430 at #loc608)) +#loc801 = loc(callsite(#loc431 at #loc609)) +#loc802 = loc(callsite(#loc433 at #loc609)) +#loc803 = loc(callsite(#loc434 at #loc609)) +#loc804 = loc(callsite(#loc435 at #loc609)) +#loc805 = loc(callsite(#loc436 at #loc609)) +#loc806 = loc(callsite(#loc437 at #loc609)) +#loc807 = loc(callsite(#loc438 at #loc609)) +#loc808 = loc(callsite(#loc439 at #loc609)) +#loc809 = loc(callsite(#loc440 at #loc609)) +#loc810 = loc(callsite(#loc441 at #loc609)) +#loc811 = loc(callsite(#loc442 at #loc609)) +#loc812 = loc(callsite(#loc443 at #loc609)) +#loc813 = loc(callsite(#loc444 at #loc609)) +#loc814 = loc(callsite(#loc445 at #loc609)) +#loc815 = loc(callsite(#loc446 at #loc609)) +#loc816 = loc(callsite(#loc447 at #loc609)) +#loc817 = loc(callsite(#loc448 at #loc609)) +#loc818 = loc(callsite(#loc384 at #loc623)) +#loc819 = loc(callsite(#loc386 at #loc623)) +#loc820 = loc(callsite(#loc387 at #loc623)) +#loc821 = loc(callsite(#loc419 at #loc623)) +#loc822 = loc(callsite(#loc420 at #loc623)) +#loc823 = loc(callsite(#loc421 at #loc623)) +#loc824 = loc(callsite(#loc422 at #loc623)) +#loc825 = loc(callsite(#loc423 at #loc623)) +#loc826 = loc(callsite(#loc424 at #loc623)) +#loc827 = loc(callsite(#loc425 at #loc623)) +#loc828 = loc(callsite(#loc426 at #loc623)) +#loc829 = loc(callsite(#loc428 at #loc623)) +#loc830 = loc(callsite(#loc429 at #loc623)) +#loc831 = loc(callsite(#loc430 at #loc623)) +#loc832 = loc(callsite(#loc431 at #loc624)) +#loc833 = loc(callsite(#loc433 at #loc624)) +#loc834 = loc(callsite(#loc434 at #loc624)) +#loc835 = loc(callsite(#loc435 at #loc624)) +#loc836 = loc(callsite(#loc436 at #loc624)) +#loc837 = loc(callsite(#loc437 at #loc624)) +#loc838 = loc(callsite(#loc438 at #loc624)) +#loc839 = loc(callsite(#loc439 at #loc624)) +#loc840 = loc(callsite(#loc440 at #loc624)) +#loc841 = loc(callsite(#loc441 at #loc624)) +#loc842 = loc(callsite(#loc442 at #loc624)) +#loc843 = loc(callsite(#loc443 at #loc624)) +#loc844 = loc(callsite(#loc444 at #loc624)) +#loc845 = loc(callsite(#loc445 at #loc624)) +#loc846 = loc(callsite(#loc446 at #loc624)) +#loc847 = loc(callsite(#loc447 at #loc624)) +#loc848 = loc(callsite(#loc448 at #loc624)) +#loc849 = loc("offs_m1"(#loc651)) +#loc850 = loc(callsite(#loc53 at #loc652)) +#loc851 = loc(callsite(#loc53 at #loc694)) +#loc852 = loc(callsite(#loc431 at #loc708)) +#loc853 = loc(callsite(#loc433 at #loc708)) +#loc854 = loc(callsite(#loc434 at #loc708)) +#loc855 = loc(callsite(#loc435 at #loc708)) +#loc856 = loc(callsite(#loc436 at #loc708)) +#loc857 = loc(callsite(#loc437 at #loc708)) +#loc858 = loc(callsite(#loc438 at #loc708)) +#loc859 = loc(callsite(#loc439 at #loc708)) +#loc860 = loc(callsite(#loc440 at #loc708)) +#loc861 = loc(callsite(#loc441 at #loc708)) +#loc862 = loc(callsite(#loc442 at #loc708)) +#loc863 = loc(callsite(#loc443 at #loc708)) +#loc864 = loc(callsite(#loc444 at #loc708)) +#loc865 = loc(callsite(#loc445 at #loc708)) +#loc866 = loc(callsite(#loc446 at #loc708)) +#loc867 = loc(callsite(#loc447 at #loc708)) +#loc868 = loc(callsite(#loc448 at #loc708)) +#loc869 = loc(callsite(#loc53 at #loc724)) +#loc870 = loc(callsite(#loc53 at #loc735)) +#loc871 = loc(callsite(#loc431 at #loc748)) +#loc872 = loc(callsite(#loc433 at #loc748)) +#loc873 = loc(callsite(#loc434 at #loc748)) +#loc874 = loc(callsite(#loc435 at #loc748)) +#loc875 = loc(callsite(#loc436 at #loc748)) +#loc876 = loc(callsite(#loc437 at #loc748)) +#loc877 = loc(callsite(#loc438 at #loc748)) +#loc878 = loc(callsite(#loc439 at #loc748)) +#loc879 = loc(callsite(#loc440 at #loc748)) +#loc880 = loc(callsite(#loc441 at #loc748)) +#loc881 = loc(callsite(#loc442 at #loc748)) +#loc882 = loc(callsite(#loc443 at #loc748)) +#loc883 = loc(callsite(#loc444 at #loc748)) +#loc884 = loc(callsite(#loc445 at #loc748)) +#loc885 = loc(callsite(#loc446 at #loc748)) +#loc886 = loc(callsite(#loc447 at #loc748)) +#loc887 = loc(callsite(#loc448 at #loc748)) +#loc888 = loc("vT_ptrs"(#loc754)) +#loc889 = loc(callsite(#loc53 at #loc755)) +#loc890 = loc(callsite(#loc53 at #loc792)) +#loc891 = loc(callsite(#loc53 at #loc818)) +#loc892 = loc(callsite(#loc53 at #loc824)) +#loc893 = loc("qT_ptrs"(#loc849)) +#loc894 = loc(callsite(#loc888 at #loc374)) +#loc895 = loc(callsite(#loc888 at #loc459)) +#loc896 = loc("do_ptrs"(#loc893)) +#loc897 = loc(callsite(#loc896 at #loc196)) +#loc898 = loc(callsite(#loc896 at #loc278)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/JOVMEF5UB3XPAO2GMKSDIGWCQSZP2YWR3UL465XFLBODYHCLBU5A/__grp__triton_red_fused__to_copy_clone_slice_sum_transpose_5.json b/SpecForge-ext/cache/compiled_kernels/triton/0/JOVMEF5UB3XPAO2GMKSDIGWCQSZP2YWR3UL465XFLBODYHCLBU5A/__grp__triton_red_fused__to_copy_clone_slice_sum_transpose_5.json new file mode 100644 index 0000000000000000000000000000000000000000..7ed8f000a3ef35c5e4d64d848df01c6bb83a0394 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/JOVMEF5UB3XPAO2GMKSDIGWCQSZP2YWR3UL465XFLBODYHCLBU5A/__grp__triton_red_fused__to_copy_clone_slice_sum_transpose_5.json @@ -0,0 +1 @@ +{"child_paths": {"triton_red_fused__to_copy_clone_slice_sum_transpose_5.source": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/JOVMEF5UB3XPAO2GMKSDIGWCQSZP2YWR3UL465XFLBODYHCLBU5A/triton_red_fused__to_copy_clone_slice_sum_transpose_5.source", "triton_red_fused__to_copy_clone_slice_sum_transpose_5.ttir": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/JOVMEF5UB3XPAO2GMKSDIGWCQSZP2YWR3UL465XFLBODYHCLBU5A/triton_red_fused__to_copy_clone_slice_sum_transpose_5.ttir", "triton_red_fused__to_copy_clone_slice_sum_transpose_5.ttgir": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/JOVMEF5UB3XPAO2GMKSDIGWCQSZP2YWR3UL465XFLBODYHCLBU5A/triton_red_fused__to_copy_clone_slice_sum_transpose_5.ttgir", "triton_red_fused__to_copy_clone_slice_sum_transpose_5.llir": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/JOVMEF5UB3XPAO2GMKSDIGWCQSZP2YWR3UL465XFLBODYHCLBU5A/triton_red_fused__to_copy_clone_slice_sum_transpose_5.llir", "triton_red_fused__to_copy_clone_slice_sum_transpose_5.ptx": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/JOVMEF5UB3XPAO2GMKSDIGWCQSZP2YWR3UL465XFLBODYHCLBU5A/triton_red_fused__to_copy_clone_slice_sum_transpose_5.ptx", "triton_red_fused__to_copy_clone_slice_sum_transpose_5.cubin": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/JOVMEF5UB3XPAO2GMKSDIGWCQSZP2YWR3UL465XFLBODYHCLBU5A/triton_red_fused__to_copy_clone_slice_sum_transpose_5.cubin", "triton_red_fused__to_copy_clone_slice_sum_transpose_5.json": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/JOVMEF5UB3XPAO2GMKSDIGWCQSZP2YWR3UL465XFLBODYHCLBU5A/triton_red_fused__to_copy_clone_slice_sum_transpose_5.json"}} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/JOVMEF5UB3XPAO2GMKSDIGWCQSZP2YWR3UL465XFLBODYHCLBU5A/triton_red_fused__to_copy_clone_slice_sum_transpose_5.cubin b/SpecForge-ext/cache/compiled_kernels/triton/0/JOVMEF5UB3XPAO2GMKSDIGWCQSZP2YWR3UL465XFLBODYHCLBU5A/triton_red_fused__to_copy_clone_slice_sum_transpose_5.cubin new file mode 100644 index 0000000000000000000000000000000000000000..b649e96e56e2df419f373e3d26e0b0c78ae4143e Binary files /dev/null and b/SpecForge-ext/cache/compiled_kernels/triton/0/JOVMEF5UB3XPAO2GMKSDIGWCQSZP2YWR3UL465XFLBODYHCLBU5A/triton_red_fused__to_copy_clone_slice_sum_transpose_5.cubin differ diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/JOVMEF5UB3XPAO2GMKSDIGWCQSZP2YWR3UL465XFLBODYHCLBU5A/triton_red_fused__to_copy_clone_slice_sum_transpose_5.json b/SpecForge-ext/cache/compiled_kernels/triton/0/JOVMEF5UB3XPAO2GMKSDIGWCQSZP2YWR3UL465XFLBODYHCLBU5A/triton_red_fused__to_copy_clone_slice_sum_transpose_5.json new file mode 100644 index 0000000000000000000000000000000000000000..1015e7a98d0783fffaced7eb1f8bbfeb2d07ee63 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/JOVMEF5UB3XPAO2GMKSDIGWCQSZP2YWR3UL465XFLBODYHCLBU5A/triton_red_fused__to_copy_clone_slice_sum_transpose_5.json @@ -0,0 +1 @@ +{"hash": "4baac217b40eeef03b4662a4341ac284b2fd62d1dd17cf76e5585c3c1c4b0d3a", "target": {"backend": "cuda", "arch": 90, "warp_size": 32}, "num_warps": 2, "num_ctas": 1, "num_stages": 1, "warp_size": 32, "maxnreg": null, "cluster_dims": [1, 1, 1], "ptx_version": null, "ptx_options": null, "ir_override": null, "enable_fp_fusion": true, "launch_cooperative_grid": false, "launch_pdl": false, "supported_fp8_dtypes": ["fp8e4b15", "fp8e4nv", "fp8e5"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b15"], "default_dot_input_precision": "tf32", "allowed_dot_input_precisions": ["tf32", "tf32x3", "ieee"], "max_num_imprecise_acc_default": 1073741824, "extern_libs": [["libdevice", "/workspace/specforge/lib/python3.11/site-packages/triton/backends/nvidia/lib/libdevice.10.bc"]], "debug": true, "backend_name": "cuda", "sanitize_overflow": false, "arch": "sm90", "instrumentation_mode": "", "triton_version": "3.5.1", "tensordesc_meta": [], "shared": 128, "tmem_size": 0, "global_scratch_size": 0, "global_scratch_align": 1, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "triton_red_fused__to_copy_clone_slice_sum_transpose_5"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/JOVMEF5UB3XPAO2GMKSDIGWCQSZP2YWR3UL465XFLBODYHCLBU5A/triton_red_fused__to_copy_clone_slice_sum_transpose_5.llir b/SpecForge-ext/cache/compiled_kernels/triton/0/JOVMEF5UB3XPAO2GMKSDIGWCQSZP2YWR3UL465XFLBODYHCLBU5A/triton_red_fused__to_copy_clone_slice_sum_transpose_5.llir new file mode 100644 index 0000000000000000000000000000000000000000..372c166306fe8ee5017c73ee30615348807e98e0 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/JOVMEF5UB3XPAO2GMKSDIGWCQSZP2YWR3UL465XFLBODYHCLBU5A/triton_red_fused__to_copy_clone_slice_sum_transpose_5.llir @@ -0,0 +1,206 @@ +; ModuleID = 'LLVMDialectModule' +source_filename = "LLVMDialectModule" +target datalayout = "e-p3:32:32-p4:32:32-p5:32:32-p6:32:32-p7:32:32-i64:64-i128:128-v16:16-v32:32-n16:32:64" + +@global_smem = external addrspace(3) global [0 x i8], align 16 + +; Function Attrs: nounwind +define ptx_kernel void @triton_red_fused__to_copy_clone_slice_sum_transpose_5(ptr addrspace(1) %0, ptr addrspace(1) %1, i64 %2, i64 %3, i32 %4, i32 %5, ptr addrspace(1) readnone captures(none) %6, ptr addrspace(1) readnone captures(none) %7) local_unnamed_addr #0 !dbg !4 { + %9 = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !dbg !7 + %10 = shl i32 %9, 3, !dbg !8 + %11 = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !dbg !9 + %12 = and i32 %11, 7, !dbg !9 + %13 = or disjoint i32 %10, %12, !dbg !10 + %14 = icmp slt i32 %13, %4, !dbg !11 + %.fr = freeze i1 %14 + %15 = lshr i32 %11, 3, !dbg !12 + %16 = and i32 %15, 7, !dbg !12 + %17 = or disjoint i32 %16, 8, !dbg !12 + %18 = sext i32 %13 to i64, !dbg !13 + %.frozen = freeze i64 %2, !dbg !14 + %19 = sdiv i64 %18, %.frozen, !dbg !14 + %20 = mul i64 %19, %.frozen, !dbg !13 + %.decomposed = sub i64 %18, %20, !dbg !13 + %21 = icmp sgt i32 %5, 0, !dbg !15 + br i1 %21, label %.lr.ph, label %._crit_edge, !dbg !15 + +.lr.ph: ; preds = %8 + %22 = mul i64 %3, %2, !dbg !16 + %23 = mul i64 %22, %19, !dbg !17 + %24 = getelementptr i32, ptr addrspace(1) %0, i64 %.decomposed + %invariant.gep = getelementptr i32, ptr addrspace(1) %24, i64 %23, !dbg !15 + br i1 %.fr, label %.lr.ph.split, label %.lr.ph.split.us + +.lr.ph.split.us: ; preds = %.lr.ph, %.lr.ph.split.us + %25 = phi i32 [ %36, %.lr.ph.split.us ], [ 0, %.lr.ph ] + %26 = or disjoint i32 %25, %16, !dbg !18 + %27 = or disjoint i32 %17, %25, !dbg !18 + %28 = sext i32 %26 to i64, !dbg !19 + %29 = sext i32 %27 to i64, !dbg !19 + %30 = mul i64 %2, %28, !dbg !19 + %31 = mul i64 %2, %29, !dbg !19 + %gep.us = getelementptr i32, ptr addrspace(1) %invariant.gep, i64 %30, !dbg !20 + %gep5.us = getelementptr i32, ptr addrspace(1) %invariant.gep, i64 %31, !dbg !20 + %32 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #4, !dbg !21 + %33 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b32 { $0 }, [ $1 + 0 ], $2;", "=r,l,l,b"(ptr addrspace(1) %gep.us, i64 %32, i1 false) #4, !dbg !21 + %34 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #4, !dbg !21 + %35 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b32 { $0 }, [ $1 + 0 ], $2;", "=r,l,l,b"(ptr addrspace(1) %gep5.us, i64 %34, i1 false) #4, !dbg !21 + %36 = add i32 %25, 16, !dbg !15 + %37 = icmp slt i32 %36, %5, !dbg !15 + br i1 %37, label %.lr.ph.split.us, label %._crit_edge, !dbg !15 + +.lr.ph.split: ; preds = %.lr.ph, %.lr.ph.split + %38 = phi i64 [ %53, %.lr.ph.split ], [ 0, %.lr.ph ] + %39 = phi i64 [ %55, %.lr.ph.split ], [ 0, %.lr.ph ] + %40 = phi i32 [ %56, %.lr.ph.split ], [ 0, %.lr.ph ] + %41 = or disjoint i32 %40, %16, !dbg !18 + %42 = or disjoint i32 %17, %40, !dbg !18 + %43 = icmp slt i32 %41, %5, !dbg !22 + %44 = icmp slt i32 %42, %5, !dbg !22 + %45 = sext i32 %41 to i64, !dbg !19 + %46 = sext i32 %42 to i64, !dbg !19 + %47 = mul i64 %2, %45, !dbg !19 + %48 = mul i64 %2, %46, !dbg !19 + %gep = getelementptr i32, ptr addrspace(1) %invariant.gep, i64 %47, !dbg !20 + %gep5 = getelementptr i32, ptr addrspace(1) %invariant.gep, i64 %48, !dbg !20 + %49 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #4, !dbg !21 + %50 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b32 { $0 }, [ $1 + 0 ], $2;", "=r,l,l,b"(ptr addrspace(1) %gep, i64 %49, i1 %43) #4, !dbg !21 + %51 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #4, !dbg !21 + %52 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b32 { $0 }, [ $1 + 0 ], $2;", "=r,l,l,b"(ptr addrspace(1) %gep5, i64 %51, i1 %44) #4, !dbg !21 + %narrow = select i1 %43, i32 %50, i32 0, !dbg !23 + %spec.select = sext i32 %narrow to i64, !dbg !23 + %53 = add i64 %38, %spec.select, !dbg !23 + %narrow6 = select i1 %44, i32 %52, i32 0, !dbg !23 + %54 = sext i32 %narrow6 to i64, !dbg !23 + %55 = add i64 %39, %54, !dbg !23 + %56 = add i32 %40, 16, !dbg !15 + %57 = icmp slt i32 %56, %5, !dbg !15 + br i1 %57, label %.lr.ph.split, label %._crit_edge, !dbg !15 + +._crit_edge: ; preds = %.lr.ph.split.us, %.lr.ph.split, %8 + %58 = phi i64 [ 0, %8 ], [ %53, %.lr.ph.split ], [ 0, %.lr.ph.split.us ], !dbg !24 + %59 = phi i64 [ 0, %8 ], [ %55, %.lr.ph.split ], [ 0, %.lr.ph.split.us ], !dbg !24 + %60 = and i32 %11, 24, !dbg !9 + %61 = lshr i32 %11, 5, !dbg !9 + %62 = add i64 %58, %59, !dbg !25 + %extelt.offset = lshr i64 %62, 32, !dbg !29 + %63 = trunc nuw i64 %extelt.offset to i32, !dbg !29 + %64 = trunc i64 %62 to i32, !dbg !29 + %65 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %64, i32 16, i32 31), !dbg !29 + %66 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %63, i32 16, i32 31), !dbg !29 + %67 = insertelement <2 x i32> poison, i32 %65, i64 0, !dbg !29 + %68 = insertelement <2 x i32> %67, i32 %66, i64 1, !dbg !29 + %69 = bitcast <2 x i32> %68 to i64, !dbg !29 + %70 = add i64 %62, %69, !dbg !25 + %extelt.offset2 = lshr i64 %70, 32, !dbg !29 + %71 = trunc nuw i64 %extelt.offset2 to i32, !dbg !29 + %72 = trunc i64 %70 to i32, !dbg !29 + %73 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %72, i32 8, i32 31), !dbg !29 + %74 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %71, i32 8, i32 31), !dbg !29 + %75 = insertelement <2 x i32> poison, i32 %73, i64 0, !dbg !29 + %76 = insertelement <2 x i32> %75, i32 %74, i64 1, !dbg !29 + %77 = bitcast <2 x i32> %76 to i64, !dbg !29 + %78 = add i64 %70, %77, !dbg !25 + %79 = and i32 %61, 1, !dbg !29 + %80 = icmp eq i32 %60, 0, !dbg !29 + %.idx = shl nuw nsw i32 %12, 4, !dbg !29 + %81 = getelementptr i8, ptr addrspace(3) @global_smem, i32 %.idx, !dbg !29 + %82 = getelementptr i64, ptr addrspace(3) %81, i32 %79, !dbg !29 + %83 = insertelement <1 x i64> poison, i64 %78, i64 0, !dbg !29 + tail call void asm sideeffect "@$2 st.shared.b64 [ $0 + 0 ], $1;", "r,l,b"(ptr addrspace(3) %82, <1 x i64> %83, i1 %80) #4, !dbg !29 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !29 + %84 = icmp samesign ult i32 %11, 16, !dbg !29 + %85 = getelementptr i64, ptr addrspace(3) @global_smem, i32 %11, !dbg !29 + %86 = tail call i64 asm sideeffect "@$2 ld.shared.b64 $0, [ $1 + 0 ];", "=l,r,b"(ptr addrspace(3) %85, i1 %84) #4, !dbg !29 + %extelt.offset3 = lshr i64 %86, 32, !dbg !29 + %87 = trunc nuw i64 %extelt.offset3 to i32, !dbg !29 + %88 = trunc i64 %86 to i32, !dbg !29 + %89 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %88, i32 1, i32 31), !dbg !29 + %90 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %87, i32 1, i32 31), !dbg !29 + %91 = insertelement <2 x i32> poison, i32 %89, i64 0, !dbg !29 + %92 = insertelement <2 x i32> %91, i32 %90, i64 1, !dbg !29 + %93 = bitcast <2 x i32> %92 to i64, !dbg !29 + %94 = add i64 %86, %93, !dbg !25 + %95 = and i32 %11, 1009, !dbg !29 + %96 = icmp eq i32 %95, 0, !dbg !29 + %97 = insertelement <1 x i64> poison, i64 %94, i64 0, !dbg !29 + tail call void asm sideeffect "@$2 st.shared.b64 [ $0 + 0 ], $1;", "r,l,b"(ptr addrspace(3) %85, <1 x i64> %97, i1 %96) #4, !dbg !29 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !29 + %98 = load i64, ptr addrspace(3) %81, align 16, !dbg !29 + %99 = trunc i64 %98 to i32, !dbg !30 + %100 = icmp slt i64 %2, 2, !dbg !31 + %101 = icmp sgt i64 %2, 1, !dbg !32 + %102 = select i1 %101, i64 %2, i64 0, !dbg !33 + %103 = zext i1 %100 to i64, !dbg !34 + %104 = add i64 %102, %103, !dbg !35 + %105 = mul i64 %19, %104, !dbg !36 + %106 = getelementptr i32, ptr addrspace(1) %1, i64 %.decomposed, !dbg !37 + %107 = getelementptr i32, ptr addrspace(1) %106, i64 %105, !dbg !37 + %108 = and i32 %11, 56, !dbg !38 + %109 = icmp eq i32 %108, 0, !dbg !38 + %110 = and i1 %109, %.fr, !dbg !38 + tail call void asm sideeffect "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b"(i32 %99, ptr addrspace(1) %107, i1 %110) #4, !dbg !38 + ret void, !dbg !39 +} + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare noundef range(i32 0, 2147483647) i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #1 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare noundef range(i32 0, 1024) i32 @llvm.nvvm.read.ptx.sreg.tid.x() #1 + +; Function Attrs: convergent nocallback nounwind memory(inaccessiblemem: readwrite) +declare i32 @llvm.nvvm.shfl.sync.bfly.i32(i32, i32, i32, i32) #2 + +; Function Attrs: convergent nocallback nounwind +declare void @llvm.nvvm.barrier.cta.sync.aligned.all(i32) #3 + +attributes #0 = { nounwind "nvvm.reqntid"="64" } +attributes #1 = { mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) } +attributes #2 = { convergent nocallback nounwind memory(inaccessiblemem: readwrite) } +attributes #3 = { convergent nocallback nounwind } +attributes #4 = { nounwind } + +!llvm.dbg.cu = !{!0} +!llvm.module.flags = !{!2, !3} + +!0 = distinct !DICompileUnit(language: DW_LANG_C, file: !1, producer: "triton", isOptimized: true, runtimeVersion: 0, emissionKind: LineTablesOnly) +!1 = !DIFile(filename: "cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py", directory: "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc") +!2 = !{i32 2, !"Debug Info Version", i32 3} +!3 = !{i32 4, !"nvvm-reflect-ftz", i32 1} +!4 = distinct !DISubprogram(name: "triton_red_fused__to_copy_clone_slice_sum_transpose_5", linkageName: "triton_red_fused__to_copy_clone_slice_sum_transpose_5", scope: !1, file: !1, line: 18, type: !5, scopeLine: 18, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !0) +!5 = !DISubroutineType(cc: DW_CC_normal, types: !6) +!6 = !{} +!7 = !DILocation(line: 21, column: 28, scope: !4) +!8 = !DILocation(line: 21, column: 33, scope: !4) +!9 = !DILocation(line: 22, column: 44, scope: !4) +!10 = !DILocation(line: 22, column: 23, scope: !4) +!11 = !DILocation(line: 23, column: 21, scope: !4) +!12 = !DILocation(line: 24, column: 37, scope: !4) +!13 = !DILocation(line: 26, column: 19, scope: !4) +!14 = !DILocation(line: 27, column: 19, scope: !4) +!15 = !DILocation(line: 30, column: 40, scope: !4) +!16 = !DILocation(line: 36, column: 54, scope: !4) +!17 = !DILocation(line: 36, column: 58, scope: !4) +!18 = !DILocation(line: 31, column: 31, scope: !4) +!19 = !DILocation(line: 36, column: 43, scope: !4) +!20 = !DILocation(line: 36, column: 34, scope: !4) +!21 = !DILocation(line: 36, column: 63, scope: !4) +!22 = !DILocation(line: 32, column: 29, scope: !4) +!23 = !DILocation(line: 40, column: 48, scope: !4) +!24 = !DILocation(line: 28, column: 43, scope: !4) +!25 = !DILocation(line: 261, column: 15, scope: !26, inlinedAt: !28) +!26 = distinct !DILexicalBlockFile(scope: !4, file: !27, discriminator: 0) +!27 = !DIFile(filename: "standard.py", directory: "/workspace/specforge/lib/python3.11/site-packages/triton/language") +!28 = !DILocation(line: 41, column: 25, scope: !4) +!29 = !DILocation(line: 291, column: 36, scope: !26, inlinedAt: !28) +!30 = !DILocation(line: 42, column: 19, scope: !4) +!31 = !DILocation(line: 43, column: 49, scope: !4) +!32 = !DILocation(line: 43, column: 75, scope: !4) +!33 = !DILocation(line: 43, column: 66, scope: !4) +!34 = !DILocation(line: 43, scope: !4) +!35 = !DILocation(line: 43, column: 57, scope: !4) +!36 = !DILocation(line: 43, column: 34, scope: !4) +!37 = !DILocation(line: 43, column: 25, scope: !4) +!38 = !DILocation(line: 43, column: 88, scope: !4) +!39 = !DILocation(line: 43, column: 4, scope: !4) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/JOVMEF5UB3XPAO2GMKSDIGWCQSZP2YWR3UL465XFLBODYHCLBU5A/triton_red_fused__to_copy_clone_slice_sum_transpose_5.ptx b/SpecForge-ext/cache/compiled_kernels/triton/0/JOVMEF5UB3XPAO2GMKSDIGWCQSZP2YWR3UL465XFLBODYHCLBU5A/triton_red_fused__to_copy_clone_slice_sum_transpose_5.ptx new file mode 100644 index 0000000000000000000000000000000000000000..d430ed0ea7f0fe9af61d162a85e9e1f63c0d902b --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/JOVMEF5UB3XPAO2GMKSDIGWCQSZP2YWR3UL465XFLBODYHCLBU5A/triton_red_fused__to_copy_clone_slice_sum_transpose_5.ptx @@ -0,0 +1,536 @@ +// +// Generated by LLVM NVPTX Back-End +// + +.version 8.7 +.target sm_90a +.address_size 64 + + // .globl triton_red_fused__to_copy_clone_slice_sum_transpose_5 // -- Begin function triton_red_fused__to_copy_clone_slice_sum_transpose_5 +.extern .shared .align 16 .b8 global_smem[]; + // @triton_red_fused__to_copy_clone_slice_sum_transpose_5 +.visible .entry triton_red_fused__to_copy_clone_slice_sum_transpose_5( + .param .u64 .ptr .global .align 1 triton_red_fused__to_copy_clone_slice_sum_transpose_5_param_0, + .param .u64 .ptr .global .align 1 triton_red_fused__to_copy_clone_slice_sum_transpose_5_param_1, + .param .u64 triton_red_fused__to_copy_clone_slice_sum_transpose_5_param_2, + .param .u64 triton_red_fused__to_copy_clone_slice_sum_transpose_5_param_3, + .param .u32 triton_red_fused__to_copy_clone_slice_sum_transpose_5_param_4, + .param .u32 triton_red_fused__to_copy_clone_slice_sum_transpose_5_param_5, + .param .u64 .ptr .global .align 1 triton_red_fused__to_copy_clone_slice_sum_transpose_5_param_6, + .param .u64 .ptr .global .align 1 triton_red_fused__to_copy_clone_slice_sum_transpose_5_param_7 +) +.reqntid 64 +{ + .reg .pred %p<17>; + .reg .b32 %r<57>; + .reg .b64 %rd<87>; + .loc 1 18 0 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:18:0 +$L__func_begin0: + .loc 1 18 0 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:18:0 + +// %bb.0: + ld.param.b32 %r8, [triton_red_fused__to_copy_clone_slice_sum_transpose_5_param_5]; + ld.param.b64 %rd16, [triton_red_fused__to_copy_clone_slice_sum_transpose_5_param_2]; +$L__tmp0: + .loc 1 21 28 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:21:28 + mov.u32 %r9, %ctaid.x; + .loc 1 21 33 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:21:33 + shl.b32 %r10, %r9, 3; + .loc 1 22 44 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:22:44 + mov.u32 %r1, %tid.x; + ld.param.b32 %r11, [triton_red_fused__to_copy_clone_slice_sum_transpose_5_param_4]; + and.b32 %r2, %r1, 7; + .loc 1 22 23 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:22:23 + or.b32 %r13, %r10, %r2; + .loc 1 26 19 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:26:19 + cvt.s64.s32 %rd1, %r13; + .loc 1 27 19 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:27:19 + or.b64 %rd19, %rd1, %rd16; + and.b64 %rd20, %rd19, -4294967296; + setp.ne.b64 %p2, %rd20, 0; + @%p2 bra $L__BB0_2; + bra.uni $L__BB0_1; +$L__BB0_2: + div.s64 %rd82, %rd1, %rd16; + bra.uni $L__BB0_3; +$L__BB0_1: + cvt.u32.u64 %r15, %rd16; + cvt.u32.u64 %r16, %rd1; + div.u32 %r17, %r16, %r15; + cvt.u64.u32 %rd82, %r17; +$L__BB0_3: + .loc 1 0 19 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:0:19 + ld.param.b64 %rd15, [triton_red_fused__to_copy_clone_slice_sum_transpose_5_param_1]; + setp.lt.s32 %p1, %r13, %r11; + .loc 1 26 19 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:26:19 + mul.lo.s64 %rd22, %rd82, %rd16; + sub.s64 %rd6, %rd1, %rd22; + .loc 1 30 40 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:30:40 + setp.lt.s32 %p3, %r8, 1; + mov.b64 %rd85, 0; + shl.b64 %rd81, %rd6, 2; + mov.b64 %rd86, %rd85; + @%p3 bra $L__BB0_9; +// %bb.4: // %.lr.ph + .loc 1 0 40 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:0:40 + ld.param.b64 %rd17, [triton_red_fused__to_copy_clone_slice_sum_transpose_5_param_3]; + ld.param.b64 %rd14, [triton_red_fused__to_copy_clone_slice_sum_transpose_5_param_0]; + bfe.u32 %r3, %r1, 3, 3; + .loc 1 36 54 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:36:54 + mul.lo.s64 %rd23, %rd17, %rd16; + .loc 1 36 58 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:36:58 + mul.lo.s64 %rd24, %rd23, %rd82; + add.s64 %rd26, %rd14, %rd81; + .loc 1 30 40 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:30:40 + shl.b64 %rd27, %rd24, 2; + add.s64 %rd7, %rd26, %rd27; + @%p1 bra $L__BB0_7; + bra.uni $L__BB0_5; +$L__BB0_7: // %.lr.ph.split.preheader + .loc 1 0 40 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:0:40 + mov.b32 %r56, 0; + mov.b64 %rd85, 0; + mov.b64 %rd86, %rd85; +$L__BB0_8: // %.lr.ph.split + // =>This Inner Loop Header: Depth=1 + .loc 1 31 31 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:31:31 + add.s32 %r26, %r3, %r56; + .loc 1 32 29 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:32:29 + add.s32 %r27, %r26, 8; + setp.lt.s32 %p7, %r26, %r8; + setp.lt.s32 %p8, %r27, %r8; + .loc 1 36 43 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:36:43 + cvt.s64.s32 %rd48, %r26; + cvt.s64.s32 %rd49, %r27; + mul.lo.s64 %rd50, %rd16, %rd48; + mul.lo.s64 %rd51, %rd16, %rd49; + .loc 1 36 34 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:36:34 + shl.b64 %rd52, %rd50, 2; + add.s64 %rd43, %rd7, %rd52; + shl.b64 %rd53, %rd51, 2; + add.s64 %rd46, %rd7, %rd53; + .loc 1 36 63 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:36:63 + // begin inline asm + mov.u64 %rd42, 0x0; + createpolicy.fractional.L2::evict_last.b64 %rd42, 1.0; + // end inline asm + // begin inline asm + mov.u32 %r24, 0x0; + @%p7 ld.global.L1::evict_last.L2::cache_hint.b32 { %r24 }, [ %rd43 + 0 ], %rd42; + // end inline asm + // begin inline asm + mov.u64 %rd45, 0x0; + createpolicy.fractional.L2::evict_last.b64 %rd45, 1.0; + // end inline asm + // begin inline asm + mov.u32 %r25, 0x0; + @%p8 ld.global.L1::evict_last.L2::cache_hint.b32 { %r25 }, [ %rd46 + 0 ], %rd45; + // end inline asm + .loc 1 40 48 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:40:48 + selp.b32 %r28, %r24, 0, %p7; + cvt.s64.s32 %rd54, %r28; + add.s64 %rd85, %rd85, %rd54; + selp.b32 %r29, %r25, 0, %p8; + cvt.s64.s32 %rd55, %r29; + add.s64 %rd86, %rd86, %rd55; + .loc 1 30 40 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:30:40 + add.s32 %r56, %r56, 16; + setp.lt.s32 %p9, %r56, %r8; + @%p9 bra $L__BB0_8; + bra.uni $L__BB0_9; +$L__BB0_5: // %.lr.ph.split.us.preheader + .loc 1 0 40 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:0:40 + mov.b32 %r55, 0; +$L__BB0_6: // %.lr.ph.split.us + // =>This Inner Loop Header: Depth=1 + .loc 1 31 31 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:31:31 + add.s32 %r21, %r3, %r55; + .loc 1 36 43 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:36:43 + add.s32 %r22, %r21, 8; + cvt.s64.s32 %rd35, %r21; + cvt.s64.s32 %rd36, %r22; + mul.lo.s64 %rd37, %rd16, %rd35; + mul.lo.s64 %rd38, %rd16, %rd36; + .loc 1 36 34 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:36:34 + shl.b64 %rd39, %rd37, 2; + add.s64 %rd29, %rd7, %rd39; + shl.b64 %rd40, %rd38, 2; + add.s64 %rd32, %rd7, %rd40; + .loc 1 36 63 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:36:63 + // begin inline asm + mov.u64 %rd28, 0x0; + createpolicy.fractional.L2::evict_last.b64 %rd28, 1.0; + // end inline asm + mov.pred %p4, 0; + // begin inline asm + mov.u32 %r19, 0x0; + @%p4 ld.global.L1::evict_last.L2::cache_hint.b32 { %r19 }, [ %rd29 + 0 ], %rd28; + // end inline asm + // begin inline asm + mov.u64 %rd31, 0x0; + createpolicy.fractional.L2::evict_last.b64 %rd31, 1.0; + // end inline asm + // begin inline asm + mov.u32 %r20, 0x0; + @%p4 ld.global.L1::evict_last.L2::cache_hint.b32 { %r20 }, [ %rd32 + 0 ], %rd31; + // end inline asm + .loc 1 30 40 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:30:40 + add.s32 %r55, %r55, 16; + setp.lt.s32 %p6, %r55, %r8; + mov.b64 %rd86, %rd85; + @%p6 bra $L__BB0_6; +$L__BB0_9: // %._crit_edge + .loc 1 22 44 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:22:44 + and.b32 %r34, %r1, 24; +$L__tmp1: + .loc 2 261 15 // standard.py:261:15 @[ cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:41:25 ] + add.s64 %rd60, %rd85, %rd86; + .loc 2 291 36 // standard.py:291:36 @[ cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:41:25 ] + mov.b64 {_, %r35}, %rd60; + cvt.u32.u64 %r36, %rd60; + shfl.sync.bfly.b32 %r37, %r36, 16, 31, -1; + shfl.sync.bfly.b32 %r38, %r35, 16, 31, -1; + cvt.u64.u32 %rd61, %r37; + cvt.u64.u32 %rd62, %r38; + shl.b64 %rd63, %rd62, 32; + or.b64 %rd64, %rd61, %rd63; + .loc 2 261 15 // standard.py:261:15 @[ cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:41:25 ] + add.s64 %rd65, %rd60, %rd64; + .loc 2 291 36 // standard.py:291:36 @[ cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:41:25 ] + mov.b64 {_, %r39}, %rd65; + cvt.u32.u64 %r40, %rd65; + shfl.sync.bfly.b32 %r41, %r40, 8, 31, -1; + shfl.sync.bfly.b32 %r42, %r39, 8, 31, -1; + cvt.u64.u32 %rd66, %r41; + cvt.u64.u32 %rd67, %r42; + shl.b64 %rd68, %rd67, 32; + or.b64 %rd69, %rd66, %rd68; + .loc 2 261 15 // standard.py:261:15 @[ cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:41:25 ] + add.s64 %rd56, %rd65, %rd69; + .loc 2 291 36 // standard.py:291:36 @[ cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:41:25 ] + setp.eq.b32 %p10, %r34, 0; + shl.b32 %r43, %r2, 4; + mov.b32 %r44, global_smem; + add.s32 %r45, %r44, %r43; + shr.u32 %r46, %r1, 2; + and.b32 %r47, %r46, 8; + add.s32 %r30, %r45, %r47; + // begin inline asm + @%p10 st.shared.b64 [ %r30 + 0 ], %rd56; + // end inline asm + bar.sync 0; + setp.lt.u32 %p11, %r1, 16; + shl.b32 %r48, %r1, 3; + add.s32 %r31, %r44, %r48; + // begin inline asm + @%p11 ld.shared.b64 %rd57, [ %r31 + 0 ]; + // end inline asm + mov.b64 {_, %r49}, %rd57; + cvt.u32.u64 %r50, %rd57; + shfl.sync.bfly.b32 %r51, %r50, 1, 31, -1; + shfl.sync.bfly.b32 %r52, %r49, 1, 31, -1; + cvt.u64.u32 %rd70, %r51; + cvt.u64.u32 %rd71, %r52; + shl.b64 %rd72, %rd71, 32; + or.b64 %rd73, %rd70, %rd72; + .loc 2 261 15 // standard.py:261:15 @[ cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:41:25 ] + add.s64 %rd58, %rd57, %rd73; + .loc 2 291 36 // standard.py:291:36 @[ cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:41:25 ] + and.b32 %r53, %r1, 1009; + setp.eq.b32 %p12, %r53, 0; + // begin inline asm + @%p12 st.shared.b64 [ %r31 + 0 ], %rd58; + // end inline asm + bar.sync 0; + ld.shared.b32 %r33, [%r45]; +$L__tmp2: + .loc 1 43 49 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:43:49 + setp.lt.s64 %p14, %rd16, 2; + .loc 1 43 75 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:43:75 + setp.gt.s64 %p15, %rd16, 1; + .loc 1 43 66 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:43:66 + selp.b64 %rd74, %rd16, 0, %p15; + .loc 1 43 0 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:43 + selp.b64 %rd75, 1, 0, %p14; + .loc 1 43 57 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:43:57 + add.s64 %rd76, %rd74, %rd75; + .loc 1 43 34 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:43:34 + mul.lo.s64 %rd77, %rd82, %rd76; + .loc 1 43 25 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:43:25 + add.s64 %rd79, %rd15, %rd81; + shl.b64 %rd80, %rd77, 2; + add.s64 %rd59, %rd79, %rd80; + .loc 1 43 88 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:43:88 + and.b32 %r54, %r1, 56; + setp.eq.b32 %p16, %r54, 0; + and.pred %p13, %p16, %p1; + // begin inline asm + @%p13 st.global.b32 [ %rd59 + 0 ], { %r33 }; + // end inline asm + .loc 1 43 4 // cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py:43:4 + ret; +$L__tmp3: +$L__func_end0: + // -- End function +} + .file 1 "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py" + .file 2 "/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py" + .section .debug_abbrev + { +.b8 1 // Abbreviation Code +.b8 17 // DW_TAG_compile_unit +.b8 1 // DW_CHILDREN_yes +.b8 37 // DW_AT_producer +.b8 8 // DW_FORM_string +.b8 19 // DW_AT_language +.b8 5 // DW_FORM_data2 +.b8 3 // DW_AT_name +.b8 8 // DW_FORM_string +.b8 16 // DW_AT_stmt_list +.b8 6 // DW_FORM_data4 +.b8 27 // DW_AT_comp_dir +.b8 8 // DW_FORM_string +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 2 // Abbreviation Code +.b8 46 // DW_TAG_subprogram +.b8 0 // DW_CHILDREN_no +.b8 3 // DW_AT_name +.b8 8 // DW_FORM_string +.b8 32 // DW_AT_inline +.b8 11 // DW_FORM_data1 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 3 // Abbreviation Code +.b8 46 // DW_TAG_subprogram +.b8 1 // DW_CHILDREN_yes +.b8 17 // DW_AT_low_pc +.b8 1 // DW_FORM_addr +.b8 18 // DW_AT_high_pc +.b8 1 // DW_FORM_addr +.b8 49 // DW_AT_abstract_origin +.b8 19 // DW_FORM_ref4 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 4 // Abbreviation Code +.b8 29 // DW_TAG_inlined_subroutine +.b8 0 // DW_CHILDREN_no +.b8 49 // DW_AT_abstract_origin +.b8 19 // DW_FORM_ref4 +.b8 17 // DW_AT_low_pc +.b8 1 // DW_FORM_addr +.b8 18 // DW_AT_high_pc +.b8 1 // DW_FORM_addr +.b8 88 // DW_AT_call_file +.b8 11 // DW_FORM_data1 +.b8 89 // DW_AT_call_line +.b8 11 // DW_FORM_data1 +.b8 87 // DW_AT_call_column +.b8 11 // DW_FORM_data1 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 0 // EOM(3) + } + .section .debug_info + { +.b32 238 // Length of Unit +.b8 2 // DWARF version number +.b8 0 +.b32 .debug_abbrev // Offset Into Abbrev. Section +.b8 8 // Address Size (in bytes) +.b8 1 // Abbrev [1] 0xb:0xe7 DW_TAG_compile_unit +.b8 116 // DW_AT_producer +.b8 114 +.b8 105 +.b8 116 +.b8 111 +.b8 110 +.b8 0 +.b8 2 // DW_AT_language +.b8 0 +.b8 99 // DW_AT_name +.b8 119 +.b8 99 +.b8 50 +.b8 106 +.b8 99 +.b8 116 +.b8 122 +.b8 54 +.b8 55 +.b8 115 +.b8 55 +.b8 110 +.b8 54 +.b8 102 +.b8 111 +.b8 116 +.b8 108 +.b8 50 +.b8 50 +.b8 103 +.b8 114 +.b8 52 +.b8 52 +.b8 103 +.b8 108 +.b8 105 +.b8 106 +.b8 114 +.b8 99 +.b8 102 +.b8 120 +.b8 112 +.b8 101 +.b8 98 +.b8 108 +.b8 118 +.b8 107 +.b8 51 +.b8 50 +.b8 104 +.b8 111 +.b8 114 +.b8 110 +.b8 108 +.b8 108 +.b8 108 +.b8 122 +.b8 110 +.b8 120 +.b8 50 +.b8 109 +.b8 46 +.b8 112 +.b8 121 +.b8 0 +.b32 .debug_line // DW_AT_stmt_list +.b8 47 // DW_AT_comp_dir +.b8 119 +.b8 111 +.b8 114 +.b8 107 +.b8 115 +.b8 112 +.b8 97 +.b8 99 +.b8 101 +.b8 47 +.b8 104 +.b8 97 +.b8 110 +.b8 114 +.b8 117 +.b8 105 +.b8 47 +.b8 83 +.b8 112 +.b8 101 +.b8 99 +.b8 70 +.b8 111 +.b8 114 +.b8 103 +.b8 101 +.b8 45 +.b8 101 +.b8 120 +.b8 116 +.b8 47 +.b8 99 +.b8 97 +.b8 99 +.b8 104 +.b8 101 +.b8 47 +.b8 99 +.b8 111 +.b8 109 +.b8 112 +.b8 105 +.b8 108 +.b8 101 +.b8 100 +.b8 95 +.b8 107 +.b8 101 +.b8 114 +.b8 110 +.b8 101 +.b8 108 +.b8 115 +.b8 47 +.b8 119 +.b8 99 +.b8 0 +.b8 2 // Abbrev [2] 0x8b:0x38 DW_TAG_subprogram +.b8 116 // DW_AT_name +.b8 114 +.b8 105 +.b8 116 +.b8 111 +.b8 110 +.b8 95 +.b8 114 +.b8 101 +.b8 100 +.b8 95 +.b8 102 +.b8 117 +.b8 115 +.b8 101 +.b8 100 +.b8 95 +.b8 95 +.b8 116 +.b8 111 +.b8 95 +.b8 99 +.b8 111 +.b8 112 +.b8 121 +.b8 95 +.b8 99 +.b8 108 +.b8 111 +.b8 110 +.b8 101 +.b8 95 +.b8 115 +.b8 108 +.b8 105 +.b8 99 +.b8 101 +.b8 95 +.b8 115 +.b8 117 +.b8 109 +.b8 95 +.b8 116 +.b8 114 +.b8 97 +.b8 110 +.b8 115 +.b8 112 +.b8 111 +.b8 115 +.b8 101 +.b8 95 +.b8 53 +.b8 0 +.b8 1 // DW_AT_inline +.b8 3 // Abbrev [3] 0xc3:0x2e DW_TAG_subprogram +.b64 $L__func_begin0 // DW_AT_low_pc +.b64 $L__func_end0 // DW_AT_high_pc +.b32 139 // DW_AT_abstract_origin +.b8 4 // Abbrev [4] 0xd8:0x18 DW_TAG_inlined_subroutine +.b32 139 // DW_AT_abstract_origin +.b64 $L__tmp1 // DW_AT_low_pc +.b64 $L__tmp2 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 41 // DW_AT_call_line +.b8 25 // DW_AT_call_column +.b8 0 // End Of Children Mark +.b8 0 // End Of Children Mark + } + .section .debug_macinfo { } diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/JOVMEF5UB3XPAO2GMKSDIGWCQSZP2YWR3UL465XFLBODYHCLBU5A/triton_red_fused__to_copy_clone_slice_sum_transpose_5.source b/SpecForge-ext/cache/compiled_kernels/triton/0/JOVMEF5UB3XPAO2GMKSDIGWCQSZP2YWR3UL465XFLBODYHCLBU5A/triton_red_fused__to_copy_clone_slice_sum_transpose_5.source new file mode 100644 index 0000000000000000000000000000000000000000..b38c8d9778cccf9217fa4f1aa83935b9055a4dfe --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/JOVMEF5UB3XPAO2GMKSDIGWCQSZP2YWR3UL465XFLBODYHCLBU5A/triton_red_fused__to_copy_clone_slice_sum_transpose_5.source @@ -0,0 +1,193 @@ +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":18:0) +#loc41 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":285:0) +#loc43 = loc(unknown) +#loc46 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":260:0) +#loc50 = loc("in_ptr0"(#loc)) +#loc51 = loc("out_ptr1"(#loc)) +#loc52 = loc("ks0"(#loc)) +#loc53 = loc("ks1"(#loc)) +#loc54 = loc("xnumel"(#loc)) +#loc55 = loc("r0_numel"(#loc)) +#loc85 = loc("input"(#loc41)) +#loc86 = loc("a"(#loc46)) +#loc87 = loc("b"(#loc46)) +module { + tt.func public @triton_red_fused__to_copy_clone_slice_sum_transpose_5(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %out_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr1"(#loc)), %ks0: i64 loc("ks0"(#loc)), %ks1: i64 loc("ks1"(#loc)), %xnumel: i32 loc("xnumel"(#loc)), %r0_numel: i32 loc("r0_numel"(#loc))) attributes {noinline = false} { + %xoffset = tt.get_program_id x : i32 loc(#loc56) + %xoffset_0 = arith.constant 8 : i32 loc(#loc57) + %xoffset_1 = arith.constant 8 : i32 loc(#loc57) + %xoffset_2 = arith.muli %xoffset, %xoffset_1 : i32 loc(#loc57) + %xindex = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> loc(#loc58) + %xindex_3 = tt.expand_dims %xindex {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> loc(#loc59) + %xindex_4 = tt.splat %xoffset_2 : i32 -> tensor<8x1xi32> loc(#loc60) + %xindex_5 = arith.addi %xindex_4, %xindex_3 : tensor<8x1xi32> loc(#loc60) + %xmask = tt.splat %xnumel : i32 -> tensor<8x1xi32> loc(#loc61) + %xmask_6 = arith.cmpi slt, %xindex_5, %xmask : tensor<8x1xi32> loc(#loc61) + %r0_base = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc62) + %r0_base_7 = tt.expand_dims %r0_base {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> loc(#loc63) + %x0 = arith.extsi %xindex_5 : tensor<8x1xi32> to tensor<8x1xi64> loc(#loc64) + %x0_8 = tt.splat %ks0 : i64 -> tensor<8x1xi64> loc(#loc64) + %x0_9 = arith.remsi %x0, %x0_8 : tensor<8x1xi64> loc(#loc64) + %x1 = arith.extsi %xindex_5 : tensor<8x1xi32> to tensor<8x1xi64> loc(#loc65) + %x1_10 = tt.splat %ks0 : i64 -> tensor<8x1xi64> loc(#loc65) + %x1_11 = arith.divsi %x1, %x1_10 : tensor<8x1xi64> loc(#loc65) + %_tmp3 = arith.constant 0 : i64 loc(#loc66) + %_tmp3_12 = arith.constant dense<0> : tensor<8x16xi64> loc(#loc66) + %c0_i32 = arith.constant 0 : i32 loc(#loc12) + %c16_i32 = arith.constant 16 : i32 loc(#loc12) + %0 = arith.bitcast %c0_i32 : i32 to i32 loc(#loc12) + %1 = arith.bitcast %r0_numel : i32 to i32 loc(#loc12) + %2 = arith.bitcast %c16_i32 : i32 to i32 loc(#loc12) + %3 = ub.poison : i32 loc(#loc12) + %_tmp3_13 = scf.for %r0_offset = %0 to %1 step %2 iter_args(%_tmp3_18 = %_tmp3_12) -> (tensor<8x16xi64>) : i32 { + %r0_index = tt.splat %r0_offset : i32 -> tensor<1x16xi32> loc(#loc68) + %r0_index_19 = arith.addi %r0_index, %r0_base_7 : tensor<1x16xi32> loc(#loc68) + %r0_mask = tt.splat %r0_numel : i32 -> tensor<1x16xi32> loc(#loc69) + %r0_mask_20 = arith.cmpi slt, %r0_index_19, %r0_mask : tensor<1x16xi32> loc(#loc69) + %tmp0 = arith.extsi %r0_index_19 : tensor<1x16xi32> to tensor<1x16xi64> loc(#loc70) + %tmp0_21 = tt.splat %ks0 : i64 -> tensor<1x16xi64> loc(#loc70) + %tmp0_22 = arith.muli %tmp0_21, %tmp0 : tensor<1x16xi64> loc(#loc70) + %tmp0_23 = tt.broadcast %x0_9 : tensor<8x1xi64> -> tensor<8x16xi64> loc(#loc71) + %tmp0_24 = tt.broadcast %tmp0_22 : tensor<1x16xi64> -> tensor<8x16xi64> loc(#loc71) + %tmp0_25 = arith.addi %tmp0_23, %tmp0_24 : tensor<8x16xi64> loc(#loc71) + %tmp0_26 = arith.muli %ks0, %ks1 : i64 loc(#loc72) + %tmp0_27 = tt.splat %tmp0_26 : i64 -> tensor<8x1xi64> loc(#loc73) + %tmp0_28 = arith.muli %tmp0_27, %x1_11 : tensor<8x1xi64> loc(#loc73) + %tmp0_29 = tt.broadcast %tmp0_28 : tensor<8x1xi64> -> tensor<8x16xi64> loc(#loc74) + %tmp0_30 = arith.addi %tmp0_25, %tmp0_29 : tensor<8x16xi64> loc(#loc74) + %tmp0_31 = tt.splat %in_ptr0 : !tt.ptr -> tensor<8x16x!tt.ptr> loc(#loc75) + %tmp0_32 = tt.addptr %tmp0_31, %tmp0_30 : tensor<8x16x!tt.ptr>, tensor<8x16xi64> loc(#loc75) + %tmp0_33 = tt.broadcast %r0_mask_20 : tensor<1x16xi1> -> tensor<8x16xi1> loc(#loc76) + %tmp0_34 = tt.broadcast %xmask_6 : tensor<8x1xi1> -> tensor<8x16xi1> loc(#loc76) + %tmp0_35 = arith.andi %tmp0_33, %tmp0_34 : tensor<8x16xi1> loc(#loc76) + %tmp0_36 = arith.constant 0.000000e+00 : f32 loc(#loc77) + %tmp0_37 = arith.constant dense<0.000000e+00> : tensor<8x16xf32> loc(#loc77) + %tmp0_38 = arith.fptosi %tmp0_37 : tensor<8x16xf32> to tensor<8x16xi32> loc(#loc77) + %tmp0_39 = tt.load %tmp0_32, %tmp0_35, %tmp0_38 evictionPolicy = evict_last : tensor<8x16x!tt.ptr> loc(#loc77) + %tmp1 = arith.extsi %tmp0_39 : tensor<8x16xi32> to tensor<8x16xi64> loc(#loc78) + %tmp4 = arith.addi %_tmp3_18, %tmp1 : tensor<8x16xi64> loc(#loc79) + %_tmp3_40 = tt.broadcast %r0_mask_20 : tensor<1x16xi1> -> tensor<8x16xi1> loc(#loc80) + %_tmp3_41 = tt.broadcast %xmask_6 : tensor<8x1xi1> -> tensor<8x16xi1> loc(#loc80) + %_tmp3_42 = arith.andi %_tmp3_40, %_tmp3_41 : tensor<8x16xi1> loc(#loc80) + %_tmp3_43 = arith.select %_tmp3_42, %tmp4, %_tmp3_18 : tensor<8x16xi1>, tensor<8x16xi64> loc(#loc81) + scf.yield %_tmp3_43 : tensor<8x16xi64> loc(#loc27) + } loc(#loc67) + %tmp3 = tt.call @"triton.language.standard.sum__i64S8_16S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%_tmp3_13) : (tensor<8x16xi64>) -> tensor<8xi64> loc(#loc82) + %tmp3_14 = tt.expand_dims %tmp3 {axis = 1 : i32} : tensor<8xi64> -> tensor<8x1xi64> loc(#loc83) + %tmp5 = arith.trunci %tmp3_14 : tensor<8x1xi64> to tensor<8x1xi32> loc(#loc84) + %c1_i32 = arith.constant 1 : i32 loc(#loc31) + %4 = arith.extsi %c1_i32 : i32 to i64 loc(#loc31) + %5 = arith.cmpi sge, %4, %ks0 : i64 loc(#loc31) + %c1_i32_15 = arith.constant 1 : i32 loc(#loc32) + %c1_i32_16 = arith.constant 1 : i32 loc(#loc32) + %6 = arith.extui %5 : i1 to i32 loc(#loc32) + %7 = arith.muli %c1_i32_16, %6 : i32 loc(#loc32) + %c1_i32_17 = arith.constant 1 : i32 loc(#loc33) + %8 = arith.extsi %c1_i32_17 : i32 to i64 loc(#loc33) + %9 = arith.cmpi sgt, %ks0, %8 : i64 loc(#loc33) + %10 = arith.extui %9 : i1 to i64 loc(#loc34) + %11 = arith.muli %ks0, %10 : i64 loc(#loc34) + %12 = arith.extsi %7 : i32 to i64 loc(#loc35) + %13 = arith.addi %12, %11 : i64 loc(#loc35) + %14 = tt.splat %13 : i64 -> tensor<8x1xi64> loc(#loc36) + %15 = arith.muli %x1_11, %14 : tensor<8x1xi64> loc(#loc36) + %16 = arith.addi %x0_9, %15 : tensor<8x1xi64> loc(#loc37) + %17 = tt.splat %out_ptr1 : !tt.ptr -> tensor<8x1x!tt.ptr> loc(#loc38) + %18 = tt.addptr %17, %16 : tensor<8x1x!tt.ptr>, tensor<8x1xi64> loc(#loc38) + tt.store %18, %tmp5, %xmask_6 : tensor<8x1x!tt.ptr> loc(#loc39) + tt.return loc(#loc40) + } loc(#loc) + tt.func private @"triton.language.standard.sum__i64S8_16S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%input: tensor<8x16xi64> loc("input"(#loc41))) -> tensor<8xi64> attributes {noinline = false} { + %0 = "tt.reduce"(%input) <{axis = 1 : i32}> ({ + ^bb0(%arg1: i64 loc(unknown), %arg2: i64 loc(unknown)): + %2 = tt.call @triton.language.standard._sum_combine__i64_i64__(%arg1, %arg2) : (i64, i64) -> i64 loc(#loc42) + tt.reduce.return %2 : i64 loc(#loc42) + }) : (tensor<8x16xi64>) -> tensor<8xi64> loc(#loc42) + tt.return %0 : tensor<8xi64> loc(#loc44) + ^bb1: // no predecessors + %1 = ub.poison : tensor<8xi64> loc(#loc45) + tt.return %1 : tensor<8xi64> loc(#loc45) + } loc(#loc41) + tt.func private @triton.language.standard._sum_combine__i64_i64__(%a: i64 loc("a"(#loc46)), %b: i64 loc("b"(#loc46))) -> i64 attributes {noinline = false} { + %0 = arith.addi %a, %b : i64 loc(#loc47) + tt.return %0 : i64 loc(#loc48) + ^bb1: // no predecessors + %1 = ub.poison : i64 loc(#loc49) + tt.return %1 : i64 loc(#loc49) + } loc(#loc46) +} loc(#loc) +#loc1 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":21:28) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":21:33) +#loc3 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":22:36) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":22:44) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":22:23) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":23:21) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":24:27) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":24:37) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":26:19) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":27:19) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":28:43) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":30:40) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":31:31) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":32:29) +#loc15 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":36:43) +#loc16 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":36:39) +#loc17 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":36:54) +#loc18 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":36:58) +#loc19 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":36:50) +#loc20 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":36:34) +#loc21 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":36:73) +#loc22 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":36:63) +#loc23 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":37:23) +#loc24 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":39:23) +#loc25 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":40:35) +#loc26 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":40:48) +#loc27 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":40:8) +#loc28 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":41:25) +#loc29 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":41:28) +#loc30 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":42:19) +#loc31 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":43:49) +#loc32 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":43:41) +#loc33 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":43:75) +#loc34 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":43:66) +#loc35 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":43:57) +#loc36 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":43:34) +#loc37 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":43:30) +#loc38 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":43:25) +#loc39 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":43:88) +#loc40 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":43:4) +#loc42 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:36) +#loc44 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:11) +#loc45 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:4) +#loc47 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:15) +#loc48 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:11) +#loc49 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:4) +#loc56 = loc("xoffset"(#loc1)) +#loc57 = loc("xoffset"(#loc2)) +#loc58 = loc("xindex"(#loc3)) +#loc59 = loc("xindex"(#loc4)) +#loc60 = loc("xindex"(#loc5)) +#loc61 = loc("xmask"(#loc6)) +#loc62 = loc("r0_base"(#loc7)) +#loc63 = loc("r0_base"(#loc8)) +#loc64 = loc("x0"(#loc9)) +#loc65 = loc("x1"(#loc10)) +#loc66 = loc("_tmp3"(#loc11)) +#loc67 = loc("_tmp3"(#loc12)) +#loc68 = loc("r0_index"(#loc13)) +#loc69 = loc("r0_mask"(#loc14)) +#loc70 = loc("tmp0"(#loc15)) +#loc71 = loc("tmp0"(#loc16)) +#loc72 = loc("tmp0"(#loc17)) +#loc73 = loc("tmp0"(#loc18)) +#loc74 = loc("tmp0"(#loc19)) +#loc75 = loc("tmp0"(#loc20)) +#loc76 = loc("tmp0"(#loc21)) +#loc77 = loc("tmp0"(#loc22)) +#loc78 = loc("tmp1"(#loc23)) +#loc79 = loc("tmp4"(#loc24)) +#loc80 = loc("_tmp3"(#loc25)) +#loc81 = loc("_tmp3"(#loc26)) +#loc82 = loc("tmp3"(#loc28)) +#loc83 = loc("tmp3"(#loc29)) +#loc84 = loc("tmp5"(#loc30)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/JOVMEF5UB3XPAO2GMKSDIGWCQSZP2YWR3UL465XFLBODYHCLBU5A/triton_red_fused__to_copy_clone_slice_sum_transpose_5.ttgir b/SpecForge-ext/cache/compiled_kernels/triton/0/JOVMEF5UB3XPAO2GMKSDIGWCQSZP2YWR3UL465XFLBODYHCLBU5A/triton_red_fused__to_copy_clone_slice_sum_transpose_5.ttgir new file mode 100644 index 0000000000000000000000000000000000000000..7b63991c81ee03dcabbec16981abaa59b48e2598 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/JOVMEF5UB3XPAO2GMKSDIGWCQSZP2YWR3UL465XFLBODYHCLBU5A/triton_red_fused__to_copy_clone_slice_sum_transpose_5.ttgir @@ -0,0 +1,147 @@ +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 2], order = [0, 1]}> +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":18:0) +#loc1 = loc(unknown) +#loc26 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":41:25) +#loc40 = loc("in_ptr0"(#loc)) +#loc41 = loc("out_ptr1"(#loc)) +#loc42 = loc("ks0"(#loc)) +#loc43 = loc("ks1"(#loc)) +#loc44 = loc("xnumel"(#loc)) +#loc45 = loc("r0_numel"(#loc)) +#loc68 = loc("tmp3"(#loc26)) +#loc73 = loc(callsite(#loc1 at #loc68)) +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @triton_red_fused__to_copy_clone_slice_sum_transpose_5(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %out_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr1"(#loc)), %ks0: i64 loc("ks0"(#loc)), %ks1: i64 loc("ks1"(#loc)), %xnumel: i32 loc("xnumel"(#loc)), %r0_numel: i32 loc("r0_numel"(#loc))) attributes {noinline = false} { + %cst = arith.constant dense<0> : tensor<8x16xi64, #blocked> loc(#loc1) + %c0_i32 = arith.constant 0 : i32 loc(#loc1) + %c16_i32 = arith.constant 16 : i32 loc(#loc1) + %c1_i64 = arith.constant 1 : i64 loc(#loc1) + %c8_i32 = arith.constant 8 : i32 loc(#loc1) + %cst_0 = arith.constant dense<0> : tensor<8x16xi32, #blocked> loc(#loc1) + %xoffset = tt.get_program_id x : i32 loc(#loc46) + %xoffset_1 = arith.muli %xoffset, %c8_i32 : i32 loc(#loc47) + %xindex = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc48) + %xindex_2 = tt.expand_dims %xindex {axis = 1 : i32} : tensor<8xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<8x1xi32, #blocked> loc(#loc48) + %xindex_3 = tt.splat %xoffset_1 : i32 -> tensor<8x1xi32, #blocked> loc(#loc49) + %xindex_4 = arith.addi %xindex_3, %xindex_2 : tensor<8x1xi32, #blocked> loc(#loc49) + %xmask = tt.splat %xnumel : i32 -> tensor<8x1xi32, #blocked> loc(#loc50) + %xmask_5 = arith.cmpi slt, %xindex_4, %xmask : tensor<8x1xi32, #blocked> loc(#loc50) + %r0_base = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc51) + %r0_base_6 = tt.expand_dims %r0_base {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> loc(#loc51) + %x0 = arith.extsi %xindex_4 : tensor<8x1xi32, #blocked> to tensor<8x1xi64, #blocked> loc(#loc52) + %x0_7 = tt.splat %ks0 : i64 -> tensor<8x1xi64, #blocked> loc(#loc52) + %x0_8 = arith.remsi %x0, %x0_7 : tensor<8x1xi64, #blocked> loc(#loc52) + %x1 = arith.divsi %x0, %x0_7 : tensor<8x1xi64, #blocked> loc(#loc53) + %r0_mask = tt.splat %r0_numel : i32 -> tensor<1x16xi32, #blocked> loc(#loc54) + %tmp0 = tt.splat %ks0 : i64 -> tensor<1x16xi64, #blocked> loc(#loc55) + %tmp0_9 = tt.broadcast %x0_8 : tensor<8x1xi64, #blocked> -> tensor<8x16xi64, #blocked> loc(#loc56) + %tmp0_10 = arith.muli %ks0, %ks1 : i64 loc(#loc57) + %tmp0_11 = tt.splat %tmp0_10 : i64 -> tensor<8x1xi64, #blocked> loc(#loc58) + %tmp0_12 = arith.muli %tmp0_11, %x1 : tensor<8x1xi64, #blocked> loc(#loc58) + %tmp0_13 = tt.broadcast %tmp0_12 : tensor<8x1xi64, #blocked> -> tensor<8x16xi64, #blocked> loc(#loc59) + %tmp0_14 = tt.splat %in_ptr0 : !tt.ptr -> tensor<8x16x!tt.ptr, #blocked> loc(#loc60) + %tmp0_15 = tt.broadcast %xmask_5 : tensor<8x1xi1, #blocked> -> tensor<8x16xi1, #blocked> loc(#loc61) + %_tmp3 = scf.for %_tmp3_17 = %c0_i32 to %r0_numel step %c16_i32 iter_args(%_tmp3_18 = %cst) -> (tensor<8x16xi64, #blocked>) : i32 { + %r0_index = tt.splat %_tmp3_17 : i32 -> tensor<1x16xi32, #blocked> loc(#loc63) + %r0_index_19 = arith.addi %r0_index, %r0_base_6 : tensor<1x16xi32, #blocked> loc(#loc63) + %r0_mask_20 = arith.cmpi slt, %r0_index_19, %r0_mask : tensor<1x16xi32, #blocked> loc(#loc54) + %tmp0_21 = arith.extsi %r0_index_19 : tensor<1x16xi32, #blocked> to tensor<1x16xi64, #blocked> loc(#loc55) + %tmp0_22 = arith.muli %tmp0, %tmp0_21 : tensor<1x16xi64, #blocked> loc(#loc55) + %tmp0_23 = tt.broadcast %tmp0_22 : tensor<1x16xi64, #blocked> -> tensor<8x16xi64, #blocked> loc(#loc56) + %tmp0_24 = arith.addi %tmp0_9, %tmp0_23 : tensor<8x16xi64, #blocked> loc(#loc56) + %tmp0_25 = arith.addi %tmp0_24, %tmp0_13 : tensor<8x16xi64, #blocked> loc(#loc59) + %tmp0_26 = tt.addptr %tmp0_14, %tmp0_25 : tensor<8x16x!tt.ptr, #blocked>, tensor<8x16xi64, #blocked> loc(#loc60) + %tmp0_27 = tt.broadcast %r0_mask_20 : tensor<1x16xi1, #blocked> -> tensor<8x16xi1, #blocked> loc(#loc61) + %tmp0_28 = arith.andi %tmp0_27, %tmp0_15 : tensor<8x16xi1, #blocked> loc(#loc61) + %tmp0_29 = tt.load %tmp0_26, %tmp0_28, %cst_0 evictionPolicy = evict_last : tensor<8x16x!tt.ptr, #blocked> loc(#loc64) + %tmp1 = arith.extsi %tmp0_29 : tensor<8x16xi32, #blocked> to tensor<8x16xi64, #blocked> loc(#loc65) + %tmp4 = arith.addi %_tmp3_18, %tmp1 : tensor<8x16xi64, #blocked> loc(#loc66) + %_tmp3_30 = arith.select %tmp0_28, %tmp4, %_tmp3_18 : tensor<8x16xi1, #blocked>, tensor<8x16xi64, #blocked> loc(#loc67) + scf.yield %_tmp3_30 : tensor<8x16xi64, #blocked> loc(#loc24) + } loc(#loc62) + %tmp3 = "tt.reduce"(%_tmp3) <{axis = 1 : i32}> ({ + ^bb0(%tmp3_17: i64 loc(callsite(#loc1 at #loc68)), %tmp3_18: i64 loc(callsite(#loc1 at #loc68))): + %tmp3_19 = arith.addi %tmp3_17, %tmp3_18 : i64 loc(#loc74) + tt.reduce.return %tmp3_19 : i64 loc(#loc72) + }) : (tensor<8x16xi64, #blocked>) -> tensor<8xi64, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc72) + %tmp3_16 = tt.expand_dims %tmp3 {axis = 1 : i32} : tensor<8xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<8x1xi64, #blocked> loc(#loc69) + %tmp5 = arith.trunci %tmp3_16 : tensor<8x1xi64, #blocked> to tensor<8x1xi32, #blocked> loc(#loc70) + %0 = arith.cmpi sle, %ks0, %c1_i64 : i64 loc(#loc30) + %1 = arith.cmpi sgt, %ks0, %c1_i64 : i64 loc(#loc31) + %2 = arith.extui %1 : i1 to i64 loc(#loc32) + %3 = arith.muli %ks0, %2 : i64 loc(#loc32) + %4 = arith.extui %0 : i1 to i64 loc(#loc71) + %5 = arith.addi %4, %3 : i64 loc(#loc33) + %6 = tt.splat %5 : i64 -> tensor<8x1xi64, #blocked> loc(#loc35) + %7 = arith.muli %x1, %6 : tensor<8x1xi64, #blocked> loc(#loc35) + %8 = arith.addi %x0_8, %7 : tensor<8x1xi64, #blocked> loc(#loc36) + %9 = tt.splat %out_ptr1 : !tt.ptr -> tensor<8x1x!tt.ptr, #blocked> loc(#loc37) + %10 = tt.addptr %9, %8 : tensor<8x1x!tt.ptr, #blocked>, tensor<8x1xi64, #blocked> loc(#loc37) + tt.store %10, %tmp5, %xmask_5 : tensor<8x1x!tt.ptr, #blocked> loc(#loc38) + tt.return loc(#loc39) + } loc(#loc) +} loc(#loc) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":21:28) +#loc3 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":21:33) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":22:44) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":22:23) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":23:21) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":24:37) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":26:19) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":27:19) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":32:29) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":36:43) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":36:39) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":36:54) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":36:58) +#loc15 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":36:50) +#loc16 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":36:34) +#loc17 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":36:73) +#loc18 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":30:40) +#loc19 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":31:31) +#loc20 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":36:63) +#loc21 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":37:23) +#loc22 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":39:23) +#loc23 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":40:48) +#loc24 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":40:8) +#loc25 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:36) +#loc27 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:15) +#loc28 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":41:28) +#loc29 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":42:19) +#loc30 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":43:49) +#loc31 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":43:75) +#loc32 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":43:66) +#loc33 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":43:57) +#loc34 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":43:41) +#loc35 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":43:34) +#loc36 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":43:30) +#loc37 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":43:25) +#loc38 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":43:88) +#loc39 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":43:4) +#loc46 = loc("xoffset"(#loc2)) +#loc47 = loc("xoffset"(#loc3)) +#loc48 = loc("xindex"(#loc4)) +#loc49 = loc("xindex"(#loc5)) +#loc50 = loc("xmask"(#loc6)) +#loc51 = loc("r0_base"(#loc7)) +#loc52 = loc("x0"(#loc8)) +#loc53 = loc("x1"(#loc9)) +#loc54 = loc("r0_mask"(#loc10)) +#loc55 = loc("tmp0"(#loc11)) +#loc56 = loc("tmp0"(#loc12)) +#loc57 = loc("tmp0"(#loc13)) +#loc58 = loc("tmp0"(#loc14)) +#loc59 = loc("tmp0"(#loc15)) +#loc60 = loc("tmp0"(#loc16)) +#loc61 = loc("tmp0"(#loc17)) +#loc62 = loc("_tmp3"(#loc18)) +#loc63 = loc("r0_index"(#loc19)) +#loc64 = loc("tmp0"(#loc20)) +#loc65 = loc("tmp1"(#loc21)) +#loc66 = loc("tmp4"(#loc22)) +#loc67 = loc("_tmp3"(#loc23)) +#loc69 = loc("tmp3"(#loc28)) +#loc70 = loc("tmp5"(#loc29)) +#loc71 = loc(fused[#loc33, #loc34]) +#loc72 = loc(callsite(#loc25 at #loc68)) +#loc74 = loc(callsite(#loc27 at #loc72)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/JOVMEF5UB3XPAO2GMKSDIGWCQSZP2YWR3UL465XFLBODYHCLBU5A/triton_red_fused__to_copy_clone_slice_sum_transpose_5.ttir b/SpecForge-ext/cache/compiled_kernels/triton/0/JOVMEF5UB3XPAO2GMKSDIGWCQSZP2YWR3UL465XFLBODYHCLBU5A/triton_red_fused__to_copy_clone_slice_sum_transpose_5.ttir new file mode 100644 index 0000000000000000000000000000000000000000..fab66be53d1751361695a0240fa4c61784e8a688 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/JOVMEF5UB3XPAO2GMKSDIGWCQSZP2YWR3UL465XFLBODYHCLBU5A/triton_red_fused__to_copy_clone_slice_sum_transpose_5.ttir @@ -0,0 +1,152 @@ +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":18:0) +#loc1 = loc(unknown) +#loc29 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":41:25) +#loc43 = loc("in_ptr0"(#loc)) +#loc44 = loc("out_ptr1"(#loc)) +#loc45 = loc("ks0"(#loc)) +#loc46 = loc("ks1"(#loc)) +#loc47 = loc("xnumel"(#loc)) +#loc48 = loc("r0_numel"(#loc)) +#loc74 = loc("tmp3"(#loc29)) +#loc79 = loc(callsite(#loc1 at #loc74)) +module { + tt.func public @triton_red_fused__to_copy_clone_slice_sum_transpose_5(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %out_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr1"(#loc)), %ks0: i64 loc("ks0"(#loc)), %ks1: i64 loc("ks1"(#loc)), %xnumel: i32 loc("xnumel"(#loc)), %r0_numel: i32 loc("r0_numel"(#loc))) attributes {noinline = false} { + %c1_i64 = arith.constant 1 : i64 loc(#loc1) + %cst = arith.constant dense<0> : tensor<8x16xi32> loc(#loc1) + %c16_i32 = arith.constant 16 : i32 loc(#loc2) + %c0_i32 = arith.constant 0 : i32 loc(#loc2) + %_tmp3 = arith.constant dense<0> : tensor<8x16xi64> loc(#loc49) + %c8_i32 = arith.constant 8 : i32 loc(#loc1) + %xoffset = tt.get_program_id x : i32 loc(#loc50) + %xoffset_0 = arith.muli %xoffset, %c8_i32 : i32 loc(#loc51) + %xindex = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> loc(#loc52) + %xindex_1 = tt.expand_dims %xindex {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> loc(#loc53) + %xindex_2 = tt.splat %xoffset_0 : i32 -> tensor<8x1xi32> loc(#loc54) + %xindex_3 = arith.addi %xindex_2, %xindex_1 : tensor<8x1xi32> loc(#loc54) + %xmask = tt.splat %xnumel : i32 -> tensor<8x1xi32> loc(#loc55) + %xmask_4 = arith.cmpi slt, %xindex_3, %xmask : tensor<8x1xi32> loc(#loc55) + %r0_base = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc56) + %r0_base_5 = tt.expand_dims %r0_base {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> loc(#loc57) + %x0 = arith.extsi %xindex_3 : tensor<8x1xi32> to tensor<8x1xi64> loc(#loc58) + %x0_6 = tt.splat %ks0 : i64 -> tensor<8x1xi64> loc(#loc58) + %x0_7 = arith.remsi %x0, %x0_6 : tensor<8x1xi64> loc(#loc58) + %x1 = arith.divsi %x0, %x0_6 : tensor<8x1xi64> loc(#loc59) + %_tmp3_8 = scf.for %r0_offset = %c0_i32 to %r0_numel step %c16_i32 iter_args(%_tmp3_10 = %_tmp3) -> (tensor<8x16xi64>) : i32 { + %r0_index = tt.splat %r0_offset : i32 -> tensor<1x16xi32> loc(#loc61) + %r0_index_11 = arith.addi %r0_index, %r0_base_5 : tensor<1x16xi32> loc(#loc61) + %r0_mask = tt.splat %r0_numel : i32 -> tensor<1x16xi32> loc(#loc62) + %r0_mask_12 = arith.cmpi slt, %r0_index_11, %r0_mask : tensor<1x16xi32> loc(#loc62) + %tmp0 = arith.extsi %r0_index_11 : tensor<1x16xi32> to tensor<1x16xi64> loc(#loc63) + %tmp0_13 = tt.splat %ks0 : i64 -> tensor<1x16xi64> loc(#loc63) + %tmp0_14 = arith.muli %tmp0_13, %tmp0 : tensor<1x16xi64> loc(#loc63) + %tmp0_15 = tt.broadcast %x0_7 : tensor<8x1xi64> -> tensor<8x16xi64> loc(#loc64) + %tmp0_16 = tt.broadcast %tmp0_14 : tensor<1x16xi64> -> tensor<8x16xi64> loc(#loc64) + %tmp0_17 = arith.addi %tmp0_15, %tmp0_16 : tensor<8x16xi64> loc(#loc64) + %tmp0_18 = arith.muli %ks0, %ks1 : i64 loc(#loc65) + %tmp0_19 = tt.splat %tmp0_18 : i64 -> tensor<8x1xi64> loc(#loc66) + %tmp0_20 = arith.muli %tmp0_19, %x1 : tensor<8x1xi64> loc(#loc66) + %tmp0_21 = tt.broadcast %tmp0_20 : tensor<8x1xi64> -> tensor<8x16xi64> loc(#loc67) + %tmp0_22 = arith.addi %tmp0_17, %tmp0_21 : tensor<8x16xi64> loc(#loc67) + %tmp0_23 = tt.splat %in_ptr0 : !tt.ptr -> tensor<8x16x!tt.ptr> loc(#loc68) + %tmp0_24 = tt.addptr %tmp0_23, %tmp0_22 : tensor<8x16x!tt.ptr>, tensor<8x16xi64> loc(#loc68) + %tmp0_25 = tt.broadcast %r0_mask_12 : tensor<1x16xi1> -> tensor<8x16xi1> loc(#loc69) + %tmp0_26 = tt.broadcast %xmask_4 : tensor<8x1xi1> -> tensor<8x16xi1> loc(#loc69) + %tmp0_27 = arith.andi %tmp0_25, %tmp0_26 : tensor<8x16xi1> loc(#loc69) + %tmp0_28 = tt.load %tmp0_24, %tmp0_27, %cst evictionPolicy = evict_last : tensor<8x16x!tt.ptr> loc(#loc70) + %tmp1 = arith.extsi %tmp0_28 : tensor<8x16xi32> to tensor<8x16xi64> loc(#loc71) + %tmp4 = arith.addi %_tmp3_10, %tmp1 : tensor<8x16xi64> loc(#loc72) + %_tmp3_29 = arith.select %tmp0_27, %tmp4, %_tmp3_10 : tensor<8x16xi1>, tensor<8x16xi64> loc(#loc73) + scf.yield %_tmp3_29 : tensor<8x16xi64> loc(#loc27) + } loc(#loc60) + %tmp3 = "tt.reduce"(%_tmp3_8) <{axis = 1 : i32}> ({ + ^bb0(%tmp3_10: i64 loc(callsite(#loc1 at #loc74)), %tmp3_11: i64 loc(callsite(#loc1 at #loc74))): + %tmp3_12 = arith.addi %tmp3_10, %tmp3_11 : i64 loc(#loc80) + tt.reduce.return %tmp3_12 : i64 loc(#loc78) + }) : (tensor<8x16xi64>) -> tensor<8xi64> loc(#loc78) + %tmp3_9 = tt.expand_dims %tmp3 {axis = 1 : i32} : tensor<8xi64> -> tensor<8x1xi64> loc(#loc75) + %tmp5 = arith.trunci %tmp3_9 : tensor<8x1xi64> to tensor<8x1xi32> loc(#loc76) + %0 = arith.cmpi sle, %ks0, %c1_i64 : i64 loc(#loc33) + %1 = arith.cmpi sgt, %ks0, %c1_i64 : i64 loc(#loc34) + %2 = arith.extui %1 : i1 to i64 loc(#loc35) + %3 = arith.muli %ks0, %2 : i64 loc(#loc35) + %4 = arith.extui %0 : i1 to i64 loc(#loc77) + %5 = arith.addi %4, %3 : i64 loc(#loc36) + %6 = tt.splat %5 : i64 -> tensor<8x1xi64> loc(#loc38) + %7 = arith.muli %x1, %6 : tensor<8x1xi64> loc(#loc38) + %8 = arith.addi %x0_7, %7 : tensor<8x1xi64> loc(#loc39) + %9 = tt.splat %out_ptr1 : !tt.ptr -> tensor<8x1x!tt.ptr> loc(#loc40) + %10 = tt.addptr %9, %8 : tensor<8x1x!tt.ptr>, tensor<8x1xi64> loc(#loc40) + tt.store %10, %tmp5, %xmask_4 : tensor<8x1x!tt.ptr> loc(#loc41) + tt.return loc(#loc42) + } loc(#loc) +} loc(#loc) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":30:40) +#loc3 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":28:43) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":21:28) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":21:33) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":22:36) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":22:44) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":22:23) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":23:21) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":24:27) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":24:37) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":26:19) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":27:19) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":31:31) +#loc15 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":32:29) +#loc16 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":36:43) +#loc17 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":36:39) +#loc18 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":36:54) +#loc19 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":36:58) +#loc20 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":36:50) +#loc21 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":36:34) +#loc22 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":36:73) +#loc23 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":36:63) +#loc24 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":37:23) +#loc25 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":39:23) +#loc26 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":40:48) +#loc27 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":40:8) +#loc28 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:36) +#loc30 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:15) +#loc31 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":41:28) +#loc32 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":42:19) +#loc33 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":43:49) +#loc34 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":43:75) +#loc35 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":43:66) +#loc36 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":43:57) +#loc37 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":43:41) +#loc38 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":43:34) +#loc39 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":43:30) +#loc40 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":43:25) +#loc41 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":43:88) +#loc42 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wc/cwc2jctz67s7n6fotl22gr44glijrcfxpeblvk32hornlllznx2m.py":43:4) +#loc49 = loc("_tmp3"(#loc3)) +#loc50 = loc("xoffset"(#loc4)) +#loc51 = loc("xoffset"(#loc5)) +#loc52 = loc("xindex"(#loc6)) +#loc53 = loc("xindex"(#loc7)) +#loc54 = loc("xindex"(#loc8)) +#loc55 = loc("xmask"(#loc9)) +#loc56 = loc("r0_base"(#loc10)) +#loc57 = loc("r0_base"(#loc11)) +#loc58 = loc("x0"(#loc12)) +#loc59 = loc("x1"(#loc13)) +#loc60 = loc("_tmp3"(#loc2)) +#loc61 = loc("r0_index"(#loc14)) +#loc62 = loc("r0_mask"(#loc15)) +#loc63 = loc("tmp0"(#loc16)) +#loc64 = loc("tmp0"(#loc17)) +#loc65 = loc("tmp0"(#loc18)) +#loc66 = loc("tmp0"(#loc19)) +#loc67 = loc("tmp0"(#loc20)) +#loc68 = loc("tmp0"(#loc21)) +#loc69 = loc("tmp0"(#loc22)) +#loc70 = loc("tmp0"(#loc23)) +#loc71 = loc("tmp1"(#loc24)) +#loc72 = loc("tmp4"(#loc25)) +#loc73 = loc("_tmp3"(#loc26)) +#loc75 = loc("tmp3"(#loc31)) +#loc76 = loc("tmp5"(#loc32)) +#loc77 = loc(fused[#loc36, #loc37]) +#loc78 = loc(callsite(#loc28 at #loc74)) +#loc80 = loc(callsite(#loc30 at #loc78)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/K6AHUKXY3DG2PXK4IJUBF54N3VWT4GZ62CFMGCGQUPTTKOJPUFHA/__grp__triton_red_fused_eq_mul_squeeze_sum_2.json b/SpecForge-ext/cache/compiled_kernels/triton/0/K6AHUKXY3DG2PXK4IJUBF54N3VWT4GZ62CFMGCGQUPTTKOJPUFHA/__grp__triton_red_fused_eq_mul_squeeze_sum_2.json new file mode 100644 index 0000000000000000000000000000000000000000..8053573a7489d202584b610dba93c13ca879a092 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/K6AHUKXY3DG2PXK4IJUBF54N3VWT4GZ62CFMGCGQUPTTKOJPUFHA/__grp__triton_red_fused_eq_mul_squeeze_sum_2.json @@ -0,0 +1 @@ +{"child_paths": {"triton_red_fused_eq_mul_squeeze_sum_2.source": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/K6AHUKXY3DG2PXK4IJUBF54N3VWT4GZ62CFMGCGQUPTTKOJPUFHA/triton_red_fused_eq_mul_squeeze_sum_2.source", "triton_red_fused_eq_mul_squeeze_sum_2.ttir": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/K6AHUKXY3DG2PXK4IJUBF54N3VWT4GZ62CFMGCGQUPTTKOJPUFHA/triton_red_fused_eq_mul_squeeze_sum_2.ttir", "triton_red_fused_eq_mul_squeeze_sum_2.ttgir": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/K6AHUKXY3DG2PXK4IJUBF54N3VWT4GZ62CFMGCGQUPTTKOJPUFHA/triton_red_fused_eq_mul_squeeze_sum_2.ttgir", "triton_red_fused_eq_mul_squeeze_sum_2.llir": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/K6AHUKXY3DG2PXK4IJUBF54N3VWT4GZ62CFMGCGQUPTTKOJPUFHA/triton_red_fused_eq_mul_squeeze_sum_2.llir", "triton_red_fused_eq_mul_squeeze_sum_2.ptx": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/K6AHUKXY3DG2PXK4IJUBF54N3VWT4GZ62CFMGCGQUPTTKOJPUFHA/triton_red_fused_eq_mul_squeeze_sum_2.ptx", "triton_red_fused_eq_mul_squeeze_sum_2.cubin": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/K6AHUKXY3DG2PXK4IJUBF54N3VWT4GZ62CFMGCGQUPTTKOJPUFHA/triton_red_fused_eq_mul_squeeze_sum_2.cubin", "triton_red_fused_eq_mul_squeeze_sum_2.json": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/K6AHUKXY3DG2PXK4IJUBF54N3VWT4GZ62CFMGCGQUPTTKOJPUFHA/triton_red_fused_eq_mul_squeeze_sum_2.json"}} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/K6AHUKXY3DG2PXK4IJUBF54N3VWT4GZ62CFMGCGQUPTTKOJPUFHA/triton_red_fused_eq_mul_squeeze_sum_2.cubin b/SpecForge-ext/cache/compiled_kernels/triton/0/K6AHUKXY3DG2PXK4IJUBF54N3VWT4GZ62CFMGCGQUPTTKOJPUFHA/triton_red_fused_eq_mul_squeeze_sum_2.cubin new file mode 100644 index 0000000000000000000000000000000000000000..4e8bab62ca9644a78a245c24206ce45fe362ed94 Binary files /dev/null and b/SpecForge-ext/cache/compiled_kernels/triton/0/K6AHUKXY3DG2PXK4IJUBF54N3VWT4GZ62CFMGCGQUPTTKOJPUFHA/triton_red_fused_eq_mul_squeeze_sum_2.cubin differ diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/K6AHUKXY3DG2PXK4IJUBF54N3VWT4GZ62CFMGCGQUPTTKOJPUFHA/triton_red_fused_eq_mul_squeeze_sum_2.json b/SpecForge-ext/cache/compiled_kernels/triton/0/K6AHUKXY3DG2PXK4IJUBF54N3VWT4GZ62CFMGCGQUPTTKOJPUFHA/triton_red_fused_eq_mul_squeeze_sum_2.json new file mode 100644 index 0000000000000000000000000000000000000000..487b9f7dd57ba85fc3c69acd3b2af000084d49a0 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/K6AHUKXY3DG2PXK4IJUBF54N3VWT4GZ62CFMGCGQUPTTKOJPUFHA/triton_red_fused_eq_mul_squeeze_sum_2.json @@ -0,0 +1 @@ +{"hash": "57807a2af8d8cda7dd5c426812f78ddd6d3e1b3ed08ac308d0a3e735392fa14e", "target": {"backend": "cuda", "arch": 90, "warp_size": 32}, "num_warps": 16, "num_ctas": 1, "num_stages": 1, "warp_size": 32, "maxnreg": null, "cluster_dims": [1, 1, 1], "ptx_version": null, "ptx_options": null, "ir_override": null, "enable_fp_fusion": true, "launch_cooperative_grid": false, "launch_pdl": false, "supported_fp8_dtypes": ["fp8e4b15", "fp8e4nv", "fp8e5"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b15"], "default_dot_input_precision": "tf32", "allowed_dot_input_precisions": ["tf32", "tf32x3", "ieee"], "max_num_imprecise_acc_default": 1073741824, "extern_libs": [["libdevice", "/workspace/specforge/lib/python3.11/site-packages/triton/backends/nvidia/lib/libdevice.10.bc"]], "debug": true, "backend_name": "cuda", "sanitize_overflow": false, "arch": "sm90", "instrumentation_mode": "", "triton_version": "3.5.1", "tensordesc_meta": [], "shared": 128, "tmem_size": 0, "global_scratch_size": 0, "global_scratch_align": 1, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "triton_red_fused_eq_mul_squeeze_sum_2"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/K6AHUKXY3DG2PXK4IJUBF54N3VWT4GZ62CFMGCGQUPTTKOJPUFHA/triton_red_fused_eq_mul_squeeze_sum_2.llir b/SpecForge-ext/cache/compiled_kernels/triton/0/K6AHUKXY3DG2PXK4IJUBF54N3VWT4GZ62CFMGCGQUPTTKOJPUFHA/triton_red_fused_eq_mul_squeeze_sum_2.llir new file mode 100644 index 0000000000000000000000000000000000000000..46d848454735265fe1304f407049ff4558287456 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/K6AHUKXY3DG2PXK4IJUBF54N3VWT4GZ62CFMGCGQUPTTKOJPUFHA/triton_red_fused_eq_mul_squeeze_sum_2.llir @@ -0,0 +1,312 @@ +; ModuleID = 'LLVMDialectModule' +source_filename = "LLVMDialectModule" +target datalayout = "e-p3:32:32-p4:32:32-p5:32:32-p6:32:32-p7:32:32-i64:64-i128:128-v16:16-v32:32-n16:32:64" + +@global_smem = external addrspace(3) global [0 x i8], align 16 + +; Function Attrs: nounwind +define ptx_kernel void @triton_red_fused_eq_mul_squeeze_sum_2(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, ptr addrspace(1) %3, i64 %4, i32 %5, i32 %6, ptr addrspace(1) readnone captures(none) %7, ptr addrspace(1) readnone captures(none) %8) local_unnamed_addr #0 !dbg !4 { + %10 = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !dbg !7 + %11 = icmp samesign ult i32 %10, 2, !dbg !8 + %12 = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !dbg !9 + %13 = and i32 %12, 511, !dbg !9 + %14 = shl i64 %4, 2, !dbg !10 + %15 = zext nneg i32 %10 to i64, !dbg !11 + %16 = mul i64 %14, %15, !dbg !11 + %17 = icmp sgt i32 %6, 0, !dbg !12 + br i1 %17, label %.lr.ph.preheader, label %._crit_edge, !dbg !12 + +.lr.ph.preheader: ; preds = %9 + %18 = insertelement <4 x i32> poison, i32 %6, i64 0 + %19 = shufflevector <4 x i32> %18, <4 x i32> poison, <4 x i32> zeroinitializer + %20 = insertelement <2 x i64> , i64 %4, i64 0 + %21 = insertelement <4 x i1> poison, i1 %11, i64 0 + %22 = shufflevector <4 x i1> %21, <4 x i1> poison, <4 x i32> zeroinitializer + br label %.lr.ph, !dbg !12 + +.lr.ph: ; preds = %.lr.ph.preheader, %.lr.ph + %23 = phi i32 [ %132, %.lr.ph ], [ 0, %.lr.ph.preheader ] + %24 = phi <4 x i64> [ %131, %.lr.ph ], [ zeroinitializer, %.lr.ph.preheader ] + %25 = or disjoint i32 %12, %23, !dbg !13 + %26 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #5, !dbg !14 + %27 = or disjoint i32 %23, %13, !dbg !13 + %28 = or i32 %25, 512, !dbg !13 + %29 = or disjoint i32 %27, 1024, !dbg !13 + %30 = or i32 %25, 1536, !dbg !13 + %31 = insertelement <4 x i32> poison, i32 %27, i64 0, !dbg !15 + %32 = insertelement <4 x i32> %31, i32 %28, i64 1, !dbg !15 + %33 = insertelement <4 x i32> %32, i32 %29, i64 2, !dbg !15 + %34 = insertelement <4 x i32> %33, i32 %30, i64 3, !dbg !15 + %35 = icmp slt <4 x i32> %34, %19, !dbg !15 + %36 = sext i32 %27 to i64, !dbg !16 + %37 = sext i32 %28 to i64, !dbg !16 + %38 = sext i32 %29 to i64, !dbg !16 + %39 = sext i32 %30 to i64, !dbg !16 + %40 = add i64 %16, %36, !dbg !16 + %41 = add i64 %16, %37, !dbg !16 + %42 = add i64 %16, %38, !dbg !16 + %43 = add i64 %16, %39, !dbg !16 + %44 = sdiv i64 %40, %4, !dbg !17 + %45 = sdiv i64 %41, %4, !dbg !17 + %46 = sdiv i64 %42, %4, !dbg !17 + %47 = sdiv i64 %43, %4, !dbg !17 + %48 = insertelement <2 x i64> poison, i64 %36, i64 0, !dbg !18 + %49 = insertelement <2 x i64> %48, i64 %44, i64 1, !dbg !18 + %50 = srem <2 x i64> %49, %20, !dbg !18 + %51 = extractelement <2 x i64> %50, i64 1, !dbg !19 + %52 = mul i64 %51, %4, !dbg !19 + %53 = insertelement <2 x i64> poison, i64 %37, i64 0, !dbg !18 + %54 = insertelement <2 x i64> %53, i64 %45, i64 1, !dbg !18 + %55 = srem <2 x i64> %54, %20, !dbg !18 + %56 = extractelement <2 x i64> %55, i64 1, !dbg !19 + %57 = mul i64 %56, %4, !dbg !19 + %58 = insertelement <2 x i64> poison, i64 %38, i64 0, !dbg !18 + %59 = insertelement <2 x i64> %58, i64 %46, i64 1, !dbg !18 + %60 = srem <2 x i64> %59, %20, !dbg !18 + %61 = extractelement <2 x i64> %60, i64 1, !dbg !19 + %62 = mul i64 %61, %4, !dbg !19 + %63 = insertelement <2 x i64> poison, i64 %39, i64 0, !dbg !18 + %64 = insertelement <2 x i64> %63, i64 %47, i64 1, !dbg !18 + %65 = srem <2 x i64> %64, %20, !dbg !18 + %66 = extractelement <2 x i64> %65, i64 1, !dbg !19 + %67 = mul i64 %66, %4, !dbg !19 + %68 = extractelement <2 x i64> %50, i64 0, !dbg !20 + %69 = add i64 %68, %52, !dbg !20 + %70 = extractelement <2 x i64> %55, i64 0, !dbg !20 + %71 = add i64 %70, %57, !dbg !20 + %72 = extractelement <2 x i64> %60, i64 0, !dbg !20 + %73 = add i64 %72, %62, !dbg !20 + %74 = extractelement <2 x i64> %65, i64 0, !dbg !20 + %75 = add i64 %74, %67, !dbg !20 + %76 = getelementptr i64, ptr addrspace(1) %0, i64 %69, !dbg !21 + %77 = getelementptr i64, ptr addrspace(1) %0, i64 %71, !dbg !21 + %78 = getelementptr i64, ptr addrspace(1) %0, i64 %73, !dbg !21 + %79 = getelementptr i64, ptr addrspace(1) %0, i64 %75, !dbg !21 + %80 = and <4 x i1> %22, %35, !dbg !22 + %81 = extractelement <4 x i1> %80, i64 0, !dbg !23 + %82 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b64 { $0 }, [ $1 + 0 ], $2;", "=l,l,l,b"(ptr addrspace(1) %76, i64 %26, i1 %81) #5, !dbg !14 + %83 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #5, !dbg !14 + %84 = extractelement <4 x i1> %80, i64 1, !dbg !23 + %85 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b64 { $0 }, [ $1 + 0 ], $2;", "=l,l,l,b"(ptr addrspace(1) %77, i64 %83, i1 %84) #5, !dbg !14 + %86 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #5, !dbg !14 + %87 = extractelement <4 x i1> %80, i64 2, !dbg !23 + %88 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b64 { $0 }, [ $1 + 0 ], $2;", "=l,l,l,b"(ptr addrspace(1) %78, i64 %86, i1 %87) #5, !dbg !14 + %89 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #5, !dbg !14 + %90 = extractelement <4 x i1> %80, i64 3, !dbg !23 + %91 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b64 { $0 }, [ $1 + 0 ], $2;", "=l,l,l,b"(ptr addrspace(1) %79, i64 %89, i1 %90) #5, !dbg !14 + %92 = getelementptr i64, ptr addrspace(1) %1, i64 %69, !dbg !24 + %93 = getelementptr i64, ptr addrspace(1) %1, i64 %71, !dbg !24 + %94 = getelementptr i64, ptr addrspace(1) %1, i64 %73, !dbg !24 + %95 = getelementptr i64, ptr addrspace(1) %1, i64 %75, !dbg !24 + %96 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #5, !dbg !25 + %97 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b64 { $0 }, [ $1 + 0 ], $2;", "=l,l,l,b"(ptr addrspace(1) %92, i64 %96, i1 %81) #5, !dbg !25 + %98 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #5, !dbg !25 + %99 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b64 { $0 }, [ $1 + 0 ], $2;", "=l,l,l,b"(ptr addrspace(1) %93, i64 %98, i1 %84) #5, !dbg !25 + %100 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #5, !dbg !25 + %101 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b64 { $0 }, [ $1 + 0 ], $2;", "=l,l,l,b"(ptr addrspace(1) %94, i64 %100, i1 %87) #5, !dbg !25 + %102 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #5, !dbg !25 + %103 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b64 { $0 }, [ $1 + 0 ], $2;", "=l,l,l,b"(ptr addrspace(1) %95, i64 %102, i1 %90) #5, !dbg !25 + %104 = getelementptr i64, ptr addrspace(1) %2, i64 %69, !dbg !26 + %105 = getelementptr i64, ptr addrspace(1) %2, i64 %71, !dbg !26 + %106 = getelementptr i64, ptr addrspace(1) %2, i64 %73, !dbg !26 + %107 = getelementptr i64, ptr addrspace(1) %2, i64 %75, !dbg !26 + %108 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #5, !dbg !23 + %109 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b64 { $0 }, [ $1 + 0 ], $2;", "=l,l,l,b"(ptr addrspace(1) %104, i64 %108, i1 %81) #5, !dbg !23 + %110 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #5, !dbg !23 + %111 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b64 { $0 }, [ $1 + 0 ], $2;", "=l,l,l,b"(ptr addrspace(1) %105, i64 %110, i1 %84) #5, !dbg !23 + %112 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #5, !dbg !23 + %113 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b64 { $0 }, [ $1 + 0 ], $2;", "=l,l,l,b"(ptr addrspace(1) %106, i64 %112, i1 %87) #5, !dbg !23 + %114 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #5, !dbg !23 + %115 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b64 { $0 }, [ $1 + 0 ], $2;", "=l,l,l,b"(ptr addrspace(1) %107, i64 %114, i1 %90) #5, !dbg !23 + %116 = insertelement <4 x i64> poison, i64 %82, i64 0, !dbg !27 + %117 = insertelement <4 x i64> %116, i64 %85, i64 1, !dbg !27 + %118 = insertelement <4 x i64> %117, i64 %88, i64 2, !dbg !27 + %119 = insertelement <4 x i64> %118, i64 %91, i64 3, !dbg !27 + %120 = insertelement <4 x i64> poison, i64 %97, i64 0, !dbg !27 + %121 = insertelement <4 x i64> %120, i64 %99, i64 1, !dbg !27 + %122 = insertelement <4 x i64> %121, i64 %101, i64 2, !dbg !27 + %123 = insertelement <4 x i64> %122, i64 %103, i64 3, !dbg !27 + %124 = icmp eq <4 x i64> %119, %123, !dbg !27 + %125 = select <4 x i1> %80, <4 x i1> %124, <4 x i1> zeroinitializer, !dbg !28 + %126 = insertelement <4 x i64> poison, i64 %109, i64 0, !dbg !28 + %127 = insertelement <4 x i64> %126, i64 %111, i64 1, !dbg !28 + %128 = insertelement <4 x i64> %127, i64 %113, i64 2, !dbg !28 + %129 = insertelement <4 x i64> %128, i64 %115, i64 3, !dbg !28 + %130 = select <4 x i1> %125, <4 x i64> %129, <4 x i64> zeroinitializer, !dbg !28 + %131 = add <4 x i64> %130, %24, !dbg !28 + %132 = add i32 %23, 2048, !dbg !12 + %133 = icmp slt i32 %132, %6, !dbg !12 + br i1 %133, label %.lr.ph, label %._crit_edge.loopexit, !dbg !12 + +._crit_edge.loopexit: ; preds = %.lr.ph + %134 = tail call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> %131), !dbg !9 + br label %._crit_edge, !dbg !9 + +._crit_edge: ; preds = %._crit_edge.loopexit, %9 + %135 = phi i64 [ 0, %9 ], [ %134, %._crit_edge.loopexit ], !dbg !29 + %136 = and i32 %12, 31, !dbg !9 + %137 = lshr i32 %12, 5, !dbg !9 + %extelt.offset = lshr i64 %135, 32, !dbg !33 + %138 = trunc nuw i64 %extelt.offset to i32, !dbg !33 + %139 = trunc i64 %135 to i32, !dbg !33 + %140 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %139, i32 16, i32 31), !dbg !33 + %141 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %138, i32 16, i32 31), !dbg !33 + %142 = insertelement <2 x i32> poison, i32 %140, i64 0, !dbg !33 + %143 = insertelement <2 x i32> %142, i32 %141, i64 1, !dbg !33 + %144 = bitcast <2 x i32> %143 to i64, !dbg !33 + %145 = add i64 %135, %144, !dbg !29 + %extelt.offset2 = lshr i64 %145, 32, !dbg !33 + %146 = trunc nuw i64 %extelt.offset2 to i32, !dbg !33 + %147 = trunc i64 %145 to i32, !dbg !33 + %148 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %147, i32 8, i32 31), !dbg !33 + %149 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %146, i32 8, i32 31), !dbg !33 + %150 = insertelement <2 x i32> poison, i32 %148, i64 0, !dbg !33 + %151 = insertelement <2 x i32> %150, i32 %149, i64 1, !dbg !33 + %152 = bitcast <2 x i32> %151 to i64, !dbg !33 + %153 = add i64 %145, %152, !dbg !29 + %extelt.offset3 = lshr i64 %153, 32, !dbg !33 + %154 = trunc nuw i64 %extelt.offset3 to i32, !dbg !33 + %155 = trunc i64 %153 to i32, !dbg !33 + %156 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %155, i32 4, i32 31), !dbg !33 + %157 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %154, i32 4, i32 31), !dbg !33 + %158 = insertelement <2 x i32> poison, i32 %156, i64 0, !dbg !33 + %159 = insertelement <2 x i32> %158, i32 %157, i64 1, !dbg !33 + %160 = bitcast <2 x i32> %159 to i64, !dbg !33 + %161 = add i64 %153, %160, !dbg !29 + %extelt.offset4 = lshr i64 %161, 32, !dbg !33 + %162 = trunc nuw i64 %extelt.offset4 to i32, !dbg !33 + %163 = trunc i64 %161 to i32, !dbg !33 + %164 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %163, i32 2, i32 31), !dbg !33 + %165 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %162, i32 2, i32 31), !dbg !33 + %166 = insertelement <2 x i32> poison, i32 %164, i64 0, !dbg !33 + %167 = insertelement <2 x i32> %166, i32 %165, i64 1, !dbg !33 + %168 = bitcast <2 x i32> %167 to i64, !dbg !33 + %169 = add i64 %161, %168, !dbg !29 + %extelt.offset5 = lshr i64 %169, 32, !dbg !33 + %170 = trunc nuw i64 %extelt.offset5 to i32, !dbg !33 + %171 = trunc i64 %169 to i32, !dbg !33 + %172 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %171, i32 1, i32 31), !dbg !33 + %173 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %170, i32 1, i32 31), !dbg !33 + %174 = insertelement <2 x i32> poison, i32 %172, i64 0, !dbg !33 + %175 = insertelement <2 x i32> %174, i32 %173, i64 1, !dbg !33 + %176 = bitcast <2 x i32> %175 to i64, !dbg !33 + %177 = add i64 %169, %176, !dbg !29 + %178 = and i32 %137, 15, !dbg !33 + %179 = icmp eq i32 %136, 0, !dbg !33 + %180 = getelementptr i64, ptr addrspace(3) @global_smem, i32 %178, !dbg !33 + %181 = insertelement <1 x i64> poison, i64 %177, i64 0, !dbg !33 + tail call void asm sideeffect "@$2 st.shared.b64 [ $0 + 0 ], $1;", "r,l,b"(ptr addrspace(3) %180, <1 x i64> %181, i1 %179) #5, !dbg !33 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !33 + %182 = icmp samesign ult i32 %12, 16, !dbg !33 + %183 = getelementptr i64, ptr addrspace(3) @global_smem, i32 %12, !dbg !33 + %184 = tail call i64 asm sideeffect "@$2 ld.shared.b64 $0, [ $1 + 0 ];", "=l,r,b"(ptr addrspace(3) %183, i1 %182) #5, !dbg !33 + %extelt.offset6 = lshr i64 %184, 32, !dbg !33 + %185 = trunc nuw i64 %extelt.offset6 to i32, !dbg !33 + %186 = trunc i64 %184 to i32, !dbg !33 + %187 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %186, i32 8, i32 31), !dbg !33 + %188 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %185, i32 8, i32 31), !dbg !33 + %189 = insertelement <2 x i32> poison, i32 %187, i64 0, !dbg !33 + %190 = insertelement <2 x i32> %189, i32 %188, i64 1, !dbg !33 + %191 = bitcast <2 x i32> %190 to i64, !dbg !33 + %192 = add i64 %184, %191, !dbg !29 + %extelt.offset7 = lshr i64 %192, 32, !dbg !33 + %193 = trunc nuw i64 %extelt.offset7 to i32, !dbg !33 + %194 = trunc i64 %192 to i32, !dbg !33 + %195 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %194, i32 4, i32 31), !dbg !33 + %196 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %193, i32 4, i32 31), !dbg !33 + %197 = insertelement <2 x i32> poison, i32 %195, i64 0, !dbg !33 + %198 = insertelement <2 x i32> %197, i32 %196, i64 1, !dbg !33 + %199 = bitcast <2 x i32> %198 to i64, !dbg !33 + %200 = add i64 %192, %199, !dbg !29 + %extelt.offset8 = lshr i64 %200, 32, !dbg !33 + %201 = trunc nuw i64 %extelt.offset8 to i32, !dbg !33 + %202 = trunc i64 %200 to i32, !dbg !33 + %203 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %202, i32 2, i32 31), !dbg !33 + %204 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %201, i32 2, i32 31), !dbg !33 + %205 = insertelement <2 x i32> poison, i32 %203, i64 0, !dbg !33 + %206 = insertelement <2 x i32> %205, i32 %204, i64 1, !dbg !33 + %207 = bitcast <2 x i32> %206 to i64, !dbg !33 + %208 = add i64 %200, %207, !dbg !29 + %extelt.offset9 = lshr i64 %208, 32, !dbg !33 + %209 = trunc nuw i64 %extelt.offset9 to i32, !dbg !33 + %210 = trunc i64 %208 to i32, !dbg !33 + %211 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %210, i32 1, i32 31), !dbg !33 + %212 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %209, i32 1, i32 31), !dbg !33 + %213 = insertelement <2 x i32> poison, i32 %211, i64 0, !dbg !33 + %214 = insertelement <2 x i32> %213, i32 %212, i64 1, !dbg !33 + %215 = bitcast <2 x i32> %214 to i64, !dbg !33 + %216 = add i64 %208, %215, !dbg !29 + %217 = icmp eq i32 %12, 0, !dbg !33 + %218 = insertelement <1 x i64> poison, i64 %216, i64 0, !dbg !33 + tail call void asm sideeffect "@$2 st.shared.b64 [ $0 + 0 ], $1;", "r,l,b"(ptr addrspace(3) %183, <1 x i64> %218, i1 %217) #5, !dbg !33 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !33 + %219 = load i64, ptr addrspace(3) @global_smem, align 16, !dbg !33 + %220 = getelementptr i64, ptr addrspace(1) %3, i64 %15, !dbg !34 + %221 = icmp eq i32 %13, 0, !dbg !35 + %222 = and i1 %11, %221, !dbg !35 + tail call void asm sideeffect "@$2 st.global.b64 [ $1 + 0 ], { $0 };", "l,l,b"(i64 %219, ptr addrspace(1) %220, i1 %222) #5, !dbg !35 + ret void, !dbg !36 +} + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare noundef range(i32 0, 2147483647) i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #1 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare noundef range(i32 0, 1024) i32 @llvm.nvvm.read.ptx.sreg.tid.x() #1 + +; Function Attrs: convergent nocallback nounwind memory(inaccessiblemem: readwrite) +declare i32 @llvm.nvvm.shfl.sync.bfly.i32(i32, i32, i32, i32) #2 + +; Function Attrs: convergent nocallback nounwind +declare void @llvm.nvvm.barrier.cta.sync.aligned.all(i32) #3 + +; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare i64 @llvm.vector.reduce.add.v4i64(<4 x i64>) #4 + +attributes #0 = { nounwind "nvvm.reqntid"="512" } +attributes #1 = { mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) } +attributes #2 = { convergent nocallback nounwind memory(inaccessiblemem: readwrite) } +attributes #3 = { convergent nocallback nounwind } +attributes #4 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } +attributes #5 = { nounwind } + +!llvm.dbg.cu = !{!0} +!llvm.module.flags = !{!2, !3} + +!0 = distinct !DICompileUnit(language: DW_LANG_C, file: !1, producer: "triton", isOptimized: true, runtimeVersion: 0, emissionKind: LineTablesOnly) +!1 = !DIFile(filename: "cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py", directory: "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec") +!2 = !{i32 2, !"Debug Info Version", i32 3} +!3 = !{i32 4, !"nvvm-reflect-ftz", i32 1} +!4 = distinct !DISubprogram(name: "triton_red_fused_eq_mul_squeeze_sum_2", linkageName: "triton_red_fused_eq_mul_squeeze_sum_2", scope: !1, file: !1, line: 18, type: !5, scopeLine: 18, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !0) +!5 = !DISubroutineType(cc: DW_CC_normal, types: !6) +!6 = !{} +!7 = !DILocation(line: 22, column: 28, scope: !4) +!8 = !DILocation(line: 24, column: 21, scope: !4) +!9 = !DILocation(line: 25, column: 37, scope: !4) +!10 = !DILocation(line: 35, column: 51, scope: !4) +!11 = !DILocation(line: 35, column: 55, scope: !4) +!12 = !DILocation(line: 29, column: 40, scope: !4) +!13 = !DILocation(line: 30, column: 31, scope: !4) +!14 = !DILocation(line: 35, column: 92, scope: !4) +!15 = !DILocation(line: 31, column: 29, scope: !4) +!16 = !DILocation(line: 35, column: 49, scope: !4) +!17 = !DILocation(line: 35, column: 62, scope: !4) +!18 = !DILocation(line: 35, column: 84, scope: !4) +!19 = !DILocation(line: 35, column: 40, scope: !4) +!20 = !DILocation(line: 35, column: 77, scope: !4) +!21 = !DILocation(line: 35, column: 34, scope: !4) +!22 = !DILocation(line: 35, column: 102, scope: !4) +!23 = !DILocation(line: 37, column: 92, scope: !4) +!24 = !DILocation(line: 36, column: 34, scope: !4) +!25 = !DILocation(line: 36, column: 92, scope: !4) +!26 = !DILocation(line: 37, column: 34, scope: !4) +!27 = !DILocation(line: 38, column: 23, scope: !4) +!28 = !DILocation(line: 43, column: 48, scope: !4) +!29 = !DILocation(line: 261, column: 15, scope: !30, inlinedAt: !32) +!30 = distinct !DILexicalBlockFile(scope: !4, file: !31, discriminator: 0) +!31 = !DIFile(filename: "standard.py", directory: "/workspace/specforge/lib/python3.11/site-packages/triton/language") +!32 = !DILocation(line: 44, column: 25, scope: !4) +!33 = !DILocation(line: 291, column: 36, scope: !30, inlinedAt: !32) +!34 = !DILocation(line: 45, column: 25, scope: !4) +!35 = !DILocation(line: 45, column: 36, scope: !4) +!36 = !DILocation(line: 45, column: 4, scope: !4) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/K6AHUKXY3DG2PXK4IJUBF54N3VWT4GZ62CFMGCGQUPTTKOJPUFHA/triton_red_fused_eq_mul_squeeze_sum_2.ptx b/SpecForge-ext/cache/compiled_kernels/triton/0/K6AHUKXY3DG2PXK4IJUBF54N3VWT4GZ62CFMGCGQUPTTKOJPUFHA/triton_red_fused_eq_mul_squeeze_sum_2.ptx new file mode 100644 index 0000000000000000000000000000000000000000..87489ee54e8a2ad10373eaa6118c35a4445a6914 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/K6AHUKXY3DG2PXK4IJUBF54N3VWT4GZ62CFMGCGQUPTTKOJPUFHA/triton_red_fused_eq_mul_squeeze_sum_2.ptx @@ -0,0 +1,703 @@ +// +// Generated by LLVM NVPTX Back-End +// + +.version 8.7 +.target sm_90a +.address_size 64 + + // .globl triton_red_fused_eq_mul_squeeze_sum_2 // -- Begin function triton_red_fused_eq_mul_squeeze_sum_2 +.extern .shared .align 16 .b8 global_smem[]; + // @triton_red_fused_eq_mul_squeeze_sum_2 +.visible .entry triton_red_fused_eq_mul_squeeze_sum_2( + .param .u64 .ptr .global .align 1 triton_red_fused_eq_mul_squeeze_sum_2_param_0, + .param .u64 .ptr .global .align 1 triton_red_fused_eq_mul_squeeze_sum_2_param_1, + .param .u64 .ptr .global .align 1 triton_red_fused_eq_mul_squeeze_sum_2_param_2, + .param .u64 .ptr .global .align 1 triton_red_fused_eq_mul_squeeze_sum_2_param_3, + .param .u64 triton_red_fused_eq_mul_squeeze_sum_2_param_4, + .param .u32 triton_red_fused_eq_mul_squeeze_sum_2_param_5, + .param .u32 triton_red_fused_eq_mul_squeeze_sum_2_param_6, + .param .u64 .ptr .global .align 1 triton_red_fused_eq_mul_squeeze_sum_2_param_7, + .param .u64 .ptr .global .align 1 triton_red_fused_eq_mul_squeeze_sum_2_param_8 +) +.reqntid 512 +{ + .reg .pred %p<37>; + .reg .b32 %r<77>; + .reg .b64 %rd<187>; + .loc 1 18 0 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:18:0 +$L__func_begin0: + .loc 1 18 0 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:18:0 + +// %bb.0: + ld.param.b32 %r13, [triton_red_fused_eq_mul_squeeze_sum_2_param_6]; + ld.param.b64 %rd39, [triton_red_fused_eq_mul_squeeze_sum_2_param_3]; +$L__tmp0: + .loc 1 22 28 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:22:28 + mov.u32 %r14, %ctaid.x; + .loc 1 25 37 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:25:37 + mov.u32 %r1, %tid.x; + and.b32 %r2, %r1, 511; + .loc 1 35 55 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:35:55 + cvt.u64.u32 %rd1, %r14; + .loc 1 29 40 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:29:40 + setp.lt.s32 %p5, %r13, 1; + mov.b64 %rd186, 0; + cvt.u32.u64 %r75, %rd1; + @%p5 bra $L__BB0_16; +// %bb.1: // %.lr.ph.preheader + .loc 1 0 40 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:0:40 + ld.param.b64 %rd40, [triton_red_fused_eq_mul_squeeze_sum_2_param_4]; + ld.param.b64 %rd38, [triton_red_fused_eq_mul_squeeze_sum_2_param_2]; + ld.param.b64 %rd37, [triton_red_fused_eq_mul_squeeze_sum_2_param_1]; + ld.param.b64 %rd36, [triton_red_fused_eq_mul_squeeze_sum_2_param_0]; + mul.lo.s64 %rd42, %rd40, %rd1; + shl.b64 %rd2, %rd42, 2; + .loc 1 24 21 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:24:21 + setp.lt.u32 %p1, %r75, 2; + mov.b64 %rd4, 8; + mov.b64 %rd178, 0; + mov.b32 %r76, 0; + mov.b64 %rd179, %rd178; + mov.b64 %rd180, %rd178; + mov.b64 %rd181, %rd178; + bra.uni $L__BB0_2; +$L__BB0_13: // in Loop: Header=BB0_2 Depth=1 + .loc 1 35 62 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:35:62 + div.s64 %rd185, %rd17, %rd40; +$L__BB0_14: // in Loop: Header=BB0_2 Depth=1 + .loc 1 31 29 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:31:29 + setp.lt.s32 %p22, %r8, %r13; + setp.lt.s32 %p23, %r9, %r13; + setp.lt.s32 %p24, %r10, %r13; + setp.lt.s32 %p25, %r11, %r13; + .loc 1 35 84 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:35:84 + rem.s64 %rd103, %rd10, %rd40; + rem.s64 %rd104, %rd182, %rd4; + rem.s64 %rd105, %rd11, %rd40; + rem.s64 %rd106, %rd183, %rd4; + rem.s64 %rd107, %rd12, %rd40; + rem.s64 %rd108, %rd184, %rd4; + rem.s64 %rd109, %rd13, %rd40; + rem.s64 %rd110, %rd185, %rd4; + .loc 1 35 77 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:35:77 + mad.lo.s64 %rd111, %rd104, %rd40, %rd103; + mad.lo.s64 %rd112, %rd106, %rd40, %rd105; + mad.lo.s64 %rd113, %rd108, %rd40, %rd107; + mad.lo.s64 %rd114, %rd110, %rd40, %rd109; + .loc 1 35 34 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:35:34 + shl.b64 %rd115, %rd111, 3; + add.s64 %rd57, %rd36, %rd115; + shl.b64 %rd116, %rd112, 3; + add.s64 %rd61, %rd36, %rd116; + shl.b64 %rd117, %rd113, 3; + add.s64 %rd65, %rd36, %rd117; + shl.b64 %rd118, %rd114, 3; + add.s64 %rd69, %rd36, %rd118; + .loc 1 35 102 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:35:102 + and.pred %p17, %p1, %p25; + and.pred %p16, %p1, %p24; + and.pred %p15, %p1, %p23; + and.pred %p14, %p1, %p22; + .loc 1 35 92 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:35:92 + // begin inline asm + mov.u64 %rd56, 0x0; + @%p14 ld.global.L1::evict_last.L2::cache_hint.b64 { %rd56 }, [ %rd57 + 0 ], %rd47; + // end inline asm + // begin inline asm + mov.u64 %rd59, 0x0; + createpolicy.fractional.L2::evict_last.b64 %rd59, 1.0; + // end inline asm + // begin inline asm + mov.u64 %rd60, 0x0; + @%p15 ld.global.L1::evict_last.L2::cache_hint.b64 { %rd60 }, [ %rd61 + 0 ], %rd59; + // end inline asm + // begin inline asm + mov.u64 %rd63, 0x0; + createpolicy.fractional.L2::evict_last.b64 %rd63, 1.0; + // end inline asm + // begin inline asm + mov.u64 %rd64, 0x0; + @%p16 ld.global.L1::evict_last.L2::cache_hint.b64 { %rd64 }, [ %rd65 + 0 ], %rd63; + // end inline asm + // begin inline asm + mov.u64 %rd67, 0x0; + createpolicy.fractional.L2::evict_last.b64 %rd67, 1.0; + // end inline asm + // begin inline asm + mov.u64 %rd68, 0x0; + @%p17 ld.global.L1::evict_last.L2::cache_hint.b64 { %rd68 }, [ %rd69 + 0 ], %rd67; + // end inline asm + .loc 1 36 34 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:36:34 + add.s64 %rd73, %rd37, %rd115; + add.s64 %rd77, %rd37, %rd116; + add.s64 %rd81, %rd37, %rd117; + add.s64 %rd85, %rd37, %rd118; + .loc 1 36 92 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:36:92 + // begin inline asm + mov.u64 %rd71, 0x0; + createpolicy.fractional.L2::evict_last.b64 %rd71, 1.0; + // end inline asm + // begin inline asm + mov.u64 %rd72, 0x0; + @%p14 ld.global.L1::evict_last.L2::cache_hint.b64 { %rd72 }, [ %rd73 + 0 ], %rd71; + // end inline asm + // begin inline asm + mov.u64 %rd75, 0x0; + createpolicy.fractional.L2::evict_last.b64 %rd75, 1.0; + // end inline asm + // begin inline asm + mov.u64 %rd76, 0x0; + @%p15 ld.global.L1::evict_last.L2::cache_hint.b64 { %rd76 }, [ %rd77 + 0 ], %rd75; + // end inline asm + // begin inline asm + mov.u64 %rd79, 0x0; + createpolicy.fractional.L2::evict_last.b64 %rd79, 1.0; + // end inline asm + // begin inline asm + mov.u64 %rd80, 0x0; + @%p16 ld.global.L1::evict_last.L2::cache_hint.b64 { %rd80 }, [ %rd81 + 0 ], %rd79; + // end inline asm + // begin inline asm + mov.u64 %rd83, 0x0; + createpolicy.fractional.L2::evict_last.b64 %rd83, 1.0; + // end inline asm + // begin inline asm + mov.u64 %rd84, 0x0; + @%p17 ld.global.L1::evict_last.L2::cache_hint.b64 { %rd84 }, [ %rd85 + 0 ], %rd83; + // end inline asm + .loc 1 37 34 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:37:34 + add.s64 %rd89, %rd38, %rd115; + add.s64 %rd93, %rd38, %rd116; + add.s64 %rd97, %rd38, %rd117; + add.s64 %rd101, %rd38, %rd118; + .loc 1 37 92 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:37:92 + // begin inline asm + mov.u64 %rd87, 0x0; + createpolicy.fractional.L2::evict_last.b64 %rd87, 1.0; + // end inline asm + // begin inline asm + mov.u64 %rd88, 0x0; + @%p14 ld.global.L1::evict_last.L2::cache_hint.b64 { %rd88 }, [ %rd89 + 0 ], %rd87; + // end inline asm + // begin inline asm + mov.u64 %rd91, 0x0; + createpolicy.fractional.L2::evict_last.b64 %rd91, 1.0; + // end inline asm + // begin inline asm + mov.u64 %rd92, 0x0; + @%p15 ld.global.L1::evict_last.L2::cache_hint.b64 { %rd92 }, [ %rd93 + 0 ], %rd91; + // end inline asm + // begin inline asm + mov.u64 %rd95, 0x0; + createpolicy.fractional.L2::evict_last.b64 %rd95, 1.0; + // end inline asm + // begin inline asm + mov.u64 %rd96, 0x0; + @%p16 ld.global.L1::evict_last.L2::cache_hint.b64 { %rd96 }, [ %rd97 + 0 ], %rd95; + // end inline asm + // begin inline asm + mov.u64 %rd99, 0x0; + createpolicy.fractional.L2::evict_last.b64 %rd99, 1.0; + // end inline asm + // begin inline asm + mov.u64 %rd100, 0x0; + @%p17 ld.global.L1::evict_last.L2::cache_hint.b64 { %rd100 }, [ %rd101 + 0 ], %rd99; + // end inline asm + .loc 1 38 23 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:38:23 + setp.eq.b64 %p26, %rd56, %rd72; + setp.eq.b64 %p27, %rd60, %rd76; + setp.eq.b64 %p28, %rd64, %rd80; + setp.eq.b64 %p29, %rd68, %rd84; + .loc 1 43 48 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:43:48 + selp.b64 %rd119, %rd100, 0, %p29; + selp.b64 %rd120, %rd119, 0, %p17; + selp.b64 %rd121, %rd96, 0, %p28; + selp.b64 %rd122, %rd121, 0, %p16; + selp.b64 %rd123, %rd92, 0, %p27; + selp.b64 %rd124, %rd123, 0, %p15; + selp.b64 %rd125, %rd88, 0, %p26; + selp.b64 %rd126, %rd125, 0, %p14; + add.s64 %rd178, %rd126, %rd178; + add.s64 %rd179, %rd124, %rd179; + add.s64 %rd180, %rd122, %rd180; + add.s64 %rd181, %rd120, %rd181; + .loc 1 29 40 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:29:40 + add.s32 %r76, %r76, 2048; + setp.lt.s32 %p30, %r76, %r13; + @%p30 bra $L__BB0_2; + bra.uni $L__BB0_15; +$L__BB0_2: // %.lr.ph + // =>This Inner Loop Header: Depth=1 + .loc 1 35 92 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:35:92 + add.s32 %r17, %r1, %r76; + // begin inline asm + mov.u64 %rd47, 0x0; + createpolicy.fractional.L2::evict_last.b64 %rd47, 1.0; + // end inline asm + .loc 1 30 31 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:30:31 + add.s32 %r8, %r2, %r76; + or.b32 %r9, %r17, 512; + .loc 1 35 49 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:35:49 + cvt.s64.s32 %rd10, %r8; + cvt.s64.s32 %rd11, %r9; + add.s64 %rd14, %rd2, %rd10; + add.s64 %rd15, %rd2, %rd11; + .loc 1 35 62 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:35:62 + or.b64 %rd48, %rd14, %rd40; + and.b64 %rd49, %rd48, -4294967296; + setp.ne.b64 %p6, %rd49, 0; + @%p6 bra $L__BB0_4; + bra.uni $L__BB0_3; +$L__BB0_4: // in Loop: Header=BB0_2 Depth=1 + div.s64 %rd182, %rd14, %rd40; + bra.uni $L__BB0_5; +$L__BB0_3: // in Loop: Header=BB0_2 Depth=1 + cvt.u32.u64 %r18, %rd40; + cvt.u32.u64 %r19, %rd14; + div.u32 %r20, %r19, %r18; + cvt.u64.u32 %rd182, %r20; +$L__BB0_5: // in Loop: Header=BB0_2 Depth=1 + .loc 1 0 0 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:0 + add.s32 %r10, %r8, 1024; + .loc 1 35 0 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:35 + cvt.s64.s32 %rd12, %r10; + add.s64 %rd16, %rd2, %rd12; + .loc 1 35 62 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:35:62 + or.b64 %rd50, %rd15, %rd40; + and.b64 %rd51, %rd50, -4294967296; + setp.ne.b64 %p7, %rd51, 0; + @%p7 bra $L__BB0_7; + bra.uni $L__BB0_6; +$L__BB0_7: // in Loop: Header=BB0_2 Depth=1 + div.s64 %rd183, %rd15, %rd40; + bra.uni $L__BB0_8; +$L__BB0_6: // in Loop: Header=BB0_2 Depth=1 + cvt.u32.u64 %r21, %rd40; + cvt.u32.u64 %r22, %rd15; + div.u32 %r23, %r22, %r21; + cvt.u64.u32 %rd183, %r23; +$L__BB0_8: // in Loop: Header=BB0_2 Depth=1 + .loc 1 0 0 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:0 + or.b32 %r11, %r17, 1536; + .loc 1 35 0 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:35 + cvt.s64.s32 %rd13, %r11; + add.s64 %rd17, %rd2, %rd13; + .loc 1 35 62 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:35:62 + or.b64 %rd52, %rd16, %rd40; + and.b64 %rd53, %rd52, -4294967296; + setp.ne.b64 %p8, %rd53, 0; + @%p8 bra $L__BB0_10; + bra.uni $L__BB0_9; +$L__BB0_10: // in Loop: Header=BB0_2 Depth=1 + div.s64 %rd184, %rd16, %rd40; + bra.uni $L__BB0_11; +$L__BB0_9: // in Loop: Header=BB0_2 Depth=1 + cvt.u32.u64 %r24, %rd40; + cvt.u32.u64 %r25, %rd16; + div.u32 %r26, %r25, %r24; + cvt.u64.u32 %rd184, %r26; +$L__BB0_11: // in Loop: Header=BB0_2 Depth=1 + or.b64 %rd54, %rd17, %rd40; + and.b64 %rd55, %rd54, -4294967296; + setp.ne.b64 %p9, %rd55, 0; + @%p9 bra $L__BB0_13; +// %bb.12: // in Loop: Header=BB0_2 Depth=1 + cvt.u32.u64 %r27, %rd40; + cvt.u32.u64 %r28, %rd17; + div.u32 %r29, %r28, %r27; + cvt.u64.u32 %rd185, %r29; + bra.uni $L__BB0_14; +$L__BB0_15: // %._crit_edge.loopexit + .loc 1 25 37 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:25:37 + add.s64 %rd127, %rd178, %rd180; + add.s64 %rd128, %rd179, %rd181; + add.s64 %rd186, %rd127, %rd128; +$L__BB0_16: // %._crit_edge + .loc 1 24 21 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:24:21 + setp.lt.u32 %p35, %r75, 2; + .loc 1 25 37 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:25:37 + and.b32 %r34, %r1, 31; +$L__tmp1: + .loc 2 291 36 // standard.py:291:36 @[ cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:44:25 ] + mov.b64 {_, %r35}, %rd186; + cvt.u32.u64 %r36, %rd186; + shfl.sync.bfly.b32 %r37, %r36, 16, 31, -1; + shfl.sync.bfly.b32 %r38, %r35, 16, 31, -1; + cvt.u64.u32 %rd134, %r37; + cvt.u64.u32 %rd135, %r38; + shl.b64 %rd136, %rd135, 32; + or.b64 %rd137, %rd134, %rd136; + .loc 2 261 15 // standard.py:261:15 @[ cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:44:25 ] + add.s64 %rd138, %rd186, %rd137; + .loc 2 291 36 // standard.py:291:36 @[ cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:44:25 ] + mov.b64 {_, %r39}, %rd138; + cvt.u32.u64 %r40, %rd138; + shfl.sync.bfly.b32 %r41, %r40, 8, 31, -1; + shfl.sync.bfly.b32 %r42, %r39, 8, 31, -1; + cvt.u64.u32 %rd139, %r41; + cvt.u64.u32 %rd140, %r42; + shl.b64 %rd141, %rd140, 32; + or.b64 %rd142, %rd139, %rd141; + .loc 2 261 15 // standard.py:261:15 @[ cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:44:25 ] + add.s64 %rd143, %rd138, %rd142; + .loc 2 291 36 // standard.py:291:36 @[ cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:44:25 ] + mov.b64 {_, %r43}, %rd143; + cvt.u32.u64 %r44, %rd143; + shfl.sync.bfly.b32 %r45, %r44, 4, 31, -1; + shfl.sync.bfly.b32 %r46, %r43, 4, 31, -1; + cvt.u64.u32 %rd144, %r45; + cvt.u64.u32 %rd145, %r46; + shl.b64 %rd146, %rd145, 32; + or.b64 %rd147, %rd144, %rd146; + .loc 2 261 15 // standard.py:261:15 @[ cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:44:25 ] + add.s64 %rd148, %rd143, %rd147; + .loc 2 291 36 // standard.py:291:36 @[ cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:44:25 ] + mov.b64 {_, %r47}, %rd148; + cvt.u32.u64 %r48, %rd148; + shfl.sync.bfly.b32 %r49, %r48, 2, 31, -1; + shfl.sync.bfly.b32 %r50, %r47, 2, 31, -1; + cvt.u64.u32 %rd149, %r49; + cvt.u64.u32 %rd150, %r50; + shl.b64 %rd151, %rd150, 32; + or.b64 %rd152, %rd149, %rd151; + .loc 2 261 15 // standard.py:261:15 @[ cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:44:25 ] + add.s64 %rd153, %rd148, %rd152; + .loc 2 291 36 // standard.py:291:36 @[ cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:44:25 ] + mov.b64 {_, %r51}, %rd153; + cvt.u32.u64 %r52, %rd153; + shfl.sync.bfly.b32 %r53, %r52, 1, 31, -1; + shfl.sync.bfly.b32 %r54, %r51, 1, 31, -1; + cvt.u64.u32 %rd154, %r53; + cvt.u64.u32 %rd155, %r54; + shl.b64 %rd156, %rd155, 32; + or.b64 %rd157, %rd154, %rd156; + .loc 2 261 15 // standard.py:261:15 @[ cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:44:25 ] + add.s64 %rd129, %rd153, %rd157; + .loc 2 291 36 // standard.py:291:36 @[ cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:44:25 ] + setp.eq.b32 %p31, %r34, 0; + shr.u32 %r55, %r1, 2; + and.b32 %r56, %r55, 120; + mov.b32 %r57, global_smem; + add.s32 %r30, %r57, %r56; + // begin inline asm + @%p31 st.shared.b64 [ %r30 + 0 ], %rd129; + // end inline asm + bar.sync 0; + setp.lt.u32 %p32, %r1, 16; + shl.b32 %r58, %r1, 3; + add.s32 %r31, %r57, %r58; + // begin inline asm + @%p32 ld.shared.b64 %rd130, [ %r31 + 0 ]; + // end inline asm + mov.b64 {_, %r59}, %rd130; + cvt.u32.u64 %r60, %rd130; + shfl.sync.bfly.b32 %r61, %r60, 8, 31, -1; + shfl.sync.bfly.b32 %r62, %r59, 8, 31, -1; + cvt.u64.u32 %rd158, %r61; + cvt.u64.u32 %rd159, %r62; + shl.b64 %rd160, %rd159, 32; + or.b64 %rd161, %rd158, %rd160; + .loc 2 261 15 // standard.py:261:15 @[ cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:44:25 ] + add.s64 %rd162, %rd130, %rd161; + .loc 2 291 36 // standard.py:291:36 @[ cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:44:25 ] + mov.b64 {_, %r63}, %rd162; + cvt.u32.u64 %r64, %rd162; + shfl.sync.bfly.b32 %r65, %r64, 4, 31, -1; + shfl.sync.bfly.b32 %r66, %r63, 4, 31, -1; + cvt.u64.u32 %rd163, %r65; + cvt.u64.u32 %rd164, %r66; + shl.b64 %rd165, %rd164, 32; + or.b64 %rd166, %rd163, %rd165; + .loc 2 261 15 // standard.py:261:15 @[ cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:44:25 ] + add.s64 %rd167, %rd162, %rd166; + .loc 2 291 36 // standard.py:291:36 @[ cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:44:25 ] + mov.b64 {_, %r67}, %rd167; + cvt.u32.u64 %r68, %rd167; + shfl.sync.bfly.b32 %r69, %r68, 2, 31, -1; + shfl.sync.bfly.b32 %r70, %r67, 2, 31, -1; + cvt.u64.u32 %rd168, %r69; + cvt.u64.u32 %rd169, %r70; + shl.b64 %rd170, %rd169, 32; + or.b64 %rd171, %rd168, %rd170; + .loc 2 261 15 // standard.py:261:15 @[ cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:44:25 ] + add.s64 %rd172, %rd167, %rd171; + .loc 2 291 36 // standard.py:291:36 @[ cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:44:25 ] + mov.b64 {_, %r71}, %rd172; + cvt.u32.u64 %r72, %rd172; + shfl.sync.bfly.b32 %r73, %r72, 1, 31, -1; + shfl.sync.bfly.b32 %r74, %r71, 1, 31, -1; + cvt.u64.u32 %rd173, %r73; + cvt.u64.u32 %rd174, %r74; + shl.b64 %rd175, %rd174, 32; + or.b64 %rd176, %rd173, %rd175; + .loc 2 261 15 // standard.py:261:15 @[ cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:44:25 ] + add.s64 %rd131, %rd172, %rd176; + .loc 2 291 36 // standard.py:291:36 @[ cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:44:25 ] + setp.eq.b32 %p33, %r1, 0; + // begin inline asm + @%p33 st.shared.b64 [ %r31 + 0 ], %rd131; + // end inline asm + bar.sync 0; + ld.shared.b64 %rd132, [global_smem]; +$L__tmp2: + .loc 1 45 25 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:45:25 + shl.b64 %rd177, %rd1, 3; + add.s64 %rd133, %rd39, %rd177; + .loc 1 45 36 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:45:36 + setp.eq.b32 %p36, %r2, 0; + and.pred %p34, %p35, %p36; + // begin inline asm + @%p34 st.global.b64 [ %rd133 + 0 ], { %rd132 }; + // end inline asm + .loc 1 45 4 // cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py:45:4 + ret; +$L__tmp3: +$L__func_end0: + // -- End function +} + .file 1 "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py" + .file 2 "/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py" + .section .debug_abbrev + { +.b8 1 // Abbreviation Code +.b8 17 // DW_TAG_compile_unit +.b8 1 // DW_CHILDREN_yes +.b8 37 // DW_AT_producer +.b8 8 // DW_FORM_string +.b8 19 // DW_AT_language +.b8 5 // DW_FORM_data2 +.b8 3 // DW_AT_name +.b8 8 // DW_FORM_string +.b8 16 // DW_AT_stmt_list +.b8 6 // DW_FORM_data4 +.b8 27 // DW_AT_comp_dir +.b8 8 // DW_FORM_string +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 2 // Abbreviation Code +.b8 46 // DW_TAG_subprogram +.b8 0 // DW_CHILDREN_no +.b8 3 // DW_AT_name +.b8 8 // DW_FORM_string +.b8 32 // DW_AT_inline +.b8 11 // DW_FORM_data1 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 3 // Abbreviation Code +.b8 46 // DW_TAG_subprogram +.b8 1 // DW_CHILDREN_yes +.b8 17 // DW_AT_low_pc +.b8 1 // DW_FORM_addr +.b8 18 // DW_AT_high_pc +.b8 1 // DW_FORM_addr +.b8 49 // DW_AT_abstract_origin +.b8 19 // DW_FORM_ref4 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 4 // Abbreviation Code +.b8 29 // DW_TAG_inlined_subroutine +.b8 0 // DW_CHILDREN_no +.b8 49 // DW_AT_abstract_origin +.b8 19 // DW_FORM_ref4 +.b8 17 // DW_AT_low_pc +.b8 1 // DW_FORM_addr +.b8 18 // DW_AT_high_pc +.b8 1 // DW_FORM_addr +.b8 88 // DW_AT_call_file +.b8 11 // DW_FORM_data1 +.b8 89 // DW_AT_call_line +.b8 11 // DW_FORM_data1 +.b8 87 // DW_AT_call_column +.b8 11 // DW_FORM_data1 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 0 // EOM(3) + } + .section .debug_info + { +.b32 222 // Length of Unit +.b8 2 // DWARF version number +.b8 0 +.b32 .debug_abbrev // Offset Into Abbrev. Section +.b8 8 // Address Size (in bytes) +.b8 1 // Abbrev [1] 0xb:0xd7 DW_TAG_compile_unit +.b8 116 // DW_AT_producer +.b8 114 +.b8 105 +.b8 116 +.b8 111 +.b8 110 +.b8 0 +.b8 2 // DW_AT_language +.b8 0 +.b8 99 // DW_AT_name +.b8 101 +.b8 99 +.b8 110 +.b8 51 +.b8 51 +.b8 117 +.b8 101 +.b8 101 +.b8 111 +.b8 113 +.b8 112 +.b8 98 +.b8 117 +.b8 53 +.b8 115 +.b8 112 +.b8 50 +.b8 111 +.b8 55 +.b8 53 +.b8 108 +.b8 122 +.b8 108 +.b8 122 +.b8 108 +.b8 117 +.b8 106 +.b8 107 +.b8 120 +.b8 115 +.b8 122 +.b8 111 +.b8 103 +.b8 99 +.b8 117 +.b8 122 +.b8 101 +.b8 117 +.b8 111 +.b8 53 +.b8 102 +.b8 121 +.b8 51 +.b8 116 +.b8 117 +.b8 102 +.b8 113 +.b8 104 +.b8 113 +.b8 113 +.b8 118 +.b8 46 +.b8 112 +.b8 121 +.b8 0 +.b32 .debug_line // DW_AT_stmt_list +.b8 47 // DW_AT_comp_dir +.b8 119 +.b8 111 +.b8 114 +.b8 107 +.b8 115 +.b8 112 +.b8 97 +.b8 99 +.b8 101 +.b8 47 +.b8 104 +.b8 97 +.b8 110 +.b8 114 +.b8 117 +.b8 105 +.b8 47 +.b8 83 +.b8 112 +.b8 101 +.b8 99 +.b8 70 +.b8 111 +.b8 114 +.b8 103 +.b8 101 +.b8 45 +.b8 101 +.b8 120 +.b8 116 +.b8 47 +.b8 99 +.b8 97 +.b8 99 +.b8 104 +.b8 101 +.b8 47 +.b8 99 +.b8 111 +.b8 109 +.b8 112 +.b8 105 +.b8 108 +.b8 101 +.b8 100 +.b8 95 +.b8 107 +.b8 101 +.b8 114 +.b8 110 +.b8 101 +.b8 108 +.b8 115 +.b8 47 +.b8 101 +.b8 99 +.b8 0 +.b8 2 // Abbrev [2] 0x8b:0x28 DW_TAG_subprogram +.b8 116 // DW_AT_name +.b8 114 +.b8 105 +.b8 116 +.b8 111 +.b8 110 +.b8 95 +.b8 114 +.b8 101 +.b8 100 +.b8 95 +.b8 102 +.b8 117 +.b8 115 +.b8 101 +.b8 100 +.b8 95 +.b8 101 +.b8 113 +.b8 95 +.b8 109 +.b8 117 +.b8 108 +.b8 95 +.b8 115 +.b8 113 +.b8 117 +.b8 101 +.b8 101 +.b8 122 +.b8 101 +.b8 95 +.b8 115 +.b8 117 +.b8 109 +.b8 95 +.b8 50 +.b8 0 +.b8 1 // DW_AT_inline +.b8 3 // Abbrev [3] 0xb3:0x2e DW_TAG_subprogram +.b64 $L__func_begin0 // DW_AT_low_pc +.b64 $L__func_end0 // DW_AT_high_pc +.b32 139 // DW_AT_abstract_origin +.b8 4 // Abbrev [4] 0xc8:0x18 DW_TAG_inlined_subroutine +.b32 139 // DW_AT_abstract_origin +.b64 $L__tmp1 // DW_AT_low_pc +.b64 $L__tmp2 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 44 // DW_AT_call_line +.b8 25 // DW_AT_call_column +.b8 0 // End Of Children Mark +.b8 0 // End Of Children Mark + } + .section .debug_macinfo { } diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/K6AHUKXY3DG2PXK4IJUBF54N3VWT4GZ62CFMGCGQUPTTKOJPUFHA/triton_red_fused_eq_mul_squeeze_sum_2.source b/SpecForge-ext/cache/compiled_kernels/triton/0/K6AHUKXY3DG2PXK4IJUBF54N3VWT4GZ62CFMGCGQUPTTKOJPUFHA/triton_red_fused_eq_mul_squeeze_sum_2.source new file mode 100644 index 0000000000000000000000000000000000000000..702c894d82fad9669e484c9b93ec85901bd74221 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/K6AHUKXY3DG2PXK4IJUBF54N3VWT4GZ62CFMGCGQUPTTKOJPUFHA/triton_red_fused_eq_mul_squeeze_sum_2.source @@ -0,0 +1,282 @@ +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":18:0) +#loc59 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":285:0) +#loc61 = loc(unknown) +#loc64 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":260:0) +#loc68 = loc("in_ptr0"(#loc)) +#loc69 = loc("in_ptr1"(#loc)) +#loc70 = loc("in_ptr2"(#loc)) +#loc71 = loc("out_ptr0"(#loc)) +#loc72 = loc("ks0"(#loc)) +#loc73 = loc("xnumel"(#loc)) +#loc74 = loc("r0_numel"(#loc)) +#loc129 = loc("input"(#loc59)) +#loc130 = loc("a"(#loc64)) +#loc131 = loc("b"(#loc64)) +module { + tt.func public @triton_red_fused_eq_mul_squeeze_sum_2(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %in_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr1"(#loc)), %in_ptr2: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr2"(#loc)), %out_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr0"(#loc)), %ks0: i64 loc("ks0"(#loc)), %xnumel: i32 loc("xnumel"(#loc)), %r0_numel: i32 loc("r0_numel"(#loc))) attributes {noinline = false} { + %xnumel_0 = arith.constant 2 : i32 loc(#loc75) + %xoffset = tt.get_program_id x : i32 loc(#loc76) + %xoffset_1 = arith.constant 1 : i32 loc(#loc77) + %xoffset_2 = arith.constant 1 : i32 loc(#loc77) + %xoffset_3 = arith.muli %xoffset, %xoffset_2 : i32 loc(#loc77) + %xindex = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32> loc(#loc78) + %xindex_4 = tt.expand_dims %xindex {axis = 1 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc79) + %xindex_5 = tt.splat %xoffset_3 : i32 -> tensor<1x1xi32> loc(#loc80) + %xindex_6 = arith.addi %xindex_5, %xindex_4 : tensor<1x1xi32> loc(#loc80) + %xmask = arith.constant dense<2> : tensor<1x1xi32> loc(#loc81) + %xmask_7 = arith.cmpi slt, %xindex_6, %xmask : tensor<1x1xi32> loc(#loc81) + %r0_base = tt.make_range {end = 2048 : i32, start = 0 : i32} : tensor<2048xi32> loc(#loc82) + %r0_base_8 = tt.expand_dims %r0_base {axis = 0 : i32} : tensor<2048xi32> -> tensor<1x2048xi32> loc(#loc83) + %_tmp7 = arith.constant 0 : i64 loc(#loc84) + %_tmp7_9 = arith.constant dense<0> : tensor<1x2048xi64> loc(#loc84) + %c0_i32 = arith.constant 0 : i32 loc(#loc11) + %c2048_i32 = arith.constant 2048 : i32 loc(#loc11) + %0 = arith.bitcast %c0_i32 : i32 to i32 loc(#loc11) + %1 = arith.bitcast %r0_numel : i32 to i32 loc(#loc11) + %2 = arith.bitcast %c2048_i32 : i32 to i32 loc(#loc11) + %3 = ub.poison : i32 loc(#loc11) + %_tmp7_10 = scf.for %r0_offset = %0 to %1 step %2 iter_args(%_tmp7_12 = %_tmp7_9) -> (tensor<1x2048xi64>) : i32 { + %r0_index = tt.splat %r0_offset : i32 -> tensor<1x2048xi32> loc(#loc86) + %r0_index_13 = arith.addi %r0_index, %r0_base_8 : tensor<1x2048xi32> loc(#loc86) + %r0_mask = tt.splat %r0_numel : i32 -> tensor<1x2048xi32> loc(#loc87) + %r0_mask_14 = arith.cmpi slt, %r0_index_13, %r0_mask : tensor<1x2048xi32> loc(#loc87) + %tmp0 = arith.constant 4 : i32 loc(#loc88) + %tmp0_15 = arith.constant 4 : i64 loc(#loc88) + %tmp0_16 = arith.muli %tmp0_15, %ks0 : i64 loc(#loc88) + %tmp0_17 = arith.extsi %xindex_6 : tensor<1x1xi32> to tensor<1x1xi64> loc(#loc89) + %tmp0_18 = tt.splat %tmp0_16 : i64 -> tensor<1x1xi64> loc(#loc89) + %tmp0_19 = arith.muli %tmp0_18, %tmp0_17 : tensor<1x1xi64> loc(#loc89) + %tmp0_20 = arith.extsi %r0_index_13 : tensor<1x2048xi32> to tensor<1x2048xi64> loc(#loc90) + %tmp0_21 = tt.broadcast %tmp0_19 : tensor<1x1xi64> -> tensor<1x2048xi64> loc(#loc90) + %tmp0_22 = arith.addi %tmp0_20, %tmp0_21 : tensor<1x2048xi64> loc(#loc90) + %tmp0_23 = tt.splat %ks0 : i64 -> tensor<1x2048xi64> loc(#loc91) + %tmp0_24 = arith.divsi %tmp0_22, %tmp0_23 : tensor<1x2048xi64> loc(#loc91) + %tmp0_25 = arith.constant 8 : i32 loc(#loc92) + %tmp0_26 = arith.constant 8 : i64 loc(#loc92) + %tmp0_27 = arith.constant dense<8> : tensor<1x2048xi64> loc(#loc92) + %tmp0_28 = arith.remsi %tmp0_24, %tmp0_27 : tensor<1x2048xi64> loc(#loc92) + %tmp0_29 = tt.splat %ks0 : i64 -> tensor<1x2048xi64> loc(#loc93) + %tmp0_30 = arith.muli %tmp0_29, %tmp0_28 : tensor<1x2048xi64> loc(#loc93) + %tmp0_31 = arith.extsi %r0_index_13 : tensor<1x2048xi32> to tensor<1x2048xi64> loc(#loc94) + %tmp0_32 = tt.splat %ks0 : i64 -> tensor<1x2048xi64> loc(#loc94) + %tmp0_33 = arith.remsi %tmp0_31, %tmp0_32 : tensor<1x2048xi64> loc(#loc94) + %tmp0_34 = arith.addi %tmp0_30, %tmp0_33 : tensor<1x2048xi64> loc(#loc95) + %tmp0_35 = tt.splat %in_ptr0 : !tt.ptr -> tensor<1x2048x!tt.ptr> loc(#loc96) + %tmp0_36 = tt.addptr %tmp0_35, %tmp0_34 : tensor<1x2048x!tt.ptr>, tensor<1x2048xi64> loc(#loc96) + %tmp0_37 = tt.broadcast %xmask_7 : tensor<1x1xi1> -> tensor<1x2048xi1> loc(#loc97) + %tmp0_38 = arith.andi %r0_mask_14, %tmp0_37 : tensor<1x2048xi1> loc(#loc97) + %tmp0_39 = arith.constant 0.000000e+00 : f32 loc(#loc98) + %tmp0_40 = arith.constant dense<0.000000e+00> : tensor<1x2048xf32> loc(#loc98) + %tmp0_41 = arith.fptosi %tmp0_40 : tensor<1x2048xf32> to tensor<1x2048xi64> loc(#loc98) + %tmp0_42 = tt.load %tmp0_36, %tmp0_38, %tmp0_41 evictionPolicy = evict_last : tensor<1x2048x!tt.ptr> loc(#loc98) + %tmp1 = arith.constant 4 : i32 loc(#loc99) + %tmp1_43 = arith.constant 4 : i64 loc(#loc99) + %tmp1_44 = arith.muli %tmp1_43, %ks0 : i64 loc(#loc99) + %tmp1_45 = arith.extsi %xindex_6 : tensor<1x1xi32> to tensor<1x1xi64> loc(#loc100) + %tmp1_46 = tt.splat %tmp1_44 : i64 -> tensor<1x1xi64> loc(#loc100) + %tmp1_47 = arith.muli %tmp1_46, %tmp1_45 : tensor<1x1xi64> loc(#loc100) + %tmp1_48 = arith.extsi %r0_index_13 : tensor<1x2048xi32> to tensor<1x2048xi64> loc(#loc101) + %tmp1_49 = tt.broadcast %tmp1_47 : tensor<1x1xi64> -> tensor<1x2048xi64> loc(#loc101) + %tmp1_50 = arith.addi %tmp1_48, %tmp1_49 : tensor<1x2048xi64> loc(#loc101) + %tmp1_51 = tt.splat %ks0 : i64 -> tensor<1x2048xi64> loc(#loc102) + %tmp1_52 = arith.divsi %tmp1_50, %tmp1_51 : tensor<1x2048xi64> loc(#loc102) + %tmp1_53 = arith.constant 8 : i32 loc(#loc103) + %tmp1_54 = arith.constant 8 : i64 loc(#loc103) + %tmp1_55 = arith.constant dense<8> : tensor<1x2048xi64> loc(#loc103) + %tmp1_56 = arith.remsi %tmp1_52, %tmp1_55 : tensor<1x2048xi64> loc(#loc103) + %tmp1_57 = tt.splat %ks0 : i64 -> tensor<1x2048xi64> loc(#loc104) + %tmp1_58 = arith.muli %tmp1_57, %tmp1_56 : tensor<1x2048xi64> loc(#loc104) + %tmp1_59 = arith.extsi %r0_index_13 : tensor<1x2048xi32> to tensor<1x2048xi64> loc(#loc105) + %tmp1_60 = tt.splat %ks0 : i64 -> tensor<1x2048xi64> loc(#loc105) + %tmp1_61 = arith.remsi %tmp1_59, %tmp1_60 : tensor<1x2048xi64> loc(#loc105) + %tmp1_62 = arith.addi %tmp1_58, %tmp1_61 : tensor<1x2048xi64> loc(#loc106) + %tmp1_63 = tt.splat %in_ptr1 : !tt.ptr -> tensor<1x2048x!tt.ptr> loc(#loc107) + %tmp1_64 = tt.addptr %tmp1_63, %tmp1_62 : tensor<1x2048x!tt.ptr>, tensor<1x2048xi64> loc(#loc107) + %tmp1_65 = tt.broadcast %xmask_7 : tensor<1x1xi1> -> tensor<1x2048xi1> loc(#loc108) + %tmp1_66 = arith.andi %r0_mask_14, %tmp1_65 : tensor<1x2048xi1> loc(#loc108) + %tmp1_67 = arith.constant 0.000000e+00 : f32 loc(#loc109) + %tmp1_68 = arith.constant dense<0.000000e+00> : tensor<1x2048xf32> loc(#loc109) + %tmp1_69 = arith.fptosi %tmp1_68 : tensor<1x2048xf32> to tensor<1x2048xi64> loc(#loc109) + %tmp1_70 = tt.load %tmp1_64, %tmp1_66, %tmp1_69 evictionPolicy = evict_last : tensor<1x2048x!tt.ptr> loc(#loc109) + %tmp4 = arith.constant 4 : i32 loc(#loc110) + %tmp4_71 = arith.constant 4 : i64 loc(#loc110) + %tmp4_72 = arith.muli %tmp4_71, %ks0 : i64 loc(#loc110) + %tmp4_73 = arith.extsi %xindex_6 : tensor<1x1xi32> to tensor<1x1xi64> loc(#loc111) + %tmp4_74 = tt.splat %tmp4_72 : i64 -> tensor<1x1xi64> loc(#loc111) + %tmp4_75 = arith.muli %tmp4_74, %tmp4_73 : tensor<1x1xi64> loc(#loc111) + %tmp4_76 = arith.extsi %r0_index_13 : tensor<1x2048xi32> to tensor<1x2048xi64> loc(#loc112) + %tmp4_77 = tt.broadcast %tmp4_75 : tensor<1x1xi64> -> tensor<1x2048xi64> loc(#loc112) + %tmp4_78 = arith.addi %tmp4_76, %tmp4_77 : tensor<1x2048xi64> loc(#loc112) + %tmp4_79 = tt.splat %ks0 : i64 -> tensor<1x2048xi64> loc(#loc113) + %tmp4_80 = arith.divsi %tmp4_78, %tmp4_79 : tensor<1x2048xi64> loc(#loc113) + %tmp4_81 = arith.constant 8 : i32 loc(#loc114) + %tmp4_82 = arith.constant 8 : i64 loc(#loc114) + %tmp4_83 = arith.constant dense<8> : tensor<1x2048xi64> loc(#loc114) + %tmp4_84 = arith.remsi %tmp4_80, %tmp4_83 : tensor<1x2048xi64> loc(#loc114) + %tmp4_85 = tt.splat %ks0 : i64 -> tensor<1x2048xi64> loc(#loc115) + %tmp4_86 = arith.muli %tmp4_85, %tmp4_84 : tensor<1x2048xi64> loc(#loc115) + %tmp4_87 = arith.extsi %r0_index_13 : tensor<1x2048xi32> to tensor<1x2048xi64> loc(#loc116) + %tmp4_88 = tt.splat %ks0 : i64 -> tensor<1x2048xi64> loc(#loc116) + %tmp4_89 = arith.remsi %tmp4_87, %tmp4_88 : tensor<1x2048xi64> loc(#loc116) + %tmp4_90 = arith.addi %tmp4_86, %tmp4_89 : tensor<1x2048xi64> loc(#loc117) + %tmp4_91 = tt.splat %in_ptr2 : !tt.ptr -> tensor<1x2048x!tt.ptr> loc(#loc118) + %tmp4_92 = tt.addptr %tmp4_91, %tmp4_90 : tensor<1x2048x!tt.ptr>, tensor<1x2048xi64> loc(#loc118) + %tmp4_93 = tt.broadcast %xmask_7 : tensor<1x1xi1> -> tensor<1x2048xi1> loc(#loc119) + %tmp4_94 = arith.andi %r0_mask_14, %tmp4_93 : tensor<1x2048xi1> loc(#loc119) + %tmp4_95 = arith.constant 0.000000e+00 : f32 loc(#loc120) + %tmp4_96 = arith.constant dense<0.000000e+00> : tensor<1x2048xf32> loc(#loc120) + %tmp4_97 = arith.fptosi %tmp4_96 : tensor<1x2048xf32> to tensor<1x2048xi64> loc(#loc120) + %tmp4_98 = tt.load %tmp4_92, %tmp4_94, %tmp4_97 evictionPolicy = evict_last : tensor<1x2048x!tt.ptr> loc(#loc120) + %tmp2 = arith.cmpi eq, %tmp0_42, %tmp1_70 : tensor<1x2048xi64> loc(#loc121) + %tmp3 = arith.extui %tmp2 : tensor<1x2048xi1> to tensor<1x2048xi64> loc(#loc122) + %tmp5 = arith.muli %tmp3, %tmp4_98 : tensor<1x2048xi64> loc(#loc123) + %tmp8 = arith.addi %_tmp7_12, %tmp5 : tensor<1x2048xi64> loc(#loc124) + %_tmp7_99 = tt.broadcast %xmask_7 : tensor<1x1xi1> -> tensor<1x2048xi1> loc(#loc125) + %_tmp7_100 = arith.andi %r0_mask_14, %_tmp7_99 : tensor<1x2048xi1> loc(#loc125) + %_tmp7_101 = arith.select %_tmp7_100, %tmp8, %_tmp7_12 : tensor<1x2048xi1>, tensor<1x2048xi64> loc(#loc126) + scf.yield %_tmp7_101 : tensor<1x2048xi64> loc(#loc53) + } loc(#loc85) + %tmp7 = tt.call @"triton.language.standard.sum__i64S1_2048S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%_tmp7_10) : (tensor<1x2048xi64>) -> tensor<1xi64> loc(#loc127) + %tmp7_11 = tt.expand_dims %tmp7 {axis = 1 : i32} : tensor<1xi64> -> tensor<1x1xi64> loc(#loc128) + %4 = tt.splat %out_ptr0 : !tt.ptr -> tensor<1x1x!tt.ptr> loc(#loc56) + %5 = tt.addptr %4, %xindex_6 : tensor<1x1x!tt.ptr>, tensor<1x1xi32> loc(#loc56) + tt.store %5, %tmp7_11, %xmask_7 : tensor<1x1x!tt.ptr> loc(#loc57) + tt.return loc(#loc58) + } loc(#loc) + tt.func private @"triton.language.standard.sum__i64S1_2048S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%input: tensor<1x2048xi64> loc("input"(#loc59))) -> tensor<1xi64> attributes {noinline = false} { + %0 = "tt.reduce"(%input) <{axis = 1 : i32}> ({ + ^bb0(%arg1: i64 loc(unknown), %arg2: i64 loc(unknown)): + %2 = tt.call @triton.language.standard._sum_combine__i64_i64__(%arg1, %arg2) : (i64, i64) -> i64 loc(#loc60) + tt.reduce.return %2 : i64 loc(#loc60) + }) : (tensor<1x2048xi64>) -> tensor<1xi64> loc(#loc60) + tt.return %0 : tensor<1xi64> loc(#loc62) + ^bb1: // no predecessors + %1 = ub.poison : tensor<1xi64> loc(#loc63) + tt.return %1 : tensor<1xi64> loc(#loc63) + } loc(#loc59) + tt.func private @triton.language.standard._sum_combine__i64_i64__(%a: i64 loc("a"(#loc64)), %b: i64 loc("b"(#loc64))) -> i64 attributes {noinline = false} { + %0 = arith.addi %a, %b : i64 loc(#loc65) + tt.return %0 : i64 loc(#loc66) + ^bb1: // no predecessors + %1 = ub.poison : i64 loc(#loc67) + tt.return %1 : i64 loc(#loc67) + } loc(#loc64) +} loc(#loc) +#loc1 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":19:13) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":22:28) +#loc3 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":22:33) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":23:36) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":23:44) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":23:23) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":24:21) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":25:27) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":25:37) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":28:43) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":29:40) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":30:31) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":31:29) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":35:51) +#loc15 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":35:55) +#loc16 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":35:49) +#loc17 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":35:62) +#loc18 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":35:69) +#loc19 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":35:40) +#loc20 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":35:84) +#loc21 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":35:77) +#loc22 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":35:34) +#loc23 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":35:102) +#loc24 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":35:92) +#loc25 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":36:51) +#loc26 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":36:55) +#loc27 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":36:49) +#loc28 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":36:62) +#loc29 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":36:69) +#loc30 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":36:40) +#loc31 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":36:84) +#loc32 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":36:77) +#loc33 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":36:34) +#loc34 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":36:102) +#loc35 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":36:92) +#loc36 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":37:51) +#loc37 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":37:55) +#loc38 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":37:49) +#loc39 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":37:62) +#loc40 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":37:69) +#loc41 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":37:40) +#loc42 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":37:84) +#loc43 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":37:77) +#loc44 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":37:34) +#loc45 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":37:102) +#loc46 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":37:92) +#loc47 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":38:23) +#loc48 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":39:23) +#loc49 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":40:22) +#loc50 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":42:23) +#loc51 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":43:35) +#loc52 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":43:48) +#loc53 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":43:8) +#loc54 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":44:25) +#loc55 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":44:28) +#loc56 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":45:25) +#loc57 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":45:36) +#loc58 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":45:4) +#loc60 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:36) +#loc62 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:11) +#loc63 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:4) +#loc65 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:15) +#loc66 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:11) +#loc67 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:4) +#loc75 = loc("xnumel"(#loc1)) +#loc76 = loc("xoffset"(#loc2)) +#loc77 = loc("xoffset"(#loc3)) +#loc78 = loc("xindex"(#loc4)) +#loc79 = loc("xindex"(#loc5)) +#loc80 = loc("xindex"(#loc6)) +#loc81 = loc("xmask"(#loc7)) +#loc82 = loc("r0_base"(#loc8)) +#loc83 = loc("r0_base"(#loc9)) +#loc84 = loc("_tmp7"(#loc10)) +#loc85 = loc("_tmp7"(#loc11)) +#loc86 = loc("r0_index"(#loc12)) +#loc87 = loc("r0_mask"(#loc13)) +#loc88 = loc("tmp0"(#loc14)) +#loc89 = loc("tmp0"(#loc15)) +#loc90 = loc("tmp0"(#loc16)) +#loc91 = loc("tmp0"(#loc17)) +#loc92 = loc("tmp0"(#loc18)) +#loc93 = loc("tmp0"(#loc19)) +#loc94 = loc("tmp0"(#loc20)) +#loc95 = loc("tmp0"(#loc21)) +#loc96 = loc("tmp0"(#loc22)) +#loc97 = loc("tmp0"(#loc23)) +#loc98 = loc("tmp0"(#loc24)) +#loc99 = loc("tmp1"(#loc25)) +#loc100 = loc("tmp1"(#loc26)) +#loc101 = loc("tmp1"(#loc27)) +#loc102 = loc("tmp1"(#loc28)) +#loc103 = loc("tmp1"(#loc29)) +#loc104 = loc("tmp1"(#loc30)) +#loc105 = loc("tmp1"(#loc31)) +#loc106 = loc("tmp1"(#loc32)) +#loc107 = loc("tmp1"(#loc33)) +#loc108 = loc("tmp1"(#loc34)) +#loc109 = loc("tmp1"(#loc35)) +#loc110 = loc("tmp4"(#loc36)) +#loc111 = loc("tmp4"(#loc37)) +#loc112 = loc("tmp4"(#loc38)) +#loc113 = loc("tmp4"(#loc39)) +#loc114 = loc("tmp4"(#loc40)) +#loc115 = loc("tmp4"(#loc41)) +#loc116 = loc("tmp4"(#loc42)) +#loc117 = loc("tmp4"(#loc43)) +#loc118 = loc("tmp4"(#loc44)) +#loc119 = loc("tmp4"(#loc45)) +#loc120 = loc("tmp4"(#loc46)) +#loc121 = loc("tmp2"(#loc47)) +#loc122 = loc("tmp3"(#loc48)) +#loc123 = loc("tmp5"(#loc49)) +#loc124 = loc("tmp8"(#loc50)) +#loc125 = loc("_tmp7"(#loc51)) +#loc126 = loc("_tmp7"(#loc52)) +#loc127 = loc("tmp7"(#loc54)) +#loc128 = loc("tmp7"(#loc55)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/K6AHUKXY3DG2PXK4IJUBF54N3VWT4GZ62CFMGCGQUPTTKOJPUFHA/triton_red_fused_eq_mul_squeeze_sum_2.ttgir b/SpecForge-ext/cache/compiled_kernels/triton/0/K6AHUKXY3DG2PXK4IJUBF54N3VWT4GZ62CFMGCGQUPTTKOJPUFHA/triton_red_fused_eq_mul_squeeze_sum_2.ttgir new file mode 100644 index 0000000000000000000000000000000000000000..94efbb30b018a9948962f14ea36e75ca54be6a78 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/K6AHUKXY3DG2PXK4IJUBF54N3VWT4GZ62CFMGCGQUPTTKOJPUFHA/triton_red_fused_eq_mul_squeeze_sum_2.ttgir @@ -0,0 +1,137 @@ +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 16], order = [0, 1]}> +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":18:0) +#loc1 = loc(unknown) +#loc30 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":44:25) +#loc36 = loc("in_ptr0"(#loc)) +#loc37 = loc("in_ptr1"(#loc)) +#loc38 = loc("in_ptr2"(#loc)) +#loc39 = loc("out_ptr0"(#loc)) +#loc40 = loc("ks0"(#loc)) +#loc41 = loc("xnumel"(#loc)) +#loc42 = loc("r0_numel"(#loc)) +#loc69 = loc("tmp7"(#loc30)) +#loc74 = loc(callsite(#loc1 at #loc69)) +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 16 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @triton_red_fused_eq_mul_squeeze_sum_2(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %in_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr1"(#loc)), %in_ptr2: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr2"(#loc)), %out_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr0"(#loc)), %ks0: i64 loc("ks0"(#loc)), %xnumel: i32 loc("xnumel"(#loc)), %r0_numel: i32 loc("r0_numel"(#loc))) attributes {noinline = false} { + %cst = arith.constant dense<8> : tensor<1x2048xi64, #blocked> loc(#loc1) + %c0_i32 = arith.constant 0 : i32 loc(#loc1) + %c2048_i32 = arith.constant 2048 : i32 loc(#loc1) + %c2_i32 = arith.constant 2 : i32 loc(#loc1) + %c4_i64 = arith.constant 4 : i64 loc(#loc1) + %cst_0 = arith.constant dense<0> : tensor<1x2048xi64, #blocked> loc(#loc1) + %xoffset = tt.get_program_id x : i32 loc(#loc43) + %xmask = arith.cmpi slt, %xoffset, %c2_i32 : i32 loc(#loc44) + %r0_base = tt.make_range {end = 2048 : i32, start = 0 : i32} : tensor<2048xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc45) + %r0_base_1 = tt.expand_dims %r0_base {axis = 0 : i32} : tensor<2048xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x2048xi32, #blocked> loc(#loc45) + %r0_mask = tt.splat %r0_numel : i32 -> tensor<1x2048xi32, #blocked> loc(#loc46) + %tmp0 = arith.muli %ks0, %c4_i64 : i64 loc(#loc47) + %tmp0_2 = arith.extsi %xoffset : i32 to i64 loc(#loc48) + %tmp0_3 = arith.muli %tmp0, %tmp0_2 : i64 loc(#loc48) + %tmp0_4 = tt.splat %tmp0_3 : i64 -> tensor<1x2048xi64, #blocked> loc(#loc71) + %tmp0_5 = tt.splat %ks0 : i64 -> tensor<1x2048xi64, #blocked> loc(#loc50) + %tmp0_6 = tt.splat %in_ptr0 : !tt.ptr -> tensor<1x2048x!tt.ptr, #blocked> loc(#loc51) + %tmp0_7 = tt.splat %xmask : i1 -> tensor<1x2048xi1, #blocked> loc(#loc72) + %tmp1 = tt.splat %in_ptr1 : !tt.ptr -> tensor<1x2048x!tt.ptr, #blocked> loc(#loc53) + %tmp4 = tt.splat %in_ptr2 : !tt.ptr -> tensor<1x2048x!tt.ptr, #blocked> loc(#loc54) + %_tmp7 = scf.for %_tmp7_9 = %c0_i32 to %r0_numel step %c2048_i32 iter_args(%arg8 = %cst_0) -> (tensor<1x2048xi64, #blocked>) : i32 { + %r0_index = tt.splat %_tmp7_9 : i32 -> tensor<1x2048xi32, #blocked> loc(#loc56) + %r0_index_10 = arith.addi %r0_index, %r0_base_1 : tensor<1x2048xi32, #blocked> loc(#loc56) + %r0_mask_11 = arith.cmpi slt, %r0_index_10, %r0_mask : tensor<1x2048xi32, #blocked> loc(#loc46) + %tmp0_12 = arith.extsi %r0_index_10 : tensor<1x2048xi32, #blocked> to tensor<1x2048xi64, #blocked> loc(#loc49) + %tmp0_13 = arith.addi %tmp0_12, %tmp0_4 : tensor<1x2048xi64, #blocked> loc(#loc49) + %tmp0_14 = arith.divsi %tmp0_13, %tmp0_5 : tensor<1x2048xi64, #blocked> loc(#loc50) + %tmp0_15 = arith.remsi %tmp0_14, %cst : tensor<1x2048xi64, #blocked> loc(#loc57) + %tmp0_16 = arith.muli %tmp0_5, %tmp0_15 : tensor<1x2048xi64, #blocked> loc(#loc58) + %tmp0_17 = arith.remsi %tmp0_12, %tmp0_5 : tensor<1x2048xi64, #blocked> loc(#loc59) + %tmp0_18 = arith.addi %tmp0_16, %tmp0_17 : tensor<1x2048xi64, #blocked> loc(#loc60) + %tmp0_19 = tt.addptr %tmp0_6, %tmp0_18 : tensor<1x2048x!tt.ptr, #blocked>, tensor<1x2048xi64, #blocked> loc(#loc51) + %tmp0_20 = arith.andi %r0_mask_11, %tmp0_7 : tensor<1x2048xi1, #blocked> loc(#loc52) + %tmp0_21 = tt.load %tmp0_19, %tmp0_20, %cst_0 evictionPolicy = evict_last : tensor<1x2048x!tt.ptr, #blocked> loc(#loc61) + %tmp1_22 = tt.addptr %tmp1, %tmp0_18 : tensor<1x2048x!tt.ptr, #blocked>, tensor<1x2048xi64, #blocked> loc(#loc53) + %tmp1_23 = tt.load %tmp1_22, %tmp0_20, %cst_0 evictionPolicy = evict_last : tensor<1x2048x!tt.ptr, #blocked> loc(#loc62) + %tmp4_24 = tt.addptr %tmp4, %tmp0_18 : tensor<1x2048x!tt.ptr, #blocked>, tensor<1x2048xi64, #blocked> loc(#loc54) + %tmp4_25 = tt.load %tmp4_24, %tmp0_20, %cst_0 evictionPolicy = evict_last : tensor<1x2048x!tt.ptr, #blocked> loc(#loc63) + %tmp2 = arith.cmpi eq, %tmp0_21, %tmp1_23 : tensor<1x2048xi64, #blocked> loc(#loc64) + %tmp3 = arith.extui %tmp2 : tensor<1x2048xi1, #blocked> to tensor<1x2048xi64, #blocked> loc(#loc65) + %tmp5 = arith.muli %tmp3, %tmp4_25 : tensor<1x2048xi64, #blocked> loc(#loc66) + %tmp8 = arith.addi %arg8, %tmp5 : tensor<1x2048xi64, #blocked> loc(#loc67) + %_tmp7_26 = arith.select %tmp0_20, %tmp8, %arg8 : tensor<1x2048xi1, #blocked>, tensor<1x2048xi64, #blocked> loc(#loc68) + scf.yield %_tmp7_26 : tensor<1x2048xi64, #blocked> loc(#loc28) + } loc(#loc55) + %tmp7 = "tt.reduce"(%_tmp7) <{axis = 1 : i32}> ({ + ^bb0(%tmp7_9: i64 loc(callsite(#loc1 at #loc69)), %tmp7_10: i64 loc(callsite(#loc1 at #loc69))): + %tmp7_11 = arith.addi %tmp7_9, %tmp7_10 : i64 loc(#loc75) + tt.reduce.return %tmp7_11 : i64 loc(#loc73) + }) : (tensor<1x2048xi64, #blocked>) -> tensor<1xi64, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc73) + %tmp7_8 = tt.expand_dims %tmp7 {axis = 1 : i32} : tensor<1xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<1x1xi64, #blocked> loc(#loc70) + %0 = tt.addptr %out_ptr0, %xoffset : !tt.ptr, i32 loc(#loc33) + %1 = tt.splat %0 : !tt.ptr -> tensor<1x1x!tt.ptr, #blocked> loc(#loc34) + %2 = tt.splat %xmask : i1 -> tensor<1x1xi1, #blocked> loc(#loc34) + tt.store %1, %tmp7_8, %2 : tensor<1x1x!tt.ptr, #blocked> loc(#loc34) + tt.return loc(#loc35) + } loc(#loc) +} loc(#loc) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":22:28) +#loc3 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":24:21) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":25:37) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":31:29) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":35:51) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":35:55) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":35:49) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":35:62) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":35:34) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":35:102) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":36:34) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":37:34) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":29:40) +#loc15 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":30:31) +#loc16 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":35:69) +#loc17 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":35:40) +#loc18 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":35:84) +#loc19 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":35:77) +#loc20 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":35:92) +#loc21 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":36:92) +#loc22 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":37:92) +#loc23 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":38:23) +#loc24 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":39:23) +#loc25 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":40:22) +#loc26 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":42:23) +#loc27 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":43:48) +#loc28 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":43:8) +#loc29 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:36) +#loc31 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:15) +#loc32 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":44:28) +#loc33 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":45:25) +#loc34 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":45:36) +#loc35 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":45:4) +#loc43 = loc("xoffset"(#loc2)) +#loc44 = loc("xmask"(#loc3)) +#loc45 = loc("r0_base"(#loc4)) +#loc46 = loc("r0_mask"(#loc5)) +#loc47 = loc("tmp0"(#loc6)) +#loc48 = loc("tmp0"(#loc7)) +#loc49 = loc("tmp0"(#loc8)) +#loc50 = loc("tmp0"(#loc9)) +#loc51 = loc("tmp0"(#loc10)) +#loc52 = loc("tmp0"(#loc11)) +#loc53 = loc("tmp1"(#loc12)) +#loc54 = loc("tmp4"(#loc13)) +#loc55 = loc("_tmp7"(#loc14)) +#loc56 = loc("r0_index"(#loc15)) +#loc57 = loc("tmp0"(#loc16)) +#loc58 = loc("tmp0"(#loc17)) +#loc59 = loc("tmp0"(#loc18)) +#loc60 = loc("tmp0"(#loc19)) +#loc61 = loc("tmp0"(#loc20)) +#loc62 = loc("tmp1"(#loc21)) +#loc63 = loc("tmp4"(#loc22)) +#loc64 = loc("tmp2"(#loc23)) +#loc65 = loc("tmp3"(#loc24)) +#loc66 = loc("tmp5"(#loc25)) +#loc67 = loc("tmp8"(#loc26)) +#loc68 = loc("_tmp7"(#loc27)) +#loc70 = loc("tmp7"(#loc32)) +#loc71 = loc(fused[#loc49, #loc48]) +#loc72 = loc(fused[#loc52, #loc44]) +#loc73 = loc(callsite(#loc29 at #loc69)) +#loc75 = loc(callsite(#loc31 at #loc73)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/K6AHUKXY3DG2PXK4IJUBF54N3VWT4GZ62CFMGCGQUPTTKOJPUFHA/triton_red_fused_eq_mul_squeeze_sum_2.ttir b/SpecForge-ext/cache/compiled_kernels/triton/0/K6AHUKXY3DG2PXK4IJUBF54N3VWT4GZ62CFMGCGQUPTTKOJPUFHA/triton_red_fused_eq_mul_squeeze_sum_2.ttir new file mode 100644 index 0000000000000000000000000000000000000000..3b2bea15688b103b179d1c6525a758d5a4eef5bd --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/K6AHUKXY3DG2PXK4IJUBF54N3VWT4GZ62CFMGCGQUPTTKOJPUFHA/triton_red_fused_eq_mul_squeeze_sum_2.ttir @@ -0,0 +1,138 @@ +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":18:0) +#loc3 = loc(unknown) +#loc31 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":44:25) +#loc37 = loc("in_ptr0"(#loc)) +#loc38 = loc("in_ptr1"(#loc)) +#loc39 = loc("in_ptr2"(#loc)) +#loc40 = loc("out_ptr0"(#loc)) +#loc41 = loc("ks0"(#loc)) +#loc42 = loc("xnumel"(#loc)) +#loc43 = loc("r0_numel"(#loc)) +#loc71 = loc("tmp7"(#loc31)) +#loc76 = loc(callsite(#loc3 at #loc71)) +module { + tt.func public @triton_red_fused_eq_mul_squeeze_sum_2(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %in_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr1"(#loc)), %in_ptr2: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr2"(#loc)), %out_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr0"(#loc)), %ks0: i64 loc("ks0"(#loc)), %xnumel: i32 loc("xnumel"(#loc)), %r0_numel: i32 loc("r0_numel"(#loc))) attributes {noinline = false} { + %xmask = arith.constant 2 : i32 loc(#loc44) + %c2048_i32 = arith.constant 2048 : i32 loc(#loc2) + %c0_i32 = arith.constant 0 : i32 loc(#loc2) + %cst = arith.constant dense<8> : tensor<1x2048xi64> loc(#loc3) + %c4_i64 = arith.constant 4 : i64 loc(#loc3) + %cst_0 = arith.constant dense<0> : tensor<1x2048xi64> loc(#loc3) + %xoffset = tt.get_program_id x : i32 loc(#loc45) + %xmask_1 = arith.cmpi slt, %xoffset, %xmask : i32 loc(#loc44) + %xmask_2 = tt.splat %xmask_1 : i1 -> tensor<1x1xi1> loc(#loc44) + %r0_base = tt.make_range {end = 2048 : i32, start = 0 : i32} : tensor<2048xi32> loc(#loc46) + %r0_base_3 = tt.expand_dims %r0_base {axis = 0 : i32} : tensor<2048xi32> -> tensor<1x2048xi32> loc(#loc47) + %_tmp7 = scf.for %r0_offset = %c0_i32 to %r0_numel step %c2048_i32 iter_args(%_tmp7_5 = %cst_0) -> (tensor<1x2048xi64>) : i32 { + %r0_index = tt.splat %r0_offset : i32 -> tensor<1x2048xi32> loc(#loc49) + %r0_index_6 = arith.addi %r0_index, %r0_base_3 : tensor<1x2048xi32> loc(#loc49) + %r0_mask = tt.splat %r0_numel : i32 -> tensor<1x2048xi32> loc(#loc50) + %r0_mask_7 = arith.cmpi slt, %r0_index_6, %r0_mask : tensor<1x2048xi32> loc(#loc50) + %tmp0 = arith.muli %ks0, %c4_i64 : i64 loc(#loc51) + %tmp0_8 = arith.extsi %xoffset : i32 to i64 loc(#loc52) + %tmp0_9 = arith.muli %tmp0, %tmp0_8 : i64 loc(#loc52) + %tmp0_10 = arith.extsi %r0_index_6 : tensor<1x2048xi32> to tensor<1x2048xi64> loc(#loc53) + %tmp0_11 = tt.splat %tmp0_9 : i64 -> tensor<1x2048xi64> loc(#loc73) + %tmp0_12 = arith.addi %tmp0_10, %tmp0_11 : tensor<1x2048xi64> loc(#loc53) + %tmp0_13 = tt.splat %ks0 : i64 -> tensor<1x2048xi64> loc(#loc54) + %tmp0_14 = arith.divsi %tmp0_12, %tmp0_13 : tensor<1x2048xi64> loc(#loc54) + %tmp0_15 = arith.remsi %tmp0_14, %cst : tensor<1x2048xi64> loc(#loc55) + %tmp0_16 = arith.muli %tmp0_13, %tmp0_15 : tensor<1x2048xi64> loc(#loc56) + %tmp0_17 = arith.remsi %tmp0_10, %tmp0_13 : tensor<1x2048xi64> loc(#loc57) + %tmp0_18 = arith.addi %tmp0_16, %tmp0_17 : tensor<1x2048xi64> loc(#loc58) + %tmp0_19 = tt.splat %in_ptr0 : !tt.ptr -> tensor<1x2048x!tt.ptr> loc(#loc59) + %tmp0_20 = tt.addptr %tmp0_19, %tmp0_18 : tensor<1x2048x!tt.ptr>, tensor<1x2048xi64> loc(#loc59) + %tmp0_21 = tt.splat %xmask_1 : i1 -> tensor<1x2048xi1> loc(#loc74) + %tmp0_22 = arith.andi %r0_mask_7, %tmp0_21 : tensor<1x2048xi1> loc(#loc60) + %tmp0_23 = tt.load %tmp0_20, %tmp0_22, %cst_0 evictionPolicy = evict_last : tensor<1x2048x!tt.ptr> loc(#loc61) + %tmp1 = tt.splat %in_ptr1 : !tt.ptr -> tensor<1x2048x!tt.ptr> loc(#loc62) + %tmp1_24 = tt.addptr %tmp1, %tmp0_18 : tensor<1x2048x!tt.ptr>, tensor<1x2048xi64> loc(#loc62) + %tmp1_25 = tt.load %tmp1_24, %tmp0_22, %cst_0 evictionPolicy = evict_last : tensor<1x2048x!tt.ptr> loc(#loc63) + %tmp4 = tt.splat %in_ptr2 : !tt.ptr -> tensor<1x2048x!tt.ptr> loc(#loc64) + %tmp4_26 = tt.addptr %tmp4, %tmp0_18 : tensor<1x2048x!tt.ptr>, tensor<1x2048xi64> loc(#loc64) + %tmp4_27 = tt.load %tmp4_26, %tmp0_22, %cst_0 evictionPolicy = evict_last : tensor<1x2048x!tt.ptr> loc(#loc65) + %tmp2 = arith.cmpi eq, %tmp0_23, %tmp1_25 : tensor<1x2048xi64> loc(#loc66) + %tmp3 = arith.extui %tmp2 : tensor<1x2048xi1> to tensor<1x2048xi64> loc(#loc67) + %tmp5 = arith.muli %tmp3, %tmp4_27 : tensor<1x2048xi64> loc(#loc68) + %tmp8 = arith.addi %_tmp7_5, %tmp5 : tensor<1x2048xi64> loc(#loc69) + %_tmp7_28 = arith.select %tmp0_22, %tmp8, %_tmp7_5 : tensor<1x2048xi1>, tensor<1x2048xi64> loc(#loc70) + scf.yield %_tmp7_28 : tensor<1x2048xi64> loc(#loc29) + } loc(#loc48) + %tmp7 = "tt.reduce"(%_tmp7) <{axis = 1 : i32}> ({ + ^bb0(%tmp7_5: i64 loc(callsite(#loc3 at #loc71)), %tmp7_6: i64 loc(callsite(#loc3 at #loc71))): + %tmp7_7 = arith.addi %tmp7_5, %tmp7_6 : i64 loc(#loc77) + tt.reduce.return %tmp7_7 : i64 loc(#loc75) + }) : (tensor<1x2048xi64>) -> tensor<1xi64> loc(#loc75) + %tmp7_4 = tt.expand_dims %tmp7 {axis = 1 : i32} : tensor<1xi64> -> tensor<1x1xi64> loc(#loc72) + %0 = tt.addptr %out_ptr0, %xoffset : !tt.ptr, i32 loc(#loc34) + %1 = tt.splat %0 : !tt.ptr -> tensor<1x1x!tt.ptr> loc(#loc34) + tt.store %1, %tmp7_4, %xmask_2 : tensor<1x1x!tt.ptr> loc(#loc35) + tt.return loc(#loc36) + } loc(#loc) +} loc(#loc) +#loc1 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":24:21) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":29:40) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":22:28) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":25:27) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":25:37) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":30:31) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":31:29) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":35:51) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":35:55) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":35:49) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":35:62) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":35:69) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":35:40) +#loc15 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":35:84) +#loc16 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":35:77) +#loc17 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":35:34) +#loc18 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":35:102) +#loc19 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":35:92) +#loc20 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":36:34) +#loc21 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":36:92) +#loc22 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":37:34) +#loc23 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":37:92) +#loc24 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":38:23) +#loc25 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":39:23) +#loc26 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":40:22) +#loc27 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":42:23) +#loc28 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":43:48) +#loc29 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":43:8) +#loc30 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:36) +#loc32 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:15) +#loc33 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":44:28) +#loc34 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":45:25) +#loc35 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":45:36) +#loc36 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ec/cecn33ueeoqpbu5sp2o75lzlzlujkxszogcuzeuo5fy3tufqhqqv.py":45:4) +#loc44 = loc("xmask"(#loc1)) +#loc45 = loc("xoffset"(#loc4)) +#loc46 = loc("r0_base"(#loc5)) +#loc47 = loc("r0_base"(#loc6)) +#loc48 = loc("_tmp7"(#loc2)) +#loc49 = loc("r0_index"(#loc7)) +#loc50 = loc("r0_mask"(#loc8)) +#loc51 = loc("tmp0"(#loc9)) +#loc52 = loc("tmp0"(#loc10)) +#loc53 = loc("tmp0"(#loc11)) +#loc54 = loc("tmp0"(#loc12)) +#loc55 = loc("tmp0"(#loc13)) +#loc56 = loc("tmp0"(#loc14)) +#loc57 = loc("tmp0"(#loc15)) +#loc58 = loc("tmp0"(#loc16)) +#loc59 = loc("tmp0"(#loc17)) +#loc60 = loc("tmp0"(#loc18)) +#loc61 = loc("tmp0"(#loc19)) +#loc62 = loc("tmp1"(#loc20)) +#loc63 = loc("tmp1"(#loc21)) +#loc64 = loc("tmp4"(#loc22)) +#loc65 = loc("tmp4"(#loc23)) +#loc66 = loc("tmp2"(#loc24)) +#loc67 = loc("tmp3"(#loc25)) +#loc68 = loc("tmp5"(#loc26)) +#loc69 = loc("tmp8"(#loc27)) +#loc70 = loc("_tmp7"(#loc28)) +#loc72 = loc("tmp7"(#loc33)) +#loc73 = loc(fused[#loc53, #loc52]) +#loc74 = loc(fused[#loc60, #loc44]) +#loc75 = loc(callsite(#loc30 at #loc71)) +#loc77 = loc(callsite(#loc32 at #loc75)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/NFXDDTZIEQAGR5JBWWCXV73QZILXJMIVVJVPZH3CHAIULDAND5UQ/__grp__triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.json b/SpecForge-ext/cache/compiled_kernels/triton/0/NFXDDTZIEQAGR5JBWWCXV73QZILXJMIVVJVPZH3CHAIULDAND5UQ/__grp__triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.json new file mode 100644 index 0000000000000000000000000000000000000000..ce720b2579cdecc6698ce411056b513f2c89d5b9 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/NFXDDTZIEQAGR5JBWWCXV73QZILXJMIVVJVPZH3CHAIULDAND5UQ/__grp__triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.json @@ -0,0 +1 @@ +{"child_paths": {"triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.source": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/NFXDDTZIEQAGR5JBWWCXV73QZILXJMIVVJVPZH3CHAIULDAND5UQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.source", "triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ttir": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/NFXDDTZIEQAGR5JBWWCXV73QZILXJMIVVJVPZH3CHAIULDAND5UQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ttir", "triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ttgir": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/NFXDDTZIEQAGR5JBWWCXV73QZILXJMIVVJVPZH3CHAIULDAND5UQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ttgir", "triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.llir": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/NFXDDTZIEQAGR5JBWWCXV73QZILXJMIVVJVPZH3CHAIULDAND5UQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.llir", "triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ptx": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/NFXDDTZIEQAGR5JBWWCXV73QZILXJMIVVJVPZH3CHAIULDAND5UQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ptx", "triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.cubin": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/NFXDDTZIEQAGR5JBWWCXV73QZILXJMIVVJVPZH3CHAIULDAND5UQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.cubin", "triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.json": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/NFXDDTZIEQAGR5JBWWCXV73QZILXJMIVVJVPZH3CHAIULDAND5UQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.json"}} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/NFXDDTZIEQAGR5JBWWCXV73QZILXJMIVVJVPZH3CHAIULDAND5UQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.cubin b/SpecForge-ext/cache/compiled_kernels/triton/0/NFXDDTZIEQAGR5JBWWCXV73QZILXJMIVVJVPZH3CHAIULDAND5UQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.cubin new file mode 100644 index 0000000000000000000000000000000000000000..9e64649230060cfe0b7b5409515c15ef4ece80c0 Binary files /dev/null and b/SpecForge-ext/cache/compiled_kernels/triton/0/NFXDDTZIEQAGR5JBWWCXV73QZILXJMIVVJVPZH3CHAIULDAND5UQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.cubin differ diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/NFXDDTZIEQAGR5JBWWCXV73QZILXJMIVVJVPZH3CHAIULDAND5UQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.json b/SpecForge-ext/cache/compiled_kernels/triton/0/NFXDDTZIEQAGR5JBWWCXV73QZILXJMIVVJVPZH3CHAIULDAND5UQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.json new file mode 100644 index 0000000000000000000000000000000000000000..7848edd58fa6cb351a065f121ef7f66022cbb319 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/NFXDDTZIEQAGR5JBWWCXV73QZILXJMIVVJVPZH3CHAIULDAND5UQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.json @@ -0,0 +1 @@ +{"hash": "696e31cf28240068f521b5857aff70ca1774b115aa6afc9f623811458c0d1f69", "target": {"backend": "cuda", "arch": 90, "warp_size": 32}, "num_warps": 2, "num_ctas": 1, "num_stages": 1, "warp_size": 32, "maxnreg": null, "cluster_dims": [1, 1, 1], "ptx_version": null, "ptx_options": null, "ir_override": null, "enable_fp_fusion": true, "launch_cooperative_grid": false, "launch_pdl": false, "supported_fp8_dtypes": ["fp8e4b15", "fp8e4nv", "fp8e5"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b15"], "default_dot_input_precision": "tf32", "allowed_dot_input_precisions": ["tf32", "tf32x3", "ieee"], "max_num_imprecise_acc_default": 1073741824, "extern_libs": [["libdevice", "/workspace/specforge/lib/python3.11/site-packages/triton/backends/nvidia/lib/libdevice.10.bc"]], "debug": true, "backend_name": "cuda", "sanitize_overflow": false, "arch": "sm90", "instrumentation_mode": "", "triton_version": "3.5.1", "tensordesc_meta": [], "shared": 0, "tmem_size": 0, "global_scratch_size": 0, "global_scratch_align": 1, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/NFXDDTZIEQAGR5JBWWCXV73QZILXJMIVVJVPZH3CHAIULDAND5UQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.llir b/SpecForge-ext/cache/compiled_kernels/triton/0/NFXDDTZIEQAGR5JBWWCXV73QZILXJMIVVJVPZH3CHAIULDAND5UQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.llir new file mode 100644 index 0000000000000000000000000000000000000000..c68ad8377215810c738a408f21c275f6b5fc4125 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/NFXDDTZIEQAGR5JBWWCXV73QZILXJMIVVJVPZH3CHAIULDAND5UQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.llir @@ -0,0 +1,840 @@ +; ModuleID = 'LLVMDialectModule' +source_filename = "LLVMDialectModule" +target datalayout = "e-p3:32:32-p4:32:32-p5:32:32-p6:32:32-p7:32:32-i64:64-i128:128-v16:16-v32:32-n16:32:64" + +@assertFunc_1 = internal constant [8 x i8] c"unknown\00" +@assertFile_1 = internal constant [114 x i8] c"/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py\00" +@assertMessage_1 = internal constant [37 x i8] c"index out of bounds: 0 <= tmp49 < 17\00" +@assertFunc_0 = internal constant [8 x i8] c"unknown\00" +@assertFile_0 = internal constant [114 x i8] c"/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py\00" +@assertMessage_0 = internal constant [37 x i8] c"index out of bounds: 0 <= tmp40 < 17\00" + +; Function Attrs: noreturn +declare !dbg !5 void @__assertfail(ptr, ptr, i32, ptr, i64) local_unnamed_addr #0 + +define ptx_kernel void @triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, ptr addrspace(1) %3, ptr addrspace(1) %4, ptr addrspace(1) %5, ptr addrspace(1) %6, i32 %7, i32 %8, ptr addrspace(1) readnone captures(none) %9, ptr addrspace(1) readnone captures(none) %10) local_unnamed_addr #1 !dbg !9 { + %12 = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !dbg !10 + %13 = icmp samesign ult i32 %12, 32, !dbg !11 + %14 = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !dbg !12 + %15 = and i32 %14, 15, !dbg !12 + %16 = shl i32 %12, 4, !dbg !13 + %17 = or disjoint i32 %15, %16, !dbg !14 + %18 = sext i32 %17 to i64, !dbg !15 + %19 = getelementptr i64, ptr addrspace(1) %0, i64 %18, !dbg !15 + %20 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09@$2 ld.global.b64 { $0 }, [ $1 + 0 ];", "=l,l,b"(ptr addrspace(1) %19, i1 %13) #5, !dbg !16 + %21 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09@$2 ld.global.b64 { $0 }, [ $1 + 0 ];", "=l,l,b"(ptr addrspace(1) %19, i1 %13) #5, !dbg !16 + %22 = add i64 %20, -1, !dbg !17 + %23 = icmp ult i64 %22, 16383, !dbg !17 + %24 = add i64 %21, -1, !dbg !17 + %25 = icmp ult i64 %24, 16383, !dbg !17 + %26 = zext i1 %23 to i32, !dbg !18 + %27 = lshr i32 %14, 1, !dbg !19 + %.lobit = and i32 %27, 1, !dbg !19 + %28 = and i32 %14, 1, !dbg !19 + %29 = lshr i32 %14, 2, !dbg !19 + %.lobit1 = and i32 %29, 1, !dbg !19 + %30 = lshr i32 %14, 3, !dbg !19 + %.lobit2 = and i32 %30, 1, !dbg !19 + %31 = xor i32 %28, 1, !dbg !23 + %32 = xor i32 %.lobit, 1, !dbg !23 + %33 = xor i32 %.lobit1, 1, !dbg !23 + %34 = xor i32 %.lobit2, 1, !dbg !23 + %35 = select i1 %23, i32 %31, i32 0, !dbg !24 + %36 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %35, i32 1, i32 31), !dbg !25 + %37 = add i32 %35, %36, !dbg !28 + %38 = select i1 %23, i32 %28, i32 0, !dbg !29 + %39 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %38, i32 1, i32 31), !dbg !25 + %40 = add i32 %38, %39, !dbg !28 + %41 = mul nuw nsw i32 %31, %15, !dbg !30 + %42 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %41, i32 1, i32 31), !dbg !25 + %43 = add i32 %42, %41, !dbg !28 + %44 = mul nuw nsw i32 %15, %28, !dbg !31 + %45 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %44, i32 1, i32 31), !dbg !25 + %46 = add i32 %45, %44, !dbg !28 + %47 = icmp slt i32 %37, %40, !dbg !32 + %48 = icmp eq i32 %37, %40, !dbg !33 + %49 = icmp sgt i32 %43, %46, !dbg !34 + %50 = and i1 %48, %49, !dbg !35 + %51 = or i1 %47, %50, !dbg !36 + %52 = trunc i32 %27 to i1, !dbg !37 + %53 = xor i1 %51, %52, !dbg !37 + %54 = xor i32 %37, %40, !dbg !38 + %55 = select i1 %53, i32 %54, i32 0, !dbg !39 + %56 = xor i32 %55, %26, !dbg !40 + %57 = xor i32 %46, %43, !dbg !41 + %58 = select i1 %53, i32 %57, i32 0, !dbg !42 + %59 = xor i32 %58, %15, !dbg !43 + %60 = mul nuw nsw i32 %56, %32, !dbg !24 + %61 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %60, i32 2, i32 31), !dbg !25 + %62 = add i32 %60, %61, !dbg !28 + %63 = mul nuw nsw i32 %56, %.lobit, !dbg !29 + %64 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %63, i32 2, i32 31), !dbg !25 + %65 = add i32 %63, %64, !dbg !28 + %66 = mul nuw nsw i32 %59, %32, !dbg !30 + %67 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %66, i32 2, i32 31), !dbg !25 + %68 = add i32 %66, %67, !dbg !28 + %69 = mul nuw nsw i32 %59, %.lobit, !dbg !31 + %70 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %69, i32 2, i32 31), !dbg !25 + %71 = add i32 %69, %70, !dbg !28 + %72 = trunc i32 %29 to i1, !dbg !37 + %73 = icmp sge i32 %62, %65, !dbg !37 + %74 = icmp ne i32 %62, %65, !dbg !37 + %75 = icmp sle i32 %68, %71, !dbg !37 + %76 = or i1 %74, %75, !dbg !37 + %77 = and i1 %73, %76, !dbg !37 + %.not = xor i1 %77, %72, !dbg !37 + %78 = xor i32 %62, %65, !dbg !38 + %79 = select i1 %.not, i32 0, i32 %78, !dbg !39 + %80 = xor i32 %79, %56, !dbg !40 + %81 = xor i32 %68, %71, !dbg !41 + %82 = select i1 %.not, i32 0, i32 %81, !dbg !42 + %83 = xor i32 %82, %59, !dbg !43 + %84 = mul nuw nsw i32 %80, %31, !dbg !24 + %85 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %84, i32 1, i32 31), !dbg !25 + %86 = add i32 %84, %85, !dbg !28 + %87 = mul nuw nsw i32 %80, %28, !dbg !29 + %88 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %87, i32 1, i32 31), !dbg !25 + %89 = add i32 %87, %88, !dbg !28 + %90 = mul nuw nsw i32 %83, %31, !dbg !30 + %91 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %90, i32 1, i32 31), !dbg !25 + %92 = add i32 %90, %91, !dbg !28 + %93 = mul nuw nsw i32 %83, %28, !dbg !31 + %94 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %93, i32 1, i32 31), !dbg !25 + %95 = add i32 %93, %94, !dbg !28 + %96 = icmp sge i32 %86, %89, !dbg !37 + %97 = icmp ne i32 %86, %89, !dbg !37 + %98 = icmp sle i32 %92, %95, !dbg !37 + %99 = or i1 %97, %98, !dbg !37 + %100 = and i1 %96, %99, !dbg !37 + %.not3 = xor i1 %100, %72, !dbg !37 + %101 = xor i32 %86, %89, !dbg !38 + %102 = select i1 %.not3, i32 0, i32 %101, !dbg !39 + %103 = xor i32 %102, %80, !dbg !40 + %104 = xor i32 %92, %95, !dbg !41 + %105 = select i1 %.not3, i32 0, i32 %104, !dbg !42 + %106 = xor i32 %105, %83, !dbg !43 + %107 = mul nuw nsw i32 %103, %33, !dbg !24 + %108 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %107, i32 4, i32 31), !dbg !25 + %109 = add i32 %107, %108, !dbg !28 + %110 = mul nuw nsw i32 %103, %.lobit1, !dbg !29 + %111 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %110, i32 4, i32 31), !dbg !25 + %112 = add i32 %110, %111, !dbg !28 + %113 = mul nuw nsw i32 %106, %33, !dbg !30 + %114 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %113, i32 4, i32 31), !dbg !25 + %115 = add i32 %113, %114, !dbg !28 + %116 = mul nuw nsw i32 %106, %.lobit1, !dbg !31 + %117 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %116, i32 4, i32 31), !dbg !25 + %118 = add i32 %116, %117, !dbg !28 + %119 = trunc i32 %30 to i1, !dbg !37 + %120 = icmp sge i32 %109, %112, !dbg !37 + %121 = icmp ne i32 %109, %112, !dbg !37 + %122 = icmp sle i32 %115, %118, !dbg !37 + %123 = or i1 %121, %122, !dbg !37 + %124 = and i1 %120, %123, !dbg !37 + %.not4 = xor i1 %124, %119, !dbg !37 + %125 = xor i32 %109, %112, !dbg !38 + %126 = select i1 %.not4, i32 0, i32 %125, !dbg !39 + %127 = xor i32 %126, %103, !dbg !40 + %128 = xor i32 %115, %118, !dbg !41 + %129 = select i1 %.not4, i32 0, i32 %128, !dbg !42 + %130 = xor i32 %129, %106, !dbg !43 + %131 = mul nuw nsw i32 %127, %32, !dbg !24 + %132 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %131, i32 2, i32 31), !dbg !25 + %133 = add i32 %131, %132, !dbg !28 + %134 = mul nuw nsw i32 %127, %.lobit, !dbg !29 + %135 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %134, i32 2, i32 31), !dbg !25 + %136 = add i32 %134, %135, !dbg !28 + %137 = mul nuw nsw i32 %130, %32, !dbg !30 + %138 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %137, i32 2, i32 31), !dbg !25 + %139 = add i32 %137, %138, !dbg !28 + %140 = mul nuw nsw i32 %130, %.lobit, !dbg !31 + %141 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %140, i32 2, i32 31), !dbg !25 + %142 = add i32 %140, %141, !dbg !28 + %143 = icmp sge i32 %133, %136, !dbg !37 + %144 = icmp ne i32 %133, %136, !dbg !37 + %145 = icmp sle i32 %139, %142, !dbg !37 + %146 = or i1 %144, %145, !dbg !37 + %147 = and i1 %143, %146, !dbg !37 + %.not5 = xor i1 %147, %119, !dbg !37 + %148 = xor i32 %133, %136, !dbg !38 + %149 = select i1 %.not5, i32 0, i32 %148, !dbg !39 + %150 = xor i32 %149, %127, !dbg !40 + %151 = xor i32 %139, %142, !dbg !41 + %152 = select i1 %.not5, i32 0, i32 %151, !dbg !42 + %153 = xor i32 %152, %130, !dbg !43 + %154 = mul nuw nsw i32 %150, %31, !dbg !24 + %155 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %154, i32 1, i32 31), !dbg !25 + %156 = add i32 %154, %155, !dbg !28 + %157 = mul nuw nsw i32 %150, %28, !dbg !29 + %158 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %157, i32 1, i32 31), !dbg !25 + %159 = add i32 %157, %158, !dbg !28 + %160 = mul nuw nsw i32 %153, %31, !dbg !30 + %161 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %160, i32 1, i32 31), !dbg !25 + %162 = add i32 %160, %161, !dbg !28 + %163 = mul nuw nsw i32 %153, %28, !dbg !31 + %164 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %163, i32 1, i32 31), !dbg !25 + %165 = add i32 %163, %164, !dbg !28 + %166 = icmp sge i32 %156, %159, !dbg !37 + %167 = icmp ne i32 %156, %159, !dbg !37 + %168 = icmp sle i32 %162, %165, !dbg !37 + %169 = or i1 %167, %168, !dbg !37 + %170 = and i1 %166, %169, !dbg !37 + %.not6 = xor i1 %170, %119, !dbg !37 + %171 = xor i32 %156, %159, !dbg !38 + %172 = select i1 %.not6, i32 0, i32 %171, !dbg !39 + %173 = xor i32 %172, %150, !dbg !40 + %174 = xor i32 %162, %165, !dbg !41 + %175 = select i1 %.not6, i32 0, i32 %174, !dbg !42 + %176 = xor i32 %175, %153, !dbg !43 + %177 = mul nuw nsw i32 %173, %34, !dbg !24 + %178 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %177, i32 8, i32 31), !dbg !25 + %179 = add i32 %177, %178, !dbg !28 + %180 = mul nuw nsw i32 %173, %.lobit2, !dbg !29 + %181 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %180, i32 8, i32 31), !dbg !25 + %182 = add i32 %180, %181, !dbg !28 + %183 = mul nuw nsw i32 %176, %34, !dbg !30 + %184 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %183, i32 8, i32 31), !dbg !25 + %185 = add i32 %183, %184, !dbg !28 + %186 = mul nuw nsw i32 %176, %.lobit2, !dbg !31 + %187 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %186, i32 8, i32 31), !dbg !25 + %188 = add i32 %186, %187, !dbg !28 + %189 = icmp slt i32 %179, %182, !dbg !32 + %190 = icmp eq i32 %179, %182, !dbg !33 + %191 = icmp sgt i32 %185, %188, !dbg !34 + %192 = and i1 %190, %191, !dbg !35 + %193 = or i1 %189, %192, !dbg !36 + %194 = xor i32 %179, %182, !dbg !38 + %195 = select i1 %193, i32 %194, i32 0, !dbg !39 + %196 = xor i32 %195, %173, !dbg !40 + %197 = xor i32 %185, %188, !dbg !41 + %198 = select i1 %193, i32 %197, i32 0, !dbg !42 + %199 = xor i32 %198, %176, !dbg !43 + %200 = mul nuw nsw i32 %196, %33, !dbg !24 + %201 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %200, i32 4, i32 31), !dbg !25 + %202 = add i32 %200, %201, !dbg !28 + %203 = mul nuw nsw i32 %196, %.lobit1, !dbg !29 + %204 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %203, i32 4, i32 31), !dbg !25 + %205 = add i32 %203, %204, !dbg !28 + %206 = mul nuw nsw i32 %199, %33, !dbg !30 + %207 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %206, i32 4, i32 31), !dbg !25 + %208 = add i32 %206, %207, !dbg !28 + %209 = mul nuw nsw i32 %199, %.lobit1, !dbg !31 + %210 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %209, i32 4, i32 31), !dbg !25 + %211 = add i32 %209, %210, !dbg !28 + %212 = icmp slt i32 %202, %205, !dbg !32 + %213 = icmp eq i32 %202, %205, !dbg !33 + %214 = icmp sgt i32 %208, %211, !dbg !34 + %215 = and i1 %213, %214, !dbg !35 + %216 = or i1 %212, %215, !dbg !36 + %217 = xor i32 %202, %205, !dbg !38 + %218 = select i1 %216, i32 %217, i32 0, !dbg !39 + %219 = xor i32 %218, %196, !dbg !40 + %220 = xor i32 %208, %211, !dbg !41 + %221 = select i1 %216, i32 %220, i32 0, !dbg !42 + %222 = xor i32 %221, %199, !dbg !43 + %223 = mul nuw nsw i32 %219, %32, !dbg !24 + %224 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %223, i32 2, i32 31), !dbg !25 + %225 = add i32 %223, %224, !dbg !28 + %226 = mul nuw nsw i32 %219, %.lobit, !dbg !29 + %227 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %226, i32 2, i32 31), !dbg !25 + %228 = add i32 %226, %227, !dbg !28 + %229 = mul nuw nsw i32 %222, %32, !dbg !30 + %230 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %229, i32 2, i32 31), !dbg !25 + %231 = add i32 %229, %230, !dbg !28 + %232 = mul nuw nsw i32 %222, %.lobit, !dbg !31 + %233 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %232, i32 2, i32 31), !dbg !25 + %234 = add i32 %232, %233, !dbg !28 + %235 = icmp slt i32 %225, %228, !dbg !32 + %236 = icmp eq i32 %225, %228, !dbg !33 + %237 = icmp sgt i32 %231, %234, !dbg !34 + %238 = and i1 %236, %237, !dbg !35 + %239 = or i1 %235, %238, !dbg !36 + %240 = xor i32 %225, %228, !dbg !38 + %241 = select i1 %239, i32 %240, i32 0, !dbg !39 + %242 = xor i32 %241, %219, !dbg !40 + %243 = xor i32 %231, %234, !dbg !41 + %244 = select i1 %239, i32 %243, i32 0, !dbg !42 + %245 = xor i32 %244, %222, !dbg !43 + %246 = mul nuw nsw i32 %242, %31, !dbg !24 + %247 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %246, i32 1, i32 31), !dbg !25 + %248 = add i32 %246, %247, !dbg !28 + %249 = mul nuw nsw i32 %242, %28, !dbg !29 + %250 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %249, i32 1, i32 31), !dbg !25 + %251 = add i32 %249, %250, !dbg !28 + %252 = mul nuw nsw i32 %245, %31, !dbg !30 + %253 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %252, i32 1, i32 31), !dbg !25 + %254 = add i32 %252, %253, !dbg !28 + %255 = mul nuw nsw i32 %245, %28, !dbg !31 + %256 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %255, i32 1, i32 31), !dbg !25 + %257 = add i32 %255, %256, !dbg !28 + %258 = icmp slt i32 %248, %251, !dbg !32 + %259 = icmp eq i32 %248, %251, !dbg !33 + %260 = icmp sgt i32 %254, %257, !dbg !34 + %261 = and i1 %259, %260, !dbg !35 + %262 = or i1 %258, %261, !dbg !36 + %263 = xor i32 %254, %257, !dbg !41 + %264 = select i1 %262, i32 %263, i32 0, !dbg !42 + %265 = xor i32 %264, %245, !dbg !43 + %266 = icmp eq i64 %20, 16384, !dbg !44 + %267 = icmp eq i64 %21, 16384, !dbg !44 + %268 = zext i1 %266 to i32, !dbg !18 + %269 = select i1 %266, i32 %31, i32 0, !dbg !45 + %270 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %269, i32 1, i32 31), !dbg !47 + %271 = add i32 %270, %269, !dbg !48 + %272 = select i1 %266, i32 %28, i32 0, !dbg !49 + %273 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %272, i32 1, i32 31), !dbg !47 + %274 = add i32 %273, %272, !dbg !48 + %275 = icmp slt i32 %271, %274, !dbg !50 + %276 = icmp eq i32 %271, %274, !dbg !51 + %277 = and i1 %49, %276, !dbg !52 + %278 = or i1 %275, %277, !dbg !53 + %279 = xor i1 %278, %52, !dbg !54 + %280 = xor i32 %274, %271, !dbg !55 + %281 = select i1 %279, i32 %280, i32 0, !dbg !56 + %282 = xor i32 %281, %268, !dbg !57 + %283 = select i1 %279, i32 %57, i32 0, !dbg !58 + %284 = xor i32 %283, %15, !dbg !59 + %285 = mul nuw nsw i32 %282, %32, !dbg !45 + %286 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %285, i32 2, i32 31), !dbg !47 + %287 = add i32 %285, %286, !dbg !48 + %288 = mul nuw nsw i32 %282, %.lobit, !dbg !49 + %289 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %288, i32 2, i32 31), !dbg !47 + %290 = add i32 %288, %289, !dbg !48 + %291 = mul nuw nsw i32 %284, %32, !dbg !60 + %292 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %291, i32 2, i32 31), !dbg !47 + %293 = add i32 %291, %292, !dbg !48 + %294 = mul nuw nsw i32 %284, %.lobit, !dbg !61 + %295 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %294, i32 2, i32 31), !dbg !47 + %296 = add i32 %294, %295, !dbg !48 + %297 = icmp sge i32 %287, %290, !dbg !54 + %298 = icmp ne i32 %287, %290, !dbg !54 + %299 = icmp sle i32 %293, %296, !dbg !54 + %300 = or i1 %298, %299, !dbg !54 + %301 = and i1 %297, %300, !dbg !54 + %.not8 = xor i1 %301, %72, !dbg !54 + %302 = xor i32 %287, %290, !dbg !55 + %303 = select i1 %.not8, i32 0, i32 %302, !dbg !56 + %304 = xor i32 %303, %282, !dbg !57 + %305 = xor i32 %293, %296, !dbg !62 + %306 = select i1 %.not8, i32 0, i32 %305, !dbg !58 + %307 = xor i32 %306, %284, !dbg !59 + %308 = mul nuw nsw i32 %304, %31, !dbg !45 + %309 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %308, i32 1, i32 31), !dbg !47 + %310 = add i32 %308, %309, !dbg !48 + %311 = mul nuw nsw i32 %304, %28, !dbg !49 + %312 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %311, i32 1, i32 31), !dbg !47 + %313 = add i32 %311, %312, !dbg !48 + %314 = mul nuw nsw i32 %307, %31, !dbg !60 + %315 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %314, i32 1, i32 31), !dbg !47 + %316 = add i32 %314, %315, !dbg !48 + %317 = mul nuw nsw i32 %307, %28, !dbg !61 + %318 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %317, i32 1, i32 31), !dbg !47 + %319 = add i32 %317, %318, !dbg !48 + %320 = icmp sge i32 %310, %313, !dbg !54 + %321 = icmp ne i32 %310, %313, !dbg !54 + %322 = icmp sle i32 %316, %319, !dbg !54 + %323 = or i1 %321, %322, !dbg !54 + %324 = and i1 %320, %323, !dbg !54 + %.not9 = xor i1 %324, %72, !dbg !54 + %325 = xor i32 %310, %313, !dbg !55 + %326 = select i1 %.not9, i32 0, i32 %325, !dbg !56 + %327 = xor i32 %326, %304, !dbg !57 + %328 = xor i32 %316, %319, !dbg !62 + %329 = select i1 %.not9, i32 0, i32 %328, !dbg !58 + %330 = xor i32 %329, %307, !dbg !59 + %331 = mul nuw nsw i32 %327, %33, !dbg !45 + %332 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %331, i32 4, i32 31), !dbg !47 + %333 = add i32 %331, %332, !dbg !48 + %334 = mul nuw nsw i32 %327, %.lobit1, !dbg !49 + %335 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %334, i32 4, i32 31), !dbg !47 + %336 = add i32 %334, %335, !dbg !48 + %337 = mul nuw nsw i32 %330, %33, !dbg !60 + %338 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %337, i32 4, i32 31), !dbg !47 + %339 = add i32 %337, %338, !dbg !48 + %340 = mul nuw nsw i32 %330, %.lobit1, !dbg !61 + %341 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %340, i32 4, i32 31), !dbg !47 + %342 = add i32 %340, %341, !dbg !48 + %343 = icmp sge i32 %333, %336, !dbg !54 + %344 = icmp ne i32 %333, %336, !dbg !54 + %345 = icmp sle i32 %339, %342, !dbg !54 + %346 = or i1 %344, %345, !dbg !54 + %347 = and i1 %343, %346, !dbg !54 + %.not10 = xor i1 %347, %119, !dbg !54 + %348 = xor i32 %333, %336, !dbg !55 + %349 = select i1 %.not10, i32 0, i32 %348, !dbg !56 + %350 = xor i32 %349, %327, !dbg !57 + %351 = xor i32 %339, %342, !dbg !62 + %352 = select i1 %.not10, i32 0, i32 %351, !dbg !58 + %353 = xor i32 %352, %330, !dbg !59 + %354 = mul nuw nsw i32 %350, %32, !dbg !45 + %355 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %354, i32 2, i32 31), !dbg !47 + %356 = add i32 %354, %355, !dbg !48 + %357 = mul nuw nsw i32 %350, %.lobit, !dbg !49 + %358 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %357, i32 2, i32 31), !dbg !47 + %359 = add i32 %357, %358, !dbg !48 + %360 = mul nuw nsw i32 %353, %32, !dbg !60 + %361 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %360, i32 2, i32 31), !dbg !47 + %362 = add i32 %360, %361, !dbg !48 + %363 = mul nuw nsw i32 %353, %.lobit, !dbg !61 + %364 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %363, i32 2, i32 31), !dbg !47 + %365 = add i32 %363, %364, !dbg !48 + %366 = icmp sge i32 %356, %359, !dbg !54 + %367 = icmp ne i32 %356, %359, !dbg !54 + %368 = icmp sle i32 %362, %365, !dbg !54 + %369 = or i1 %367, %368, !dbg !54 + %370 = and i1 %366, %369, !dbg !54 + %.not11 = xor i1 %370, %119, !dbg !54 + %371 = xor i32 %356, %359, !dbg !55 + %372 = select i1 %.not11, i32 0, i32 %371, !dbg !56 + %373 = xor i32 %372, %350, !dbg !57 + %374 = xor i32 %362, %365, !dbg !62 + %375 = select i1 %.not11, i32 0, i32 %374, !dbg !58 + %376 = xor i32 %375, %353, !dbg !59 + %377 = mul nuw nsw i32 %373, %31, !dbg !45 + %378 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %377, i32 1, i32 31), !dbg !47 + %379 = add i32 %377, %378, !dbg !48 + %380 = mul nuw nsw i32 %373, %28, !dbg !49 + %381 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %380, i32 1, i32 31), !dbg !47 + %382 = add i32 %380, %381, !dbg !48 + %383 = mul nuw nsw i32 %376, %31, !dbg !60 + %384 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %383, i32 1, i32 31), !dbg !47 + %385 = add i32 %383, %384, !dbg !48 + %386 = mul nuw nsw i32 %376, %28, !dbg !61 + %387 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %386, i32 1, i32 31), !dbg !47 + %388 = add i32 %386, %387, !dbg !48 + %389 = icmp sge i32 %379, %382, !dbg !54 + %390 = icmp ne i32 %379, %382, !dbg !54 + %391 = icmp sle i32 %385, %388, !dbg !54 + %392 = or i1 %390, %391, !dbg !54 + %393 = and i1 %389, %392, !dbg !54 + %.not12 = xor i1 %393, %119, !dbg !54 + %394 = xor i32 %379, %382, !dbg !55 + %395 = select i1 %.not12, i32 0, i32 %394, !dbg !56 + %396 = xor i32 %395, %373, !dbg !57 + %397 = xor i32 %385, %388, !dbg !62 + %398 = select i1 %.not12, i32 0, i32 %397, !dbg !58 + %399 = xor i32 %398, %376, !dbg !59 + %400 = mul nuw nsw i32 %396, %34, !dbg !45 + %401 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %400, i32 8, i32 31), !dbg !47 + %402 = add i32 %400, %401, !dbg !48 + %403 = mul nuw nsw i32 %396, %.lobit2, !dbg !49 + %404 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %403, i32 8, i32 31), !dbg !47 + %405 = add i32 %403, %404, !dbg !48 + %406 = mul nuw nsw i32 %399, %34, !dbg !60 + %407 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %406, i32 8, i32 31), !dbg !47 + %408 = add i32 %406, %407, !dbg !48 + %409 = mul nuw nsw i32 %399, %.lobit2, !dbg !61 + %410 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %409, i32 8, i32 31), !dbg !47 + %411 = add i32 %409, %410, !dbg !48 + %412 = icmp slt i32 %402, %405, !dbg !50 + %413 = icmp eq i32 %402, %405, !dbg !51 + %414 = icmp sgt i32 %408, %411, !dbg !63 + %415 = and i1 %413, %414, !dbg !52 + %416 = or i1 %412, %415, !dbg !53 + %417 = xor i32 %402, %405, !dbg !55 + %418 = select i1 %416, i32 %417, i32 0, !dbg !56 + %419 = xor i32 %418, %396, !dbg !57 + %420 = xor i32 %408, %411, !dbg !62 + %421 = select i1 %416, i32 %420, i32 0, !dbg !58 + %422 = xor i32 %421, %399, !dbg !59 + %423 = mul nuw nsw i32 %419, %33, !dbg !45 + %424 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %423, i32 4, i32 31), !dbg !47 + %425 = add i32 %423, %424, !dbg !48 + %426 = mul nuw nsw i32 %419, %.lobit1, !dbg !49 + %427 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %426, i32 4, i32 31), !dbg !47 + %428 = add i32 %426, %427, !dbg !48 + %429 = mul nuw nsw i32 %422, %33, !dbg !60 + %430 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %429, i32 4, i32 31), !dbg !47 + %431 = add i32 %429, %430, !dbg !48 + %432 = mul nuw nsw i32 %422, %.lobit1, !dbg !61 + %433 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %432, i32 4, i32 31), !dbg !47 + %434 = add i32 %432, %433, !dbg !48 + %435 = icmp slt i32 %425, %428, !dbg !50 + %436 = icmp eq i32 %425, %428, !dbg !51 + %437 = icmp sgt i32 %431, %434, !dbg !63 + %438 = and i1 %436, %437, !dbg !52 + %439 = or i1 %435, %438, !dbg !53 + %440 = xor i32 %425, %428, !dbg !55 + %441 = select i1 %439, i32 %440, i32 0, !dbg !56 + %442 = xor i32 %441, %419, !dbg !57 + %443 = xor i32 %431, %434, !dbg !62 + %444 = select i1 %439, i32 %443, i32 0, !dbg !58 + %445 = xor i32 %444, %422, !dbg !59 + %446 = mul nuw nsw i32 %442, %32, !dbg !45 + %447 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %446, i32 2, i32 31), !dbg !47 + %448 = add i32 %446, %447, !dbg !48 + %449 = mul nuw nsw i32 %442, %.lobit, !dbg !49 + %450 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %449, i32 2, i32 31), !dbg !47 + %451 = add i32 %449, %450, !dbg !48 + %452 = mul nuw nsw i32 %445, %32, !dbg !60 + %453 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %452, i32 2, i32 31), !dbg !47 + %454 = add i32 %452, %453, !dbg !48 + %455 = mul nuw nsw i32 %445, %.lobit, !dbg !61 + %456 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %455, i32 2, i32 31), !dbg !47 + %457 = add i32 %455, %456, !dbg !48 + %458 = icmp slt i32 %448, %451, !dbg !50 + %459 = icmp eq i32 %448, %451, !dbg !51 + %460 = icmp sgt i32 %454, %457, !dbg !63 + %461 = and i1 %459, %460, !dbg !52 + %462 = or i1 %458, %461, !dbg !53 + %463 = xor i32 %448, %451, !dbg !55 + %464 = select i1 %462, i32 %463, i32 0, !dbg !56 + %465 = xor i32 %464, %442, !dbg !57 + %466 = xor i32 %454, %457, !dbg !62 + %467 = select i1 %462, i32 %466, i32 0, !dbg !58 + %468 = xor i32 %467, %445, !dbg !59 + %469 = mul nuw nsw i32 %465, %31, !dbg !45 + %470 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %469, i32 1, i32 31), !dbg !47 + %471 = add i32 %469, %470, !dbg !48 + %472 = mul nuw nsw i32 %465, %28, !dbg !49 + %473 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %472, i32 1, i32 31), !dbg !47 + %474 = add i32 %472, %473, !dbg !48 + %475 = mul nuw nsw i32 %468, %31, !dbg !60 + %476 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %475, i32 1, i32 31), !dbg !47 + %477 = add i32 %475, %476, !dbg !48 + %478 = mul nuw nsw i32 %468, %28, !dbg !61 + %479 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %478, i32 1, i32 31), !dbg !47 + %480 = add i32 %478, %479, !dbg !48 + %481 = icmp slt i32 %471, %474, !dbg !50 + %482 = icmp eq i32 %471, %474, !dbg !51 + %483 = icmp sgt i32 %477, %480, !dbg !63 + %484 = and i1 %482, %483, !dbg !52 + %485 = or i1 %481, %484, !dbg !53 + %486 = xor i32 %477, %480, !dbg !62 + %487 = select i1 %485, i32 %486, i32 0, !dbg !58 + %488 = xor i32 %487, %468, !dbg !59 + %narrow = select i1 %13, i1 %23, i1 false, !dbg !64 + %489 = zext i1 %narrow to i64, !dbg !64 + %narrow13 = select i1 %13, i1 %25, i1 false, !dbg !64 + %490 = zext i1 %narrow13 to i64, !dbg !64 + %491 = zext i1 %narrow to i32, !dbg !65 + %492 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %491, i32 8, i32 31), !dbg !65 + %493 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 0, i32 8, i32 31), !dbg !65 + %494 = insertelement <2 x i32> poison, i32 %492, i64 0, !dbg !65 + %495 = insertelement <2 x i32> %494, i32 %493, i64 1, !dbg !65 + %496 = bitcast <2 x i32> %495 to i64, !dbg !65 + %497 = add i64 %496, %489, !dbg !67 + %extelt.offset = lshr i64 %497, 32, !dbg !65 + %498 = trunc nuw i64 %extelt.offset to i32, !dbg !65 + %499 = trunc i64 %497 to i32, !dbg !65 + %500 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %499, i32 4, i32 31), !dbg !65 + %501 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %498, i32 4, i32 31), !dbg !65 + %502 = insertelement <2 x i32> poison, i32 %500, i64 0, !dbg !65 + %503 = insertelement <2 x i32> %502, i32 %501, i64 1, !dbg !65 + %504 = bitcast <2 x i32> %503 to i64, !dbg !65 + %505 = add i64 %497, %504, !dbg !67 + %extelt.offset14 = lshr i64 %505, 32, !dbg !65 + %506 = trunc nuw i64 %extelt.offset14 to i32, !dbg !65 + %507 = trunc i64 %505 to i32, !dbg !65 + %508 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %507, i32 2, i32 31), !dbg !65 + %509 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %506, i32 2, i32 31), !dbg !65 + %510 = insertelement <2 x i32> poison, i32 %508, i64 0, !dbg !65 + %511 = insertelement <2 x i32> %510, i32 %509, i64 1, !dbg !65 + %512 = bitcast <2 x i32> %511 to i64, !dbg !65 + %513 = add i64 %505, %512, !dbg !67 + %extelt.offset15 = lshr i64 %513, 32, !dbg !65 + %514 = trunc nuw i64 %extelt.offset15 to i32, !dbg !65 + %515 = trunc i64 %513 to i32, !dbg !65 + %516 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %515, i32 1, i32 31), !dbg !65 + %517 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %514, i32 1, i32 31), !dbg !65 + %518 = insertelement <2 x i32> poison, i32 %516, i64 0, !dbg !65 + %519 = insertelement <2 x i32> %518, i32 %517, i64 1, !dbg !65 + %520 = bitcast <2 x i32> %519 to i64, !dbg !65 + %521 = add i64 %513, %520, !dbg !67 + %522 = zext i1 %narrow13 to i32, !dbg !65 + %523 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %522, i32 8, i32 31), !dbg !65 + %524 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 0, i32 8, i32 31), !dbg !65 + %525 = insertelement <2 x i32> poison, i32 %523, i64 0, !dbg !65 + %526 = insertelement <2 x i32> %525, i32 %524, i64 1, !dbg !65 + %527 = bitcast <2 x i32> %526 to i64, !dbg !65 + %528 = add i64 %527, %490, !dbg !67 + %extelt.offset17 = lshr i64 %528, 32, !dbg !65 + %529 = trunc nuw i64 %extelt.offset17 to i32, !dbg !65 + %530 = trunc i64 %528 to i32, !dbg !65 + %531 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %530, i32 4, i32 31), !dbg !65 + %532 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %529, i32 4, i32 31), !dbg !65 + %533 = insertelement <2 x i32> poison, i32 %531, i64 0, !dbg !65 + %534 = insertelement <2 x i32> %533, i32 %532, i64 1, !dbg !65 + %535 = bitcast <2 x i32> %534 to i64, !dbg !65 + %536 = add i64 %528, %535, !dbg !67 + %extelt.offset18 = lshr i64 %536, 32, !dbg !65 + %537 = trunc nuw i64 %extelt.offset18 to i32, !dbg !65 + %538 = trunc i64 %536 to i32, !dbg !65 + %539 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %538, i32 2, i32 31), !dbg !65 + %540 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %537, i32 2, i32 31), !dbg !65 + %541 = insertelement <2 x i32> poison, i32 %539, i64 0, !dbg !65 + %542 = insertelement <2 x i32> %541, i32 %540, i64 1, !dbg !65 + %543 = bitcast <2 x i32> %542 to i64, !dbg !65 + %544 = add i64 %536, %543, !dbg !67 + %extelt.offset19 = lshr i64 %544, 32, !dbg !65 + %545 = trunc nuw i64 %extelt.offset19 to i32, !dbg !65 + %546 = trunc i64 %544 to i32, !dbg !65 + %547 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %546, i32 1, i32 31), !dbg !65 + %548 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %545, i32 1, i32 31), !dbg !65 + %narrow20 = select i1 %13, i1 %266, i1 false, !dbg !68 + %549 = zext i1 %narrow20 to i64, !dbg !68 + %narrow21 = select i1 %13, i1 %267, i1 false, !dbg !68 + %550 = zext i1 %narrow21 to i64, !dbg !68 + %551 = zext i1 %narrow20 to i32, !dbg !69 + %552 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %551, i32 8, i32 31), !dbg !69 + %553 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 0, i32 8, i32 31), !dbg !69 + %554 = insertelement <2 x i32> poison, i32 %552, i64 0, !dbg !69 + %555 = insertelement <2 x i32> %554, i32 %553, i64 1, !dbg !69 + %556 = bitcast <2 x i32> %555 to i64, !dbg !69 + %557 = add i64 %556, %549, !dbg !71 + %extelt.offset23 = lshr i64 %557, 32, !dbg !69 + %558 = trunc nuw i64 %extelt.offset23 to i32, !dbg !69 + %559 = trunc i64 %557 to i32, !dbg !69 + %560 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %559, i32 4, i32 31), !dbg !69 + %561 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %558, i32 4, i32 31), !dbg !69 + %562 = insertelement <2 x i32> poison, i32 %560, i64 0, !dbg !69 + %563 = insertelement <2 x i32> %562, i32 %561, i64 1, !dbg !69 + %564 = bitcast <2 x i32> %563 to i64, !dbg !69 + %565 = add i64 %557, %564, !dbg !71 + %extelt.offset24 = lshr i64 %565, 32, !dbg !69 + %566 = trunc nuw i64 %extelt.offset24 to i32, !dbg !69 + %567 = trunc i64 %565 to i32, !dbg !69 + %568 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %567, i32 2, i32 31), !dbg !69 + %569 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %566, i32 2, i32 31), !dbg !69 + %570 = insertelement <2 x i32> poison, i32 %568, i64 0, !dbg !69 + %571 = insertelement <2 x i32> %570, i32 %569, i64 1, !dbg !69 + %572 = bitcast <2 x i32> %571 to i64, !dbg !69 + %573 = add i64 %565, %572, !dbg !71 + %extelt.offset25 = lshr i64 %573, 32, !dbg !69 + %574 = trunc nuw i64 %extelt.offset25 to i32, !dbg !69 + %575 = trunc i64 %573 to i32, !dbg !69 + %576 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %575, i32 1, i32 31), !dbg !69 + %577 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %574, i32 1, i32 31), !dbg !69 + %578 = zext i1 %narrow21 to i32, !dbg !69 + %579 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %578, i32 8, i32 31), !dbg !69 + %580 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 0, i32 8, i32 31), !dbg !69 + %581 = insertelement <2 x i32> poison, i32 %579, i64 0, !dbg !69 + %582 = insertelement <2 x i32> %581, i32 %580, i64 1, !dbg !69 + %583 = bitcast <2 x i32> %582 to i64, !dbg !69 + %584 = add i64 %583, %550, !dbg !71 + %extelt.offset27 = lshr i64 %584, 32, !dbg !69 + %585 = trunc nuw i64 %extelt.offset27 to i32, !dbg !69 + %586 = trunc i64 %584 to i32, !dbg !69 + %587 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %586, i32 4, i32 31), !dbg !69 + %588 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %585, i32 4, i32 31), !dbg !69 + %589 = insertelement <2 x i32> poison, i32 %587, i64 0, !dbg !69 + %590 = insertelement <2 x i32> %589, i32 %588, i64 1, !dbg !69 + %591 = bitcast <2 x i32> %590 to i64, !dbg !69 + %592 = add i64 %584, %591, !dbg !71 + %extelt.offset28 = lshr i64 %592, 32, !dbg !69 + %593 = trunc nuw i64 %extelt.offset28 to i32, !dbg !69 + %594 = trunc i64 %592 to i32, !dbg !69 + %595 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %594, i32 2, i32 31), !dbg !69 + %596 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %593, i32 2, i32 31), !dbg !69 + %597 = insertelement <2 x i32> poison, i32 %595, i64 0, !dbg !69 + %598 = insertelement <2 x i32> %597, i32 %596, i64 1, !dbg !69 + %599 = bitcast <2 x i32> %598 to i64, !dbg !69 + %600 = add i64 %592, %599, !dbg !71 + %extelt.offset29 = lshr i64 %600, 32, !dbg !69 + %601 = trunc nuw i64 %extelt.offset29 to i32, !dbg !69 + %602 = trunc i64 %600 to i32, !dbg !69 + %603 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %602, i32 1, i32 31), !dbg !69 + %604 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %601, i32 1, i32 31), !dbg !69 + %605 = trunc i64 %521 to i32, !dbg !72 + %606 = icmp slt i32 %15, %605, !dbg !73 + %607 = select i1 %606, i32 %265, i32 16, !dbg !74 + %608 = add i32 %607, 17, !dbg !75 + %609 = icmp slt i32 %607, 0, !dbg !76 + %610 = select i1 %609, i32 %608, i32 %607, !dbg !77 + %611 = icmp ult i32 %610, 17, !dbg !78 + %612 = icmp samesign ugt i32 %12, 31, !dbg !18 + %613 = or i1 %612, %611, !dbg !79 + br i1 %613, label %615, label %614, !dbg !80 + +614: ; preds = %11 + tail call void @__assertfail(ptr nonnull @assertMessage_0, ptr nonnull @assertFile_0, i32 71, ptr nonnull @assertFunc_0, i64 1), !dbg !80 + unreachable, !dbg !80 + +615: ; preds = %11 + %616 = insertelement <2 x i32> poison, i32 %576, i64 0, !dbg !69 + %617 = insertelement <2 x i32> %616, i32 %577, i64 1, !dbg !69 + %618 = bitcast <2 x i32> %617 to i64, !dbg !69 + %619 = add i64 %573, %618, !dbg !71 + %620 = trunc i64 %619 to i32, !dbg !81 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !80 + %621 = icmp slt i32 %15, %620, !dbg !82 + %622 = select i1 %621, i32 %488, i32 16, !dbg !83 + %623 = add i32 %622, 17, !dbg !84 + %624 = icmp slt i32 %622, 0, !dbg !85 + %625 = select i1 %624, i32 %623, i32 %622, !dbg !86 + %626 = icmp ult i32 %625, 17, !dbg !87 + %627 = or i1 %612, %626, !dbg !88 + br i1 %627, label %629, label %628, !dbg !89 + +628: ; preds = %615 + tail call void @__assertfail(ptr nonnull @assertMessage_1, ptr nonnull @assertFile_1, i32 80, ptr nonnull @assertFunc_1, i64 1), !dbg !89 + unreachable, !dbg !89 + +629: ; preds = %615 + %630 = insertelement <2 x i32> poison, i32 %603, i64 0, !dbg !69 + %631 = insertelement <2 x i32> %630, i32 %604, i64 1, !dbg !69 + %632 = bitcast <2 x i32> %631 to i64, !dbg !69 + %633 = add i64 %600, %632, !dbg !71 + %634 = trunc i64 %633 to i32, !dbg !81 + %635 = insertelement <2 x i32> poison, i32 %547, i64 0, !dbg !65 + %636 = insertelement <2 x i32> %635, i32 %548, i64 1, !dbg !65 + %637 = bitcast <2 x i32> %636 to i64, !dbg !65 + %638 = add i64 %544, %637, !dbg !67 + %639 = trunc i64 %638 to i32, !dbg !72 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !89 + %640 = zext nneg i32 %12 to i64, !dbg !90 + %641 = getelementptr i32, ptr addrspace(1) %1, i64 %640, !dbg !90 + %642 = and i32 %14, 63, !dbg !91 + %643 = icmp eq i32 %642, 0, !dbg !91 + %644 = and i1 %13, %643, !dbg !91 + tail call void asm sideeffect "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b"(i32 %639, ptr addrspace(1) %641, i1 %644) #5, !dbg !91 + %645 = getelementptr i32, ptr addrspace(1) %2, i64 %640, !dbg !92 + tail call void asm sideeffect "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b"(i32 %634, ptr addrspace(1) %645, i1 %644) #5, !dbg !93 + %646 = getelementptr i32, ptr addrspace(1) %3, i64 %18, !dbg !94 + %647 = and i32 %14, 48, !dbg !95 + %648 = icmp eq i32 %647, 0, !dbg !95 + %649 = and i1 %13, %648, !dbg !95 + tail call void asm sideeffect "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b"(i32 %265, ptr addrspace(1) %646, i1 %649) #5, !dbg !95 + %650 = mul i32 %12, 17, !dbg !96 + %651 = add i32 %610, %650, !dbg !97 + %652 = sext i32 %651 to i64, !dbg !98 + %653 = getelementptr i32, ptr addrspace(1) %4, i64 %652, !dbg !98 + tail call void asm sideeffect "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b"(i32 1, ptr addrspace(1) %653, i1 %649) #5, !dbg !99 + %654 = getelementptr i32, ptr addrspace(1) %5, i64 %18, !dbg !100 + tail call void asm sideeffect "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b"(i32 %488, ptr addrspace(1) %654, i1 %649) #5, !dbg !101 + %655 = add i32 %625, %650, !dbg !102 + %656 = sext i32 %655 to i64, !dbg !103 + %657 = getelementptr i32, ptr addrspace(1) %6, i64 %656, !dbg !103 + tail call void asm sideeffect "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b"(i32 1, ptr addrspace(1) %657, i1 %649) #5, !dbg !104 + ret void, !dbg !105 +} + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare noundef range(i32 0, 2147483647) i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #2 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare noundef range(i32 0, 1024) i32 @llvm.nvvm.read.ptx.sreg.tid.x() #2 + +; Function Attrs: convergent nocallback nounwind memory(inaccessiblemem: readwrite) +declare i32 @llvm.nvvm.shfl.sync.bfly.i32(i32, i32, i32, i32) #3 + +; Function Attrs: convergent nocallback nounwind +declare void @llvm.nvvm.barrier.cta.sync.aligned.all(i32) #4 + +attributes #0 = { noreturn } +attributes #1 = { "nvvm.reqntid"="64" } +attributes #2 = { mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) } +attributes #3 = { convergent nocallback nounwind memory(inaccessiblemem: readwrite) } +attributes #4 = { convergent nocallback nounwind } +attributes #5 = { nounwind } + +!llvm.dbg.cu = !{!0} +!llvm.module.flags = !{!2, !3} +!llvm.ident = !{!4} + +!0 = distinct !DICompileUnit(language: DW_LANG_C, file: !1, producer: "triton", isOptimized: true, runtimeVersion: 0, emissionKind: LineTablesOnly) +!1 = !DIFile(filename: "cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py", directory: "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs") +!2 = !{i32 2, !"Debug Info Version", i32 3} +!3 = !{i32 4, !"nvvm-reflect-ftz", i32 1} +!4 = !{!"clang version 3.8.0 (tags/RELEASE_380/final)"} +!5 = !DISubprogram(name: "__assertfail", linkageName: "__assertfail", scope: !6, file: !6, type: !7, spFlags: DISPFlagOptimized) +!6 = !DIFile(filename: "", directory: "") +!7 = !DISubroutineType(cc: DW_CC_normal, types: !8) +!8 = !{} +!9 = distinct !DISubprogram(name: "triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2", linkageName: "triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2", scope: !1, file: !1, line: 18, type: !7, scopeLine: 18, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !0) +!10 = !DILocation(line: 24, column: 28, scope: !9) +!11 = !DILocation(line: 26, column: 21, scope: !9) +!12 = !DILocation(line: 27, column: 38, scope: !9) +!13 = !DILocation(line: 34, column: 40, scope: !9) +!14 = !DILocation(line: 34, column: 37, scope: !9) +!15 = !DILocation(line: 34, column: 30, scope: !9) +!16 = !DILocation(line: 34, column: 45, scope: !9) +!17 = !DILocation(line: 39, column: 18, scope: !9) +!18 = !DILocation(line: 0, scope: !9) +!19 = !DILocation(line: 627, column: 44, scope: !20, inlinedAt: !22) +!20 = distinct !DILexicalBlockFile(scope: !9, file: !21, discriminator: 0) +!21 = !DIFile(filename: "triton_helpers.py", directory: "/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime") +!22 = !DILocation(line: 46, column: 71, scope: !9) +!23 = !DILocation(line: 537, column: 21, scope: !20, inlinedAt: !22) +!24 = !DILocation(line: 538, column: 40, scope: !20, inlinedAt: !22) +!25 = !DILocation(line: 291, column: 36, scope: !26, inlinedAt: !22) +!26 = distinct !DILexicalBlockFile(scope: !9, file: !27, discriminator: 0) +!27 = !DIFile(filename: "standard.py", directory: "/workspace/specforge/lib/python3.11/site-packages/triton/language") +!28 = !DILocation(line: 261, column: 15, scope: !26, inlinedAt: !22) +!29 = !DILocation(line: 539, column: 41, scope: !20, inlinedAt: !22) +!30 = !DILocation(line: 548, column: 23, scope: !20, inlinedAt: !22) +!31 = !DILocation(line: 551, column: 23, scope: !20, inlinedAt: !22) +!32 = !DILocation(line: 574, column: 22, scope: !20, inlinedAt: !22) +!33 = !DILocation(line: 591, column: 21, scope: !20, inlinedAt: !22) +!34 = !DILocation(line: 594, column: 40, scope: !20, inlinedAt: !22) +!35 = !DILocation(line: 594, column: 29, scope: !20, inlinedAt: !22) +!36 = !DILocation(line: 594, column: 23, scope: !20, inlinedAt: !22) +!37 = !DILocation(line: 599, column: 28, scope: !20, inlinedAt: !22) +!38 = !DILocation(line: 600, column: 38, scope: !20, inlinedAt: !22) +!39 = !DILocation(line: 600, column: 46, scope: !20, inlinedAt: !22) +!40 = !DILocation(line: 600, column: 15, scope: !20, inlinedAt: !22) +!41 = !DILocation(line: 601, column: 48, scope: !20, inlinedAt: !22) +!42 = !DILocation(line: 601, column: 59, scope: !20, inlinedAt: !22) +!43 = !DILocation(line: 601, column: 22, scope: !20, inlinedAt: !22) +!44 = !DILocation(line: 47, column: 20, scope: !9) +!45 = !DILocation(line: 538, column: 40, scope: !20, inlinedAt: !46) +!46 = !DILocation(line: 51, column: 71, scope: !9) +!47 = !DILocation(line: 291, column: 36, scope: !26, inlinedAt: !46) +!48 = !DILocation(line: 261, column: 15, scope: !26, inlinedAt: !46) +!49 = !DILocation(line: 539, column: 41, scope: !20, inlinedAt: !46) +!50 = !DILocation(line: 574, column: 22, scope: !20, inlinedAt: !46) +!51 = !DILocation(line: 591, column: 21, scope: !20, inlinedAt: !46) +!52 = !DILocation(line: 594, column: 29, scope: !20, inlinedAt: !46) +!53 = !DILocation(line: 594, column: 23, scope: !20, inlinedAt: !46) +!54 = !DILocation(line: 599, column: 28, scope: !20, inlinedAt: !46) +!55 = !DILocation(line: 600, column: 38, scope: !20, inlinedAt: !46) +!56 = !DILocation(line: 600, column: 46, scope: !20, inlinedAt: !46) +!57 = !DILocation(line: 600, column: 15, scope: !20, inlinedAt: !46) +!58 = !DILocation(line: 601, column: 59, scope: !20, inlinedAt: !46) +!59 = !DILocation(line: 601, column: 22, scope: !20, inlinedAt: !46) +!60 = !DILocation(line: 548, column: 23, scope: !20, inlinedAt: !46) +!61 = !DILocation(line: 551, column: 23, scope: !20, inlinedAt: !46) +!62 = !DILocation(line: 601, column: 48, scope: !20, inlinedAt: !46) +!63 = !DILocation(line: 594, column: 40, scope: !20, inlinedAt: !46) +!64 = !DILocation(line: 54, column: 35, scope: !9) +!65 = !DILocation(line: 291, column: 36, scope: !26, inlinedAt: !66) +!66 = !DILocation(line: 55, column: 26, scope: !9) +!67 = !DILocation(line: 261, column: 15, scope: !26, inlinedAt: !66) +!68 = !DILocation(line: 58, column: 35, scope: !9) +!69 = !DILocation(line: 291, column: 36, scope: !26, inlinedAt: !70) +!70 = !DILocation(line: 59, column: 26, scope: !9) +!71 = !DILocation(line: 261, column: 15, scope: !26, inlinedAt: !70) +!72 = !DILocation(line: 60, column: 21, scope: !9) +!73 = !DILocation(line: 64, column: 19, scope: !9) +!74 = !DILocation(line: 66, column: 35, scope: !9) +!75 = !DILocation(line: 68, column: 20, scope: !9) +!76 = !DILocation(line: 69, column: 20, scope: !9) +!77 = !DILocation(line: 70, column: 35, scope: !9) +!78 = !DILocation(line: 71, column: 38, scope: !9) +!79 = !DILocation(line: 71, column: 53, scope: !9) +!80 = !DILocation(line: 71, column: 63, scope: !9) +!81 = !DILocation(line: 61, column: 21, scope: !9) +!82 = !DILocation(line: 75, column: 19, scope: !9) +!83 = !DILocation(line: 76, column: 35, scope: !9) +!84 = !DILocation(line: 77, column: 20, scope: !9) +!85 = !DILocation(line: 78, column: 20, scope: !9) +!86 = !DILocation(line: 79, column: 35, scope: !9) +!87 = !DILocation(line: 80, column: 38, scope: !9) +!88 = !DILocation(line: 80, column: 53, scope: !9) +!89 = !DILocation(line: 80, column: 63, scope: !9) +!90 = !DILocation(line: 81, column: 25, scope: !9) +!91 = !DILocation(line: 81, column: 37, scope: !9) +!92 = !DILocation(line: 82, column: 25, scope: !9) +!93 = !DILocation(line: 82, column: 37, scope: !9) +!94 = !DILocation(line: 83, column: 25, scope: !9) +!95 = !DILocation(line: 83, column: 47, scope: !9) +!96 = !DILocation(line: 84, column: 52, scope: !9) +!97 = !DILocation(line: 84, column: 49, scope: !9) +!98 = !DILocation(line: 84, column: 25, scope: !9) +!99 = !DILocation(line: 84, column: 85, scope: !9) +!100 = !DILocation(line: 85, column: 25, scope: !9) +!101 = !DILocation(line: 85, column: 47, scope: !9) +!102 = !DILocation(line: 86, column: 49, scope: !9) +!103 = !DILocation(line: 86, column: 25, scope: !9) +!104 = !DILocation(line: 86, column: 85, scope: !9) +!105 = !DILocation(line: 86, column: 4, scope: !9) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/NFXDDTZIEQAGR5JBWWCXV73QZILXJMIVVJVPZH3CHAIULDAND5UQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ptx b/SpecForge-ext/cache/compiled_kernels/triton/0/NFXDDTZIEQAGR5JBWWCXV73QZILXJMIVVJVPZH3CHAIULDAND5UQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ptx new file mode 100644 index 0000000000000000000000000000000000000000..ec7f6f809950e068ed6a5e96a5e772b24e6a2c9d --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/NFXDDTZIEQAGR5JBWWCXV73QZILXJMIVVJVPZH3CHAIULDAND5UQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ptx @@ -0,0 +1,1673 @@ +// +// Generated by LLVM NVPTX Back-End +// + +.version 8.7 +.target sm_90a +.address_size 64 + + // .globl triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2 // -- Begin function triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2 +.extern .func __assertfail +( + .param .b64 __assertfail_param_0, + .param .b64 __assertfail_param_1, + .param .b32 __assertfail_param_2, + .param .b64 __assertfail_param_3, + .param .b64 __assertfail_param_4 +) +.noreturn; +.global .align 1 .b8 assertFunc_1[8] = {117, 110, 107, 110, 111, 119, 110}; +.global .align 1 .b8 assertFile_1[114] = {47, 119, 111, 114, 107, 115, 112, 97, 99, 101, 47, 104, 97, 110, 114, 117, 105, 47, 83, 112, 101, 99, 70, 111, 114, 103, 101, 45, 101, 120, 116, 47, 99, 97, 99, 104, 101, 47, 99, 111, 109, 112, 105, 108, 101, 100, 95, 107, 101, 114, 110, 101, 108, 115, 47, 98, 115, 47, 99, 98, 115, 54, 53, 50, 105, 55, 99, 116, 53, 55, 117, 103, 120, 54, 118, 110, 99, 51, 53, 110, 54, 51, 116, 105, 110, 112, 122, 117, 54, 98, 97, 51, 122, 121, 109, 117, 102, 111, 100, 105, 103, 52, 104, 112, 122, 97, 114, 119, 99, 108, 46, 112, 121}; +.global .align 1 .b8 assertMessage_1[37] = {105, 110, 100, 101, 120, 32, 111, 117, 116, 32, 111, 102, 32, 98, 111, 117, 110, 100, 115, 58, 32, 48, 32, 60, 61, 32, 116, 109, 112, 52, 57, 32, 60, 32, 49, 55}; +.global .align 1 .b8 assertFunc_0[8] = {117, 110, 107, 110, 111, 119, 110}; +.global .align 1 .b8 assertFile_0[114] = {47, 119, 111, 114, 107, 115, 112, 97, 99, 101, 47, 104, 97, 110, 114, 117, 105, 47, 83, 112, 101, 99, 70, 111, 114, 103, 101, 45, 101, 120, 116, 47, 99, 97, 99, 104, 101, 47, 99, 111, 109, 112, 105, 108, 101, 100, 95, 107, 101, 114, 110, 101, 108, 115, 47, 98, 115, 47, 99, 98, 115, 54, 53, 50, 105, 55, 99, 116, 53, 55, 117, 103, 120, 54, 118, 110, 99, 51, 53, 110, 54, 51, 116, 105, 110, 112, 122, 117, 54, 98, 97, 51, 122, 121, 109, 117, 102, 111, 100, 105, 103, 52, 104, 112, 122, 97, 114, 119, 99, 108, 46, 112, 121}; +.global .align 1 .b8 assertMessage_0[37] = {105, 110, 100, 101, 120, 32, 111, 117, 116, 32, 111, 102, 32, 98, 111, 117, 110, 100, 115, 58, 32, 48, 32, 60, 61, 32, 116, 109, 112, 52, 48, 32, 60, 32, 49, 55}; + // @triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2 +.visible .entry triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2( + .param .u64 .ptr .global .align 1 triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_0, + .param .u64 .ptr .global .align 1 triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_1, + .param .u64 .ptr .global .align 1 triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_2, + .param .u64 .ptr .global .align 1 triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_3, + .param .u64 .ptr .global .align 1 triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_4, + .param .u64 .ptr .global .align 1 triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_5, + .param .u64 .ptr .global .align 1 triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_6, + .param .u32 triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_7, + .param .u32 triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_8, + .param .u64 .ptr .global .align 1 triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_9, + .param .u64 .ptr .global .align 1 triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_10 +) +.reqntid 64 +{ + .reg .pred %p<140>; + .reg .b32 %r<453>; + .reg .b64 %rd<107>; + .loc 1 18 0 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:18:0 +$L__func_begin0: + .loc 1 18 0 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:18:0 + +// %bb.0: + ld.param.b64 %rd15, [triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_0]; +$L__tmp0: + .loc 1 24 28 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:24:28 + mov.u32 %r1, %ctaid.x; + .loc 1 26 21 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:26:21 + setp.lt.u32 %p2, %r1, 32; + .loc 1 27 38 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:27:38 + mov.u32 %r2, %tid.x; + and.b32 %r3, %r2, 15; + .loc 1 34 40 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:34:40 + shl.b32 %r14, %r1, 4; + .loc 1 34 37 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:34:37 + or.b32 %r15, %r3, %r14; + .loc 1 34 30 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:34:30 + mad.wide.s32 %rd12, %r15, 8, %rd15; + .loc 1 34 45 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:34:45 + // begin inline asm + mov.u64 %rd11, 0x0; + @%p2 ld.global.b64 { %rd11 }, [ %rd12 + 0 ]; + // end inline asm + // begin inline asm + mov.u64 %rd13, 0x0; + @%p2 ld.global.b64 { %rd13 }, [ %rd12 + 0 ]; + // end inline asm + .loc 1 39 18 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:39:18 + add.s64 %rd16, %rd11, -1; + setp.lt.u64 %p3, %rd16, 16383; + add.s64 %rd17, %rd13, -1; + setp.lt.u64 %p4, %rd17, 16383; + .loc 1 0 0 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:0 + selp.b32 %r16, 1, 0, %p3; +$L__tmp1: + .loc 2 627 44 // triton_helpers.py:627:44 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shr.u32 %r17, %r2, 1; + bfe.u32 %r18, %r2, 1, 1; + and.b32 %r19, %r2, 1; + shr.u32 %r20, %r2, 2; + bfe.u32 %r21, %r2, 2, 1; + shr.u32 %r22, %r2, 3; + bfe.u32 %r23, %r2, 3, 1; + .loc 2 537 21 // triton_helpers.py:537:21 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r24, %r19, 1; + xor.b32 %r25, %r18, 1; + xor.b32 %r26, %r21, 1; + xor.b32 %r27, %r23, 1; + .loc 2 538 40 // triton_helpers.py:538:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r28, %r24, 0, %p3; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r29, %r28, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r30, %r28, %r29; + .loc 2 539 41 // triton_helpers.py:539:41 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r31, %r19, 0, %p3; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r32, %r31, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r33, %r31, %r32; + .loc 2 548 23 // triton_helpers.py:548:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r34, %r24, %r3; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r35, %r34, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r36, %r35, %r34; + .loc 2 551 23 // triton_helpers.py:551:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r37, %r3, %r19; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r38, %r37, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r39, %r38, %r37; + .loc 2 574 22 // triton_helpers.py:574:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + setp.lt.s32 %p5, %r30, %r33; + .loc 2 591 21 // triton_helpers.py:591:21 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + setp.eq.b32 %p6, %r30, %r33; + .loc 2 594 40 // triton_helpers.py:594:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + setp.gt.s32 %p7, %r36, %r39; + .loc 2 594 29 // triton_helpers.py:594:29 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + and.pred %p8, %p6, %p7; + .loc 2 594 23 // triton_helpers.py:594:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + or.pred %p9, %p5, %p8; + .loc 2 599 28 // triton_helpers.py:599:28 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + and.b32 %r40, %r17, 1; + setp.ne.b32 %p10, %r40, 0; + xor.pred %p11, %p9, %p10; + .loc 2 600 38 // triton_helpers.py:600:38 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r41, %r30, %r33; + .loc 2 600 46 // triton_helpers.py:600:46 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r42, %r41, 0, %p11; + .loc 2 600 15 // triton_helpers.py:600:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r43, %r42, %r16; + .loc 2 601 48 // triton_helpers.py:601:48 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r44, %r39, %r36; + .loc 2 601 59 // triton_helpers.py:601:59 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r45, %r44, 0, %p11; + .loc 2 601 22 // triton_helpers.py:601:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r46, %r45, %r3; + .loc 2 538 40 // triton_helpers.py:538:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r47, %r43, %r25; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r48, %r47, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r49, %r47, %r48; + .loc 2 539 41 // triton_helpers.py:539:41 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r50, %r43, %r18; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r51, %r50, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r52, %r50, %r51; + .loc 2 548 23 // triton_helpers.py:548:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r53, %r46, %r25; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r54, %r53, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r55, %r53, %r54; + .loc 2 551 23 // triton_helpers.py:551:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r56, %r46, %r18; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r57, %r56, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r58, %r56, %r57; + .loc 2 599 28 // triton_helpers.py:599:28 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + and.b32 %r59, %r20, 1; + setp.ne.b32 %p12, %r59, 0; + setp.ge.s32 %p13, %r49, %r52; + setp.ne.b32 %p14, %r49, %r52; + setp.le.s32 %p15, %r55, %r58; + or.pred %p16, %p14, %p15; + and.pred %p17, %p13, %p16; + xor.pred %p18, %p17, %p12; + .loc 2 600 38 // triton_helpers.py:600:38 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r60, %r49, %r52; + .loc 2 600 46 // triton_helpers.py:600:46 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r61, 0, %r60, %p18; + .loc 2 600 15 // triton_helpers.py:600:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r62, %r61, %r43; + .loc 2 601 48 // triton_helpers.py:601:48 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r63, %r55, %r58; + .loc 2 601 59 // triton_helpers.py:601:59 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r64, 0, %r63, %p18; + .loc 2 601 22 // triton_helpers.py:601:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r65, %r64, %r46; + .loc 2 538 40 // triton_helpers.py:538:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r66, %r62, %r24; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r67, %r66, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r68, %r66, %r67; + .loc 2 539 41 // triton_helpers.py:539:41 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r69, %r62, %r19; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r70, %r69, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r71, %r69, %r70; + .loc 2 548 23 // triton_helpers.py:548:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r72, %r65, %r24; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r73, %r72, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r74, %r72, %r73; + .loc 2 551 23 // triton_helpers.py:551:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r75, %r65, %r19; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r76, %r75, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r77, %r75, %r76; + .loc 2 599 28 // triton_helpers.py:599:28 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + setp.ge.s32 %p19, %r68, %r71; + setp.ne.b32 %p20, %r68, %r71; + setp.le.s32 %p21, %r74, %r77; + or.pred %p22, %p20, %p21; + and.pred %p23, %p19, %p22; + xor.pred %p24, %p23, %p12; + .loc 2 600 38 // triton_helpers.py:600:38 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r78, %r68, %r71; + .loc 2 600 46 // triton_helpers.py:600:46 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r79, 0, %r78, %p24; + .loc 2 600 15 // triton_helpers.py:600:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r80, %r79, %r62; + .loc 2 601 48 // triton_helpers.py:601:48 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r81, %r74, %r77; + .loc 2 601 59 // triton_helpers.py:601:59 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r82, 0, %r81, %p24; + .loc 2 601 22 // triton_helpers.py:601:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r83, %r82, %r65; + .loc 2 538 40 // triton_helpers.py:538:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r84, %r80, %r26; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r85, %r84, 4, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r86, %r84, %r85; + .loc 2 539 41 // triton_helpers.py:539:41 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r87, %r80, %r21; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r88, %r87, 4, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r89, %r87, %r88; + .loc 2 548 23 // triton_helpers.py:548:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r90, %r83, %r26; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r91, %r90, 4, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r92, %r90, %r91; + .loc 2 551 23 // triton_helpers.py:551:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r93, %r83, %r21; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r94, %r93, 4, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r95, %r93, %r94; + .loc 2 599 28 // triton_helpers.py:599:28 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + and.b32 %r96, %r22, 1; + setp.ne.b32 %p25, %r96, 0; + setp.ge.s32 %p26, %r86, %r89; + setp.ne.b32 %p27, %r86, %r89; + setp.le.s32 %p28, %r92, %r95; + or.pred %p29, %p27, %p28; + and.pred %p30, %p26, %p29; + xor.pred %p31, %p30, %p25; + .loc 2 600 38 // triton_helpers.py:600:38 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r97, %r86, %r89; + .loc 2 600 46 // triton_helpers.py:600:46 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r98, 0, %r97, %p31; + .loc 2 600 15 // triton_helpers.py:600:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r99, %r98, %r80; + .loc 2 601 48 // triton_helpers.py:601:48 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r100, %r92, %r95; + .loc 2 601 59 // triton_helpers.py:601:59 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r101, 0, %r100, %p31; + .loc 2 601 22 // triton_helpers.py:601:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r102, %r101, %r83; + .loc 2 538 40 // triton_helpers.py:538:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r103, %r99, %r25; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r104, %r103, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r105, %r103, %r104; + .loc 2 539 41 // triton_helpers.py:539:41 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r106, %r99, %r18; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r107, %r106, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r108, %r106, %r107; + .loc 2 548 23 // triton_helpers.py:548:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r109, %r102, %r25; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r110, %r109, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r111, %r109, %r110; + .loc 2 551 23 // triton_helpers.py:551:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r112, %r102, %r18; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r113, %r112, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r114, %r112, %r113; + .loc 2 599 28 // triton_helpers.py:599:28 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + setp.ge.s32 %p32, %r105, %r108; + setp.ne.b32 %p33, %r105, %r108; + setp.le.s32 %p34, %r111, %r114; + or.pred %p35, %p33, %p34; + and.pred %p36, %p32, %p35; + xor.pred %p37, %p36, %p25; + .loc 2 600 38 // triton_helpers.py:600:38 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r115, %r105, %r108; + .loc 2 600 46 // triton_helpers.py:600:46 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r116, 0, %r115, %p37; + .loc 2 600 15 // triton_helpers.py:600:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r117, %r116, %r99; + .loc 2 601 48 // triton_helpers.py:601:48 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r118, %r111, %r114; + .loc 2 601 59 // triton_helpers.py:601:59 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r119, 0, %r118, %p37; + .loc 2 601 22 // triton_helpers.py:601:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r120, %r119, %r102; + .loc 2 538 40 // triton_helpers.py:538:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r121, %r117, %r24; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r122, %r121, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r123, %r121, %r122; + .loc 2 539 41 // triton_helpers.py:539:41 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r124, %r117, %r19; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r125, %r124, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r126, %r124, %r125; + .loc 2 548 23 // triton_helpers.py:548:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r127, %r120, %r24; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r128, %r127, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r129, %r127, %r128; + .loc 2 551 23 // triton_helpers.py:551:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r130, %r120, %r19; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r131, %r130, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r132, %r130, %r131; + .loc 2 599 28 // triton_helpers.py:599:28 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + setp.ge.s32 %p38, %r123, %r126; + setp.ne.b32 %p39, %r123, %r126; + setp.le.s32 %p40, %r129, %r132; + or.pred %p41, %p39, %p40; + and.pred %p42, %p38, %p41; + xor.pred %p43, %p42, %p25; + .loc 2 600 38 // triton_helpers.py:600:38 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r133, %r123, %r126; + .loc 2 600 46 // triton_helpers.py:600:46 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r134, 0, %r133, %p43; + .loc 2 600 15 // triton_helpers.py:600:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r135, %r134, %r117; + .loc 2 601 48 // triton_helpers.py:601:48 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r136, %r129, %r132; + .loc 2 601 59 // triton_helpers.py:601:59 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r137, 0, %r136, %p43; + .loc 2 601 22 // triton_helpers.py:601:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r138, %r137, %r120; + .loc 2 538 40 // triton_helpers.py:538:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r139, %r135, %r27; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r140, %r139, 8, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r141, %r139, %r140; + .loc 2 539 41 // triton_helpers.py:539:41 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r142, %r135, %r23; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r143, %r142, 8, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r144, %r142, %r143; + .loc 2 548 23 // triton_helpers.py:548:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r145, %r138, %r27; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r146, %r145, 8, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r147, %r145, %r146; + .loc 2 551 23 // triton_helpers.py:551:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r148, %r138, %r23; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r149, %r148, 8, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r150, %r148, %r149; + .loc 2 574 22 // triton_helpers.py:574:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + setp.lt.s32 %p44, %r141, %r144; + .loc 2 591 21 // triton_helpers.py:591:21 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + setp.eq.b32 %p45, %r141, %r144; + .loc 2 594 40 // triton_helpers.py:594:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + setp.gt.s32 %p46, %r147, %r150; + .loc 2 594 29 // triton_helpers.py:594:29 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + and.pred %p47, %p45, %p46; + .loc 2 594 23 // triton_helpers.py:594:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + or.pred %p48, %p44, %p47; + .loc 2 600 38 // triton_helpers.py:600:38 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r151, %r141, %r144; + .loc 2 600 46 // triton_helpers.py:600:46 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r152, %r151, 0, %p48; + .loc 2 600 15 // triton_helpers.py:600:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r153, %r152, %r135; + .loc 2 601 48 // triton_helpers.py:601:48 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r154, %r147, %r150; + .loc 2 601 59 // triton_helpers.py:601:59 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r155, %r154, 0, %p48; + .loc 2 601 22 // triton_helpers.py:601:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r156, %r155, %r138; + .loc 2 538 40 // triton_helpers.py:538:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r157, %r153, %r26; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r158, %r157, 4, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r159, %r157, %r158; + .loc 2 539 41 // triton_helpers.py:539:41 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r160, %r153, %r21; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r161, %r160, 4, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r162, %r160, %r161; + .loc 2 548 23 // triton_helpers.py:548:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r163, %r156, %r26; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r164, %r163, 4, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r165, %r163, %r164; + .loc 2 551 23 // triton_helpers.py:551:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r166, %r156, %r21; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r167, %r166, 4, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r168, %r166, %r167; + .loc 2 574 22 // triton_helpers.py:574:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + setp.lt.s32 %p49, %r159, %r162; + .loc 2 591 21 // triton_helpers.py:591:21 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + setp.eq.b32 %p50, %r159, %r162; + .loc 2 594 40 // triton_helpers.py:594:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + setp.gt.s32 %p51, %r165, %r168; + .loc 2 594 29 // triton_helpers.py:594:29 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + and.pred %p52, %p50, %p51; + .loc 2 594 23 // triton_helpers.py:594:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + or.pred %p53, %p49, %p52; + .loc 2 600 38 // triton_helpers.py:600:38 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r169, %r159, %r162; + .loc 2 600 46 // triton_helpers.py:600:46 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r170, %r169, 0, %p53; + .loc 2 600 15 // triton_helpers.py:600:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r171, %r170, %r153; + .loc 2 601 48 // triton_helpers.py:601:48 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r172, %r165, %r168; + .loc 2 601 59 // triton_helpers.py:601:59 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r173, %r172, 0, %p53; + .loc 2 601 22 // triton_helpers.py:601:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r174, %r173, %r156; + .loc 2 538 40 // triton_helpers.py:538:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r175, %r171, %r25; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r176, %r175, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r177, %r175, %r176; + .loc 2 539 41 // triton_helpers.py:539:41 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r178, %r171, %r18; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r179, %r178, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r180, %r178, %r179; + .loc 2 548 23 // triton_helpers.py:548:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r181, %r174, %r25; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r182, %r181, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r183, %r181, %r182; + .loc 2 551 23 // triton_helpers.py:551:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r184, %r174, %r18; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r185, %r184, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r186, %r184, %r185; + .loc 2 574 22 // triton_helpers.py:574:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + setp.lt.s32 %p54, %r177, %r180; + .loc 2 591 21 // triton_helpers.py:591:21 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + setp.eq.b32 %p55, %r177, %r180; + .loc 2 594 40 // triton_helpers.py:594:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + setp.gt.s32 %p56, %r183, %r186; + .loc 2 594 29 // triton_helpers.py:594:29 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + and.pred %p57, %p55, %p56; + .loc 2 594 23 // triton_helpers.py:594:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + or.pred %p58, %p54, %p57; + .loc 2 600 38 // triton_helpers.py:600:38 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r187, %r177, %r180; + .loc 2 600 46 // triton_helpers.py:600:46 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r188, %r187, 0, %p58; + .loc 2 600 15 // triton_helpers.py:600:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r189, %r188, %r171; + .loc 2 601 48 // triton_helpers.py:601:48 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r190, %r183, %r186; + .loc 2 601 59 // triton_helpers.py:601:59 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r191, %r190, 0, %p58; + .loc 2 601 22 // triton_helpers.py:601:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r192, %r191, %r174; + .loc 2 538 40 // triton_helpers.py:538:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r193, %r189, %r24; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r194, %r193, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r195, %r193, %r194; + .loc 2 539 41 // triton_helpers.py:539:41 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r196, %r189, %r19; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r197, %r196, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r198, %r196, %r197; + .loc 2 548 23 // triton_helpers.py:548:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r199, %r192, %r24; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r200, %r199, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r201, %r199, %r200; + .loc 2 551 23 // triton_helpers.py:551:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + mul.lo.s32 %r202, %r192, %r19; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + shfl.sync.bfly.b32 %r203, %r202, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + add.s32 %r204, %r202, %r203; + .loc 2 574 22 // triton_helpers.py:574:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + setp.lt.s32 %p59, %r195, %r198; + .loc 2 591 21 // triton_helpers.py:591:21 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + setp.eq.b32 %p60, %r195, %r198; + .loc 2 594 40 // triton_helpers.py:594:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + setp.gt.s32 %p61, %r201, %r204; + .loc 2 601 48 // triton_helpers.py:601:48 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r205, %r201, %r204; + .loc 2 601 59 // triton_helpers.py:601:59 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + selp.b32 %r206, %r205, 0, %p61; + selp.b32 %r207, %r206, 0, %p60; + selp.b32 %r208, %r205, %r207, %p59; + .loc 2 601 22 // triton_helpers.py:601:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:46:71 ] + xor.b32 %r4, %r208, %r192; +$L__tmp2: + .loc 1 47 20 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:47:20 + setp.eq.b64 %p62, %rd11, 16384; + setp.eq.b64 %p63, %rd13, 16384; + .loc 1 0 0 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:0 + selp.b32 %r209, 1, 0, %p62; +$L__tmp3: + .loc 2 538 40 // triton_helpers.py:538:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + selp.b32 %r210, %r24, 0, %p62; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r211, %r210, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r212, %r211, %r210; + .loc 2 539 41 // triton_helpers.py:539:41 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + selp.b32 %r213, %r19, 0, %p62; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r214, %r213, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r215, %r214, %r213; + .loc 2 574 22 // triton_helpers.py:574:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + setp.lt.s32 %p64, %r212, %r215; + .loc 2 591 21 // triton_helpers.py:591:21 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + setp.eq.b32 %p65, %r212, %r215; + .loc 2 594 29 // triton_helpers.py:594:29 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + and.pred %p66, %p7, %p65; + .loc 2 594 23 // triton_helpers.py:594:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + or.pred %p67, %p64, %p66; + .loc 2 599 28 // triton_helpers.py:599:28 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.pred %p68, %p67, %p10; + .loc 2 600 38 // triton_helpers.py:600:38 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r216, %r215, %r212; + .loc 2 600 46 // triton_helpers.py:600:46 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + selp.b32 %r217, %r216, 0, %p68; + .loc 2 600 15 // triton_helpers.py:600:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r218, %r217, %r209; + .loc 2 601 59 // triton_helpers.py:601:59 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + selp.b32 %r219, %r44, 0, %p68; + .loc 2 601 22 // triton_helpers.py:601:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r220, %r219, %r3; + .loc 2 538 40 // triton_helpers.py:538:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r221, %r218, %r25; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r222, %r221, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r223, %r221, %r222; + .loc 2 539 41 // triton_helpers.py:539:41 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r224, %r218, %r18; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r225, %r224, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r226, %r224, %r225; + .loc 2 548 23 // triton_helpers.py:548:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r227, %r220, %r25; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r228, %r227, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r229, %r227, %r228; + .loc 2 551 23 // triton_helpers.py:551:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r230, %r220, %r18; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r231, %r230, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r232, %r230, %r231; + .loc 2 599 28 // triton_helpers.py:599:28 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + setp.ge.s32 %p69, %r223, %r226; + setp.ne.b32 %p70, %r223, %r226; + setp.le.s32 %p71, %r229, %r232; + or.pred %p72, %p70, %p71; + and.pred %p73, %p69, %p72; + xor.pred %p74, %p73, %p12; + .loc 2 600 38 // triton_helpers.py:600:38 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r233, %r223, %r226; + .loc 2 600 46 // triton_helpers.py:600:46 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + selp.b32 %r234, 0, %r233, %p74; + .loc 2 600 15 // triton_helpers.py:600:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r235, %r234, %r218; + .loc 2 601 48 // triton_helpers.py:601:48 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r236, %r229, %r232; + .loc 2 601 59 // triton_helpers.py:601:59 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + selp.b32 %r237, 0, %r236, %p74; + .loc 2 601 22 // triton_helpers.py:601:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r238, %r237, %r220; + .loc 2 538 40 // triton_helpers.py:538:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r239, %r235, %r24; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r240, %r239, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r241, %r239, %r240; + .loc 2 539 41 // triton_helpers.py:539:41 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r242, %r235, %r19; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r243, %r242, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r244, %r242, %r243; + .loc 2 548 23 // triton_helpers.py:548:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r245, %r238, %r24; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r246, %r245, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r247, %r245, %r246; + .loc 2 551 23 // triton_helpers.py:551:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r248, %r238, %r19; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r249, %r248, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r250, %r248, %r249; + .loc 2 599 28 // triton_helpers.py:599:28 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + setp.ge.s32 %p75, %r241, %r244; + setp.ne.b32 %p76, %r241, %r244; + setp.le.s32 %p77, %r247, %r250; + or.pred %p78, %p76, %p77; + and.pred %p79, %p75, %p78; + xor.pred %p80, %p79, %p12; + .loc 2 600 38 // triton_helpers.py:600:38 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r251, %r241, %r244; + .loc 2 600 46 // triton_helpers.py:600:46 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + selp.b32 %r252, 0, %r251, %p80; + .loc 2 600 15 // triton_helpers.py:600:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r253, %r252, %r235; + .loc 2 601 48 // triton_helpers.py:601:48 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r254, %r247, %r250; + .loc 2 601 59 // triton_helpers.py:601:59 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + selp.b32 %r255, 0, %r254, %p80; + .loc 2 601 22 // triton_helpers.py:601:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r256, %r255, %r238; + .loc 2 538 40 // triton_helpers.py:538:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r257, %r253, %r26; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r258, %r257, 4, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r259, %r257, %r258; + .loc 2 539 41 // triton_helpers.py:539:41 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r260, %r253, %r21; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r261, %r260, 4, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r262, %r260, %r261; + .loc 2 548 23 // triton_helpers.py:548:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r263, %r256, %r26; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r264, %r263, 4, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r265, %r263, %r264; + .loc 2 551 23 // triton_helpers.py:551:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r266, %r256, %r21; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r267, %r266, 4, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r268, %r266, %r267; + .loc 2 599 28 // triton_helpers.py:599:28 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + setp.ge.s32 %p81, %r259, %r262; + setp.ne.b32 %p82, %r259, %r262; + setp.le.s32 %p83, %r265, %r268; + or.pred %p84, %p82, %p83; + and.pred %p85, %p81, %p84; + xor.pred %p86, %p85, %p25; + .loc 2 600 38 // triton_helpers.py:600:38 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r269, %r259, %r262; + .loc 2 600 46 // triton_helpers.py:600:46 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + selp.b32 %r270, 0, %r269, %p86; + .loc 2 600 15 // triton_helpers.py:600:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r271, %r270, %r253; + .loc 2 601 48 // triton_helpers.py:601:48 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r272, %r265, %r268; + .loc 2 601 59 // triton_helpers.py:601:59 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + selp.b32 %r273, 0, %r272, %p86; + .loc 2 601 22 // triton_helpers.py:601:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r274, %r273, %r256; + .loc 2 538 40 // triton_helpers.py:538:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r275, %r271, %r25; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r276, %r275, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r277, %r275, %r276; + .loc 2 539 41 // triton_helpers.py:539:41 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r278, %r271, %r18; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r279, %r278, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r280, %r278, %r279; + .loc 2 548 23 // triton_helpers.py:548:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r281, %r274, %r25; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r282, %r281, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r283, %r281, %r282; + .loc 2 551 23 // triton_helpers.py:551:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r284, %r274, %r18; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r285, %r284, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r286, %r284, %r285; + .loc 2 599 28 // triton_helpers.py:599:28 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + setp.ge.s32 %p87, %r277, %r280; + setp.ne.b32 %p88, %r277, %r280; + setp.le.s32 %p89, %r283, %r286; + or.pred %p90, %p88, %p89; + and.pred %p91, %p87, %p90; + xor.pred %p92, %p91, %p25; + .loc 2 600 38 // triton_helpers.py:600:38 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r287, %r277, %r280; + .loc 2 600 46 // triton_helpers.py:600:46 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + selp.b32 %r288, 0, %r287, %p92; + .loc 2 600 15 // triton_helpers.py:600:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r289, %r288, %r271; + .loc 2 601 48 // triton_helpers.py:601:48 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r290, %r283, %r286; + .loc 2 601 59 // triton_helpers.py:601:59 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + selp.b32 %r291, 0, %r290, %p92; + .loc 2 601 22 // triton_helpers.py:601:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r292, %r291, %r274; + .loc 2 538 40 // triton_helpers.py:538:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r293, %r289, %r24; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r294, %r293, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r295, %r293, %r294; + .loc 2 539 41 // triton_helpers.py:539:41 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r296, %r289, %r19; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r297, %r296, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r298, %r296, %r297; + .loc 2 548 23 // triton_helpers.py:548:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r299, %r292, %r24; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r300, %r299, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r301, %r299, %r300; + .loc 2 551 23 // triton_helpers.py:551:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r302, %r292, %r19; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r303, %r302, 1, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r304, %r302, %r303; + .loc 2 599 28 // triton_helpers.py:599:28 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + setp.ge.s32 %p93, %r295, %r298; + setp.ne.b32 %p94, %r295, %r298; + setp.le.s32 %p95, %r301, %r304; + or.pred %p96, %p94, %p95; + and.pred %p97, %p93, %p96; + xor.pred %p98, %p97, %p25; + .loc 2 600 38 // triton_helpers.py:600:38 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r305, %r295, %r298; + .loc 2 600 46 // triton_helpers.py:600:46 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + selp.b32 %r306, 0, %r305, %p98; + .loc 2 600 15 // triton_helpers.py:600:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r307, %r306, %r289; + .loc 2 601 48 // triton_helpers.py:601:48 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r308, %r301, %r304; + .loc 2 601 59 // triton_helpers.py:601:59 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + selp.b32 %r309, 0, %r308, %p98; + .loc 2 601 22 // triton_helpers.py:601:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r310, %r309, %r292; + .loc 2 538 40 // triton_helpers.py:538:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r311, %r307, %r27; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r312, %r311, 8, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r313, %r311, %r312; + .loc 2 539 41 // triton_helpers.py:539:41 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r314, %r307, %r23; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r315, %r314, 8, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r316, %r314, %r315; + .loc 2 548 23 // triton_helpers.py:548:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r317, %r310, %r27; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r318, %r317, 8, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r319, %r317, %r318; + .loc 2 551 23 // triton_helpers.py:551:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r320, %r310, %r23; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r321, %r320, 8, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r322, %r320, %r321; + .loc 2 574 22 // triton_helpers.py:574:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + setp.lt.s32 %p99, %r313, %r316; + .loc 2 591 21 // triton_helpers.py:591:21 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + setp.eq.b32 %p100, %r313, %r316; + .loc 2 594 40 // triton_helpers.py:594:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + setp.gt.s32 %p101, %r319, %r322; + .loc 2 594 29 // triton_helpers.py:594:29 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + and.pred %p102, %p100, %p101; + .loc 2 594 23 // triton_helpers.py:594:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + or.pred %p103, %p99, %p102; + .loc 2 600 38 // triton_helpers.py:600:38 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r323, %r313, %r316; + .loc 2 600 46 // triton_helpers.py:600:46 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + selp.b32 %r324, %r323, 0, %p103; + .loc 2 600 15 // triton_helpers.py:600:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r325, %r324, %r307; + .loc 2 601 48 // triton_helpers.py:601:48 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r326, %r319, %r322; + .loc 2 601 59 // triton_helpers.py:601:59 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + selp.b32 %r327, %r326, 0, %p103; + .loc 2 601 22 // triton_helpers.py:601:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r328, %r327, %r310; + .loc 2 538 40 // triton_helpers.py:538:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r329, %r325, %r26; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r330, %r329, 4, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r331, %r329, %r330; + .loc 2 539 41 // triton_helpers.py:539:41 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r332, %r325, %r21; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r333, %r332, 4, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r334, %r332, %r333; + .loc 2 548 23 // triton_helpers.py:548:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r335, %r328, %r26; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r336, %r335, 4, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r337, %r335, %r336; + .loc 2 551 23 // triton_helpers.py:551:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r338, %r328, %r21; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r339, %r338, 4, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r340, %r338, %r339; + .loc 2 574 22 // triton_helpers.py:574:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + setp.lt.s32 %p104, %r331, %r334; + .loc 2 591 21 // triton_helpers.py:591:21 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + setp.eq.b32 %p105, %r331, %r334; + .loc 2 594 40 // triton_helpers.py:594:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + setp.gt.s32 %p106, %r337, %r340; + .loc 2 594 29 // triton_helpers.py:594:29 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + and.pred %p107, %p105, %p106; + .loc 2 594 23 // triton_helpers.py:594:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + or.pred %p108, %p104, %p107; + .loc 2 600 38 // triton_helpers.py:600:38 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r341, %r331, %r334; + .loc 2 600 46 // triton_helpers.py:600:46 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + selp.b32 %r342, %r341, 0, %p108; + .loc 2 600 15 // triton_helpers.py:600:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r343, %r342, %r325; + .loc 2 601 48 // triton_helpers.py:601:48 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r344, %r337, %r340; + .loc 2 601 59 // triton_helpers.py:601:59 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + selp.b32 %r345, %r344, 0, %p108; + .loc 2 601 22 // triton_helpers.py:601:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r346, %r345, %r328; + .loc 2 538 40 // triton_helpers.py:538:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r347, %r343, %r25; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r348, %r347, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r349, %r347, %r348; + .loc 2 539 41 // triton_helpers.py:539:41 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r350, %r343, %r18; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r351, %r350, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r352, %r350, %r351; + .loc 2 548 23 // triton_helpers.py:548:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r353, %r346, %r25; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r354, %r353, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r355, %r353, %r354; + .loc 2 551 23 // triton_helpers.py:551:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r356, %r346, %r18; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r357, %r356, 2, 31, -1; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + add.s32 %r358, %r356, %r357; + .loc 2 574 22 // triton_helpers.py:574:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + setp.lt.s32 %p109, %r349, %r352; + .loc 2 591 21 // triton_helpers.py:591:21 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + setp.eq.b32 %p110, %r349, %r352; + .loc 2 594 40 // triton_helpers.py:594:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + setp.gt.s32 %p111, %r355, %r358; + .loc 2 594 29 // triton_helpers.py:594:29 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + and.pred %p112, %p110, %p111; + .loc 2 594 23 // triton_helpers.py:594:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + or.pred %p113, %p109, %p112; + .loc 2 600 38 // triton_helpers.py:600:38 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r359, %r349, %r352; + .loc 2 600 46 // triton_helpers.py:600:46 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + selp.b32 %r360, %r359, 0, %p113; + .loc 2 600 15 // triton_helpers.py:600:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r361, %r360, %r343; + .loc 2 601 48 // triton_helpers.py:601:48 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r362, %r355, %r358; + .loc 2 601 59 // triton_helpers.py:601:59 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + selp.b32 %r363, %r362, 0, %p113; + .loc 2 601 22 // triton_helpers.py:601:22 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + xor.b32 %r364, %r363, %r346; + .loc 2 538 40 // triton_helpers.py:538:40 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r365, %r361, %r24; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r366, %r365, 1, 31, -1; + .loc 2 539 41 // triton_helpers.py:539:41 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r368, %r361, %r19; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r369, %r368, 1, 31, -1; + .loc 2 548 23 // triton_helpers.py:548:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r371, %r364, %r24; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r372, %r371, 1, 31, -1; + .loc 2 551 23 // triton_helpers.py:551:23 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + mul.lo.s32 %r374, %r364, %r19; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:51:71 ] + shfl.sync.bfly.b32 %r375, %r374, 1, 31, -1; +$L__tmp4: + .loc 1 54 35 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:54:35 + and.pred %p117, %p2, %p3; + selp.b64 %rd18, 1, 0, %p117; + and.pred %p118, %p2, %p4; + selp.b64 %rd19, 1, 0, %p118; +$L__tmp5: + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:55:26 ] + selp.b32 %r381, 1, 0, %p117; + shfl.sync.bfly.b32 %r382, %r381, 8, 31, -1; + mov.b32 %r383, 0; + shfl.sync.bfly.b32 %r384, %r383, 8, 31, -1; + cvt.u64.u32 %rd20, %r382; + cvt.u64.u32 %rd21, %r384; + shl.b64 %rd22, %rd21, 32; + or.b64 %rd23, %rd20, %rd22; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:55:26 ] + add.s64 %rd24, %rd23, %rd18; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:55:26 ] + mov.b64 {_, %r385}, %rd24; + cvt.u32.u64 %r386, %rd24; + shfl.sync.bfly.b32 %r387, %r386, 4, 31, -1; + shfl.sync.bfly.b32 %r388, %r385, 4, 31, -1; + cvt.u64.u32 %rd25, %r387; + cvt.u64.u32 %rd26, %r388; + shl.b64 %rd27, %rd26, 32; + or.b64 %rd28, %rd25, %rd27; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:55:26 ] + add.s64 %rd29, %rd24, %rd28; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:55:26 ] + mov.b64 {_, %r389}, %rd29; + cvt.u32.u64 %r390, %rd29; + shfl.sync.bfly.b32 %r391, %r390, 2, 31, -1; + shfl.sync.bfly.b32 %r392, %r389, 2, 31, -1; + cvt.u64.u32 %rd30, %r391; + cvt.u64.u32 %rd31, %r392; + shl.b64 %rd32, %rd31, 32; + or.b64 %rd33, %rd30, %rd32; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:55:26 ] + add.s64 %rd34, %rd29, %rd33; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:55:26 ] + mov.b64 {_, %r393}, %rd34; + cvt.u32.u64 %r394, %rd34; + shfl.sync.bfly.b32 %r395, %r394, 1, 31, -1; + shfl.sync.bfly.b32 %r396, %r393, 1, 31, -1; + cvt.u64.u32 %rd35, %r395; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:55:26 ] + add.s64 %rd36, %rd34, %rd35; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:55:26 ] + selp.b32 %r397, 1, 0, %p118; + shfl.sync.bfly.b32 %r398, %r397, 8, 31, -1; + shfl.sync.bfly.b32 %r399, %r383, 8, 31, -1; + cvt.u64.u32 %rd37, %r398; + cvt.u64.u32 %rd38, %r399; + shl.b64 %rd39, %rd38, 32; + or.b64 %rd40, %rd37, %rd39; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:55:26 ] + add.s64 %rd41, %rd40, %rd19; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:55:26 ] + mov.b64 {_, %r400}, %rd41; + cvt.u32.u64 %r401, %rd41; + shfl.sync.bfly.b32 %r402, %r401, 4, 31, -1; + shfl.sync.bfly.b32 %r403, %r400, 4, 31, -1; + cvt.u64.u32 %rd42, %r402; + cvt.u64.u32 %rd43, %r403; + shl.b64 %rd44, %rd43, 32; + or.b64 %rd45, %rd42, %rd44; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:55:26 ] + add.s64 %rd46, %rd41, %rd45; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:55:26 ] + mov.b64 {_, %r404}, %rd46; + cvt.u32.u64 %r405, %rd46; + shfl.sync.bfly.b32 %r406, %r405, 2, 31, -1; + shfl.sync.bfly.b32 %r407, %r404, 2, 31, -1; + cvt.u64.u32 %rd47, %r406; + cvt.u64.u32 %rd48, %r407; + shl.b64 %rd49, %rd48, 32; + or.b64 %rd50, %rd47, %rd49; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:55:26 ] + add.s64 %rd2, %rd46, %rd50; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:55:26 ] + mov.b64 {_, %r408}, %rd2; + cvt.u32.u64 %r409, %rd2; + shfl.sync.bfly.b32 %r6, %r409, 1, 31, -1; + shfl.sync.bfly.b32 %r7, %r408, 1, 31, -1; +$L__tmp6: + .loc 1 58 35 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:58:35 + and.pred %p119, %p2, %p62; + selp.b64 %rd51, 1, 0, %p119; + and.pred %p120, %p2, %p63; + selp.b64 %rd52, 1, 0, %p120; +$L__tmp7: + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:59:26 ] + selp.b32 %r410, 1, 0, %p119; + shfl.sync.bfly.b32 %r411, %r410, 8, 31, -1; + shfl.sync.bfly.b32 %r412, %r383, 8, 31, -1; + cvt.u64.u32 %rd53, %r411; + cvt.u64.u32 %rd54, %r412; + shl.b64 %rd55, %rd54, 32; + or.b64 %rd56, %rd53, %rd55; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:59:26 ] + add.s64 %rd57, %rd56, %rd51; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:59:26 ] + mov.b64 {_, %r413}, %rd57; + cvt.u32.u64 %r414, %rd57; + shfl.sync.bfly.b32 %r415, %r414, 4, 31, -1; + shfl.sync.bfly.b32 %r416, %r413, 4, 31, -1; + cvt.u64.u32 %rd58, %r415; + cvt.u64.u32 %rd59, %r416; + shl.b64 %rd60, %rd59, 32; + or.b64 %rd61, %rd58, %rd60; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:59:26 ] + add.s64 %rd62, %rd57, %rd61; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:59:26 ] + mov.b64 {_, %r417}, %rd62; + cvt.u32.u64 %r418, %rd62; + shfl.sync.bfly.b32 %r419, %r418, 2, 31, -1; + shfl.sync.bfly.b32 %r420, %r417, 2, 31, -1; + cvt.u64.u32 %rd63, %r419; + cvt.u64.u32 %rd64, %r420; + shl.b64 %rd65, %rd64, 32; + or.b64 %rd66, %rd63, %rd65; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:59:26 ] + add.s64 %rd3, %rd62, %rd66; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:59:26 ] + mov.b64 {_, %r421}, %rd3; + cvt.u32.u64 %r422, %rd3; + shfl.sync.bfly.b32 %r8, %r422, 1, 31, -1; + shfl.sync.bfly.b32 %r9, %r421, 1, 31, -1; + selp.b32 %r423, 1, 0, %p120; + shfl.sync.bfly.b32 %r424, %r423, 8, 31, -1; + shfl.sync.bfly.b32 %r425, %r383, 8, 31, -1; + cvt.u64.u32 %rd67, %r424; + cvt.u64.u32 %rd68, %r425; + shl.b64 %rd69, %rd68, 32; + or.b64 %rd70, %rd67, %rd69; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:59:26 ] + add.s64 %rd71, %rd70, %rd52; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:59:26 ] + mov.b64 {_, %r426}, %rd71; + cvt.u32.u64 %r427, %rd71; + shfl.sync.bfly.b32 %r428, %r427, 4, 31, -1; + shfl.sync.bfly.b32 %r429, %r426, 4, 31, -1; + cvt.u64.u32 %rd72, %r428; + cvt.u64.u32 %rd73, %r429; + shl.b64 %rd74, %rd73, 32; + or.b64 %rd75, %rd72, %rd74; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:59:26 ] + add.s64 %rd76, %rd71, %rd75; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:59:26 ] + mov.b64 {_, %r430}, %rd76; + cvt.u32.u64 %r431, %rd76; + shfl.sync.bfly.b32 %r432, %r431, 2, 31, -1; + shfl.sync.bfly.b32 %r433, %r430, 2, 31, -1; + cvt.u64.u32 %rd77, %r432; + cvt.u64.u32 %rd78, %r433; + shl.b64 %rd79, %rd78, 32; + or.b64 %rd80, %rd77, %rd79; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:59:26 ] + add.s64 %rd4, %rd76, %rd80; + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:59:26 ] + mov.b64 {_, %r434}, %rd4; + cvt.u32.u64 %r435, %rd4; + shfl.sync.bfly.b32 %r10, %r435, 1, 31, -1; + shfl.sync.bfly.b32 %r11, %r434, 1, 31, -1; +$L__tmp8: + .loc 1 60 21 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:60:21 + cvt.u32.u64 %r436, %rd36; + .loc 1 64 19 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:64:19 + setp.lt.s32 %p121, %r3, %r436; + .loc 1 66 35 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:66:35 + selp.b32 %r437, %r4, 16, %p121; + .loc 1 68 20 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:68:20 + add.s32 %r438, %r437, 17; + .loc 1 69 20 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:69:20 + setp.lt.s32 %p122, %r437, 0; + .loc 1 70 35 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:70:35 + selp.b32 %r12, %r438, %r437, %p122; + .loc 1 71 38 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:71:38 + setp.lt.u32 %p123, %r12, 17; + .loc 1 0 0 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:0 + setp.gt.u32 %p124, %r1, 31; + .loc 1 71 53 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:71:53 + or.pred %p125, %p124, %p123; + .loc 1 71 63 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:71:63 + @%p125 bra $L__BB0_2; + bra.uni $L__BB0_1; +$L__BB0_2: +$L__tmp9: + .loc 1 18 0 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:18 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:0 ] + add.s32 %r367, %r365, %r366; + add.s32 %r370, %r368, %r369; + add.s32 %r373, %r371, %r372; + add.s32 %r376, %r374, %r375; + setp.lt.s32 %p114, %r367, %r370; + setp.eq.b32 %p115, %r367, %r370; + setp.gt.s32 %p116, %r373, %r376; + xor.b32 %r377, %r373, %r376; + selp.b32 %r378, %r377, 0, %p116; + selp.b32 %r379, %r378, 0, %p115; + selp.b32 %r380, %r377, %r379, %p114; + xor.b32 %r5, %r380, %r364; +$L__tmp10: + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:59:26 ] + cvt.u64.u32 %rd87, %r8; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:59:26 ] + add.s64 %rd88, %rd3, %rd87; +$L__tmp11: + .loc 1 61 21 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:61:21 + cvt.u32.u64 %r439, %rd88; + .loc 1 71 63 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:71:63 + bar.sync 0; + .loc 1 75 19 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:75:19 + setp.lt.s32 %p127, %r3, %r439; + .loc 1 76 35 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:76:35 + selp.b32 %r440, %r5, 16, %p127; + .loc 1 77 20 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:77:20 + add.s32 %r441, %r440, 17; + .loc 1 78 20 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:78:20 + setp.lt.s32 %p128, %r440, 0; + .loc 1 79 35 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:79:35 + selp.b32 %r13, %r441, %r440, %p128; + .loc 1 80 38 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:80:38 + setp.lt.u32 %p129, %r13, 17; + .loc 1 80 53 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:80:53 + or.pred %p130, %p124, %p129; + .loc 1 80 63 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:80:63 + @%p130 bra $L__BB0_4; + bra.uni $L__BB0_3; +$L__BB0_4: + .loc 1 0 63 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:0:63 + ld.param.b64 %rd10, [triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_6]; + ld.param.b64 %rd9, [triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_5]; + ld.param.b64 %rd8, [triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_4]; + ld.param.b64 %rd7, [triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_3]; + ld.param.b64 %rd6, [triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_2]; + ld.param.b64 %rd5, [triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2_param_1]; + cvt.s64.s32 %rd1, %r15; +$L__tmp12: + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:59:26 ] + cvt.u64.u32 %rd101, %r10; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:59:26 ] + add.s64 %rd102, %rd4, %rd101; +$L__tmp13: + .loc 1 61 21 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:61:21 + cvt.u32.u64 %r443, %rd102; +$L__tmp14: + .loc 3 291 36 // standard.py:291:36 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:55:26 ] + cvt.u64.u32 %rd103, %r6; + .loc 3 261 15 // standard.py:261:15 @[ cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:55:26 ] + add.s64 %rd104, %rd2, %rd103; +$L__tmp15: + .loc 1 60 21 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:60:21 + cvt.u32.u64 %r442, %rd104; + .loc 1 80 63 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:80:63 + bar.sync 0; + .loc 1 81 25 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:81:25 + mul.wide.u32 %rd105, %r1, 4; + add.s64 %rd95, %rd5, %rd105; + .loc 1 81 37 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:81:37 + and.b32 %r448, %r2, 63; + setp.eq.b32 %p138, %r448, 0; + and.pred %p131, %p2, %p138; + // begin inline asm + @%p131 st.global.b32 [ %rd95 + 0 ], { %r442 }; + // end inline asm + .loc 1 82 25 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:82:25 + add.s64 %rd96, %rd6, %rd105; + .loc 1 82 37 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:82:37 + // begin inline asm + @%p131 st.global.b32 [ %rd96 + 0 ], { %r443 }; + // end inline asm + .loc 1 83 25 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:83:25 + shl.b64 %rd106, %rd1, 2; + add.s64 %rd97, %rd7, %rd106; + .loc 1 83 47 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:83:47 + and.b32 %r449, %r2, 48; + setp.eq.b32 %p139, %r449, 0; + and.pred %p133, %p2, %p139; + // begin inline asm + @%p133 st.global.b32 [ %rd97 + 0 ], { %r4 }; + // end inline asm + .loc 1 84 52 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:84:52 + mul.lo.s32 %r450, %r1, 17; + .loc 1 84 49 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:84:49 + add.s32 %r451, %r12, %r450; + .loc 1 84 25 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:84:25 + mad.wide.s32 %rd98, %r451, 4, %rd8; + mov.b32 %r445, 1; + .loc 1 84 85 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:84:85 + // begin inline asm + @%p133 st.global.b32 [ %rd98 + 0 ], { %r445 }; + // end inline asm + .loc 1 85 25 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:85:25 + add.s64 %rd99, %rd9, %rd106; + .loc 1 85 47 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:85:47 + // begin inline asm + @%p133 st.global.b32 [ %rd99 + 0 ], { %r5 }; + // end inline asm + .loc 1 86 49 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:86:49 + add.s32 %r452, %r13, %r450; + .loc 1 86 25 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:86:25 + mad.wide.s32 %rd100, %r452, 4, %rd10; + .loc 1 86 85 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:86:85 + // begin inline asm + @%p133 st.global.b32 [ %rd100 + 0 ], { %r445 }; + // end inline asm + .loc 1 86 4 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:86:4 + ret; +$L__BB0_1: + .loc 1 71 63 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:71:63 + { // callseq 0, 0 + .param .b64 param0; + .param .b64 param1; + .param .b32 param2; + .param .b64 param3; + .param .b64 param4; + mov.b64 %rd81, assertFunc_0; + cvta.global.u64 %rd82, %rd81; + st.param.b64 [param3], %rd82; + mov.b64 %rd83, assertFile_0; + cvta.global.u64 %rd84, %rd83; + st.param.b64 [param1], %rd84; + mov.b64 %rd85, assertMessage_0; + cvta.global.u64 %rd86, %rd85; + st.param.b64 [param0], %rd86; + st.param.b64 [param4], 1; + st.param.b32 [param2], 71; + call.uni __assertfail, (param0, param1, param2, param3, param4); + } // callseq 0 + trap; +$L__BB0_3: + .loc 1 80 63 // cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py:80:63 + { // callseq 1, 0 + .param .b64 param0; + .param .b64 param1; + .param .b32 param2; + .param .b64 param3; + .param .b64 param4; + mov.b64 %rd89, assertFunc_1; + cvta.global.u64 %rd90, %rd89; + st.param.b64 [param3], %rd90; + mov.b64 %rd91, assertFile_1; + cvta.global.u64 %rd92, %rd91; + st.param.b64 [param1], %rd92; + mov.b64 %rd93, assertMessage_1; + cvta.global.u64 %rd94, %rd93; + st.param.b64 [param0], %rd94; + st.param.b64 [param4], 1; + st.param.b32 [param2], 80; + call.uni __assertfail, (param0, param1, param2, param3, param4); + } // callseq 1 + trap; +$L__tmp16: +$L__func_end0: + // -- End function +} + .file 1 "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py" + .file 2 "/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py" + .file 3 "/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py" + .section .debug_abbrev + { +.b8 1 // Abbreviation Code +.b8 17 // DW_TAG_compile_unit +.b8 1 // DW_CHILDREN_yes +.b8 37 // DW_AT_producer +.b8 8 // DW_FORM_string +.b8 19 // DW_AT_language +.b8 5 // DW_FORM_data2 +.b8 3 // DW_AT_name +.b8 8 // DW_FORM_string +.b8 16 // DW_AT_stmt_list +.b8 6 // DW_FORM_data4 +.b8 27 // DW_AT_comp_dir +.b8 8 // DW_FORM_string +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 2 // Abbreviation Code +.b8 46 // DW_TAG_subprogram +.b8 0 // DW_CHILDREN_no +.b8 3 // DW_AT_name +.b8 8 // DW_FORM_string +.b8 32 // DW_AT_inline +.b8 11 // DW_FORM_data1 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 3 // Abbreviation Code +.b8 46 // DW_TAG_subprogram +.b8 1 // DW_CHILDREN_yes +.b8 17 // DW_AT_low_pc +.b8 1 // DW_FORM_addr +.b8 18 // DW_AT_high_pc +.b8 1 // DW_FORM_addr +.b8 49 // DW_AT_abstract_origin +.b8 19 // DW_FORM_ref4 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 4 // Abbreviation Code +.b8 29 // DW_TAG_inlined_subroutine +.b8 0 // DW_CHILDREN_no +.b8 49 // DW_AT_abstract_origin +.b8 19 // DW_FORM_ref4 +.b8 17 // DW_AT_low_pc +.b8 1 // DW_FORM_addr +.b8 18 // DW_AT_high_pc +.b8 1 // DW_FORM_addr +.b8 88 // DW_AT_call_file +.b8 11 // DW_FORM_data1 +.b8 89 // DW_AT_call_line +.b8 11 // DW_FORM_data1 +.b8 87 // DW_AT_call_column +.b8 11 // DW_FORM_data1 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 5 // Abbreviation Code +.b8 29 // DW_TAG_inlined_subroutine +.b8 0 // DW_CHILDREN_no +.b8 49 // DW_AT_abstract_origin +.b8 19 // DW_FORM_ref4 +.b8 17 // DW_AT_low_pc +.b8 1 // DW_FORM_addr +.b8 18 // DW_AT_high_pc +.b8 1 // DW_FORM_addr +.b8 88 // DW_AT_call_file +.b8 11 // DW_FORM_data1 +.b8 89 // DW_AT_call_line +.b8 11 // DW_FORM_data1 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 0 // EOM(3) + } + .section .debug_info + { +.b32 399 // Length of Unit +.b8 2 // DWARF version number +.b8 0 +.b32 .debug_abbrev // Offset Into Abbrev. Section +.b8 8 // Address Size (in bytes) +.b8 1 // Abbrev [1] 0xb:0x188 DW_TAG_compile_unit +.b8 116 // DW_AT_producer +.b8 114 +.b8 105 +.b8 116 +.b8 111 +.b8 110 +.b8 0 +.b8 2 // DW_AT_language +.b8 0 +.b8 99 // DW_AT_name +.b8 98 +.b8 115 +.b8 54 +.b8 53 +.b8 50 +.b8 105 +.b8 55 +.b8 99 +.b8 116 +.b8 53 +.b8 55 +.b8 117 +.b8 103 +.b8 120 +.b8 54 +.b8 118 +.b8 110 +.b8 99 +.b8 51 +.b8 53 +.b8 110 +.b8 54 +.b8 51 +.b8 116 +.b8 105 +.b8 110 +.b8 112 +.b8 122 +.b8 117 +.b8 54 +.b8 98 +.b8 97 +.b8 51 +.b8 122 +.b8 121 +.b8 109 +.b8 117 +.b8 102 +.b8 111 +.b8 100 +.b8 105 +.b8 103 +.b8 52 +.b8 104 +.b8 112 +.b8 122 +.b8 97 +.b8 114 +.b8 119 +.b8 99 +.b8 108 +.b8 46 +.b8 112 +.b8 121 +.b8 0 +.b32 .debug_line // DW_AT_stmt_list +.b8 47 // DW_AT_comp_dir +.b8 119 +.b8 111 +.b8 114 +.b8 107 +.b8 115 +.b8 112 +.b8 97 +.b8 99 +.b8 101 +.b8 47 +.b8 104 +.b8 97 +.b8 110 +.b8 114 +.b8 117 +.b8 105 +.b8 47 +.b8 83 +.b8 112 +.b8 101 +.b8 99 +.b8 70 +.b8 111 +.b8 114 +.b8 103 +.b8 101 +.b8 45 +.b8 101 +.b8 120 +.b8 116 +.b8 47 +.b8 99 +.b8 97 +.b8 99 +.b8 104 +.b8 101 +.b8 47 +.b8 99 +.b8 111 +.b8 109 +.b8 112 +.b8 105 +.b8 108 +.b8 101 +.b8 100 +.b8 95 +.b8 107 +.b8 101 +.b8 114 +.b8 110 +.b8 101 +.b8 108 +.b8 115 +.b8 47 +.b8 98 +.b8 115 +.b8 0 +.b8 2 // Abbrev [2] 0x8b:0x7a DW_TAG_subprogram +.b8 116 // DW_AT_name +.b8 114 +.b8 105 +.b8 116 +.b8 111 +.b8 110 +.b8 95 +.b8 112 +.b8 101 +.b8 114 +.b8 95 +.b8 102 +.b8 117 +.b8 115 +.b8 101 +.b8 100 +.b8 95 +.b8 95 +.b8 116 +.b8 111 +.b8 95 +.b8 99 +.b8 111 +.b8 112 +.b8 121 +.b8 95 +.b8 97 +.b8 114 +.b8 97 +.b8 110 +.b8 103 +.b8 101 +.b8 95 +.b8 98 +.b8 105 +.b8 116 +.b8 119 +.b8 105 +.b8 115 +.b8 101 +.b8 95 +.b8 97 +.b8 110 +.b8 100 +.b8 95 +.b8 101 +.b8 113 +.b8 95 +.b8 103 +.b8 116 +.b8 95 +.b8 105 +.b8 110 +.b8 100 +.b8 101 +.b8 120 +.b8 95 +.b8 112 +.b8 117 +.b8 116 +.b8 95 +.b8 108 +.b8 116 +.b8 95 +.b8 110 +.b8 101 +.b8 119 +.b8 95 +.b8 122 +.b8 101 +.b8 114 +.b8 111 +.b8 115 +.b8 95 +.b8 115 +.b8 99 +.b8 97 +.b8 108 +.b8 97 +.b8 114 +.b8 95 +.b8 116 +.b8 101 +.b8 110 +.b8 115 +.b8 111 +.b8 114 +.b8 95 +.b8 115 +.b8 111 +.b8 114 +.b8 116 +.b8 95 +.b8 115 +.b8 117 +.b8 109 +.b8 95 +.b8 117 +.b8 110 +.b8 115 +.b8 113 +.b8 117 +.b8 101 +.b8 101 +.b8 122 +.b8 101 +.b8 95 +.b8 118 +.b8 105 +.b8 101 +.b8 119 +.b8 95 +.b8 119 +.b8 104 +.b8 101 +.b8 114 +.b8 101 +.b8 95 +.b8 50 +.b8 0 +.b8 1 // DW_AT_inline +.b8 3 // Abbrev [3] 0x105:0x8d DW_TAG_subprogram +.b64 $L__func_begin0 // DW_AT_low_pc +.b64 $L__func_end0 // DW_AT_high_pc +.b32 139 // DW_AT_abstract_origin +.b8 4 // Abbrev [4] 0x11a:0x18 DW_TAG_inlined_subroutine +.b32 139 // DW_AT_abstract_origin +.b64 $L__tmp1 // DW_AT_low_pc +.b64 $L__tmp2 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 46 // DW_AT_call_line +.b8 71 // DW_AT_call_column +.b8 4 // Abbrev [4] 0x132:0x18 DW_TAG_inlined_subroutine +.b32 139 // DW_AT_abstract_origin +.b64 $L__tmp3 // DW_AT_low_pc +.b64 $L__tmp4 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 51 // DW_AT_call_line +.b8 71 // DW_AT_call_column +.b8 4 // Abbrev [4] 0x14a:0x18 DW_TAG_inlined_subroutine +.b32 139 // DW_AT_abstract_origin +.b64 $L__tmp5 // DW_AT_low_pc +.b64 $L__tmp15 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 55 // DW_AT_call_line +.b8 26 // DW_AT_call_column +.b8 4 // Abbrev [4] 0x162:0x18 DW_TAG_inlined_subroutine +.b32 139 // DW_AT_abstract_origin +.b64 $L__tmp7 // DW_AT_low_pc +.b64 $L__tmp13 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 59 // DW_AT_call_line +.b8 26 // DW_AT_call_column +.b8 5 // Abbrev [5] 0x17a:0x17 DW_TAG_inlined_subroutine +.b32 139 // DW_AT_abstract_origin +.b64 $L__tmp9 // DW_AT_low_pc +.b64 $L__tmp10 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 0 // DW_AT_call_line +.b8 0 // End Of Children Mark +.b8 0 // End Of Children Mark + } + .section .debug_macinfo { } diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/NFXDDTZIEQAGR5JBWWCXV73QZILXJMIVVJVPZH3CHAIULDAND5UQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.source b/SpecForge-ext/cache/compiled_kernels/triton/0/NFXDDTZIEQAGR5JBWWCXV73QZILXJMIVVJVPZH3CHAIULDAND5UQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.source new file mode 100644 index 0000000000000000000000000000000000000000..d32c1c734b5487a50e4ffa31c4dfb628e9ad1261 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/NFXDDTZIEQAGR5JBWWCXV73QZILXJMIVVJVPZH3CHAIULDAND5UQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.source @@ -0,0 +1,1397 @@ +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":18:0) +#loc90 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":640:0) +#loc94 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":607:0) +#loc102 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":518:0) +#loc140 = loc(unknown) +#loc165 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":285:0) +#loc169 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":260:0) +#loc174 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":86:0) +#loc178 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":63:0) +#loc187 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":131:0) +#loc192 = loc("in_ptr0"(#loc)) +#loc193 = loc("out_ptr4"(#loc)) +#loc194 = loc("out_ptr5"(#loc)) +#loc195 = loc("out_ptr6"(#loc)) +#loc196 = loc("out_ptr7"(#loc)) +#loc197 = loc("out_ptr8"(#loc)) +#loc198 = loc("out_ptr9"(#loc)) +#loc199 = loc("xnumel"(#loc)) +#loc200 = loc("r0_numel"(#loc)) +#loc255 = loc("x"(#loc90)) +#loc256 = loc("idxs"(#loc90)) +#loc257 = loc("x"(#loc94)) +#loc258 = loc("idxs"(#loc94)) +#loc263 = loc("x"(#loc102)) +#loc264 = loc("idxs"(#loc102)) +#loc265 = loc("flip"(#loc102)) +#loc321 = loc("input"(#loc165)) +#loc322 = loc("a"(#loc169)) +#loc323 = loc("b"(#loc169)) +#loc325 = loc("x"(#loc174)) +#loc326 = loc("x"(#loc178)) +#loc327 = loc("input"(#loc187)) +module { + tt.func public @triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %out_ptr4: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr4"(#loc)), %out_ptr5: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr5"(#loc)), %out_ptr6: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr6"(#loc)), %out_ptr7: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr7"(#loc)), %out_ptr8: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr8"(#loc)), %out_ptr9: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr9"(#loc)), %xnumel: i32 {tt.divisibility = 16 : i32} loc("xnumel"(#loc)), %r0_numel: i32 {tt.divisibility = 16 : i32} loc("r0_numel"(#loc))) attributes {noinline = false} { + %xnumel_0 = arith.constant 32 : i32 loc(#loc201) + %r0_numel_1 = arith.constant 16 : i32 loc(#loc202) + %xoffset = tt.get_program_id x : i32 loc(#loc203) + %xoffset_2 = arith.constant 1 : i32 loc(#loc204) + %xoffset_3 = arith.constant 1 : i32 loc(#loc204) + %xoffset_4 = arith.muli %xoffset, %xoffset_3 : i32 loc(#loc204) + %xindex = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32> loc(#loc205) + %xindex_5 = tt.expand_dims %xindex {axis = 1 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc206) + %xindex_6 = tt.splat %xoffset_4 : i32 -> tensor<1x1xi32> loc(#loc207) + %xindex_7 = arith.addi %xindex_6, %xindex_5 : tensor<1x1xi32> loc(#loc207) + %xmask = arith.constant dense<32> : tensor<1x1xi32> loc(#loc208) + %xmask_8 = arith.cmpi slt, %xindex_7, %xmask : tensor<1x1xi32> loc(#loc208) + %r0_index = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc209) + %r0_index_9 = tt.expand_dims %r0_index {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> loc(#loc210) + %r0_offset = arith.constant 0 : i32 loc(#loc211) + %r0_mask = arith.constant true loc(#loc212) + %r0_mask_10 = arith.constant dense : tensor<1x16xi1> loc(#loc212) + %tmp0 = arith.constant 16 : i32 loc(#loc213) + %tmp0_11 = arith.constant 16 : i32 loc(#loc213) + %tmp0_12 = arith.constant dense<16> : tensor<1x1xi32> loc(#loc213) + %tmp0_13 = arith.muli %tmp0_12, %xindex_7 : tensor<1x1xi32> loc(#loc213) + %tmp0_14 = tt.broadcast %tmp0_13 : tensor<1x1xi32> -> tensor<1x16xi32> loc(#loc214) + %tmp0_15 = arith.addi %r0_index_9, %tmp0_14 : tensor<1x16xi32> loc(#loc214) + %tmp0_16 = tt.splat %in_ptr0 : !tt.ptr -> tensor<1x16x!tt.ptr> loc(#loc215) + %tmp0_17 = tt.addptr %tmp0_16, %tmp0_15 : tensor<1x16x!tt.ptr>, tensor<1x16xi32> loc(#loc215) + %tmp0_18 = arith.constant 0.000000e+00 : f32 loc(#loc216) + %tmp0_19 = tt.broadcast %xmask_8 : tensor<1x1xi1> -> tensor<1x16xi1> loc(#loc216) + %tmp0_20 = arith.constant dense<0.000000e+00> : tensor<1x16xf32> loc(#loc216) + %tmp0_21 = arith.fptosi %tmp0_20 : tensor<1x16xf32> to tensor<1x16xi64> loc(#loc216) + %tmp0_22 = tt.load %tmp0_17, %tmp0_19, %tmp0_21 : tensor<1x16x!tt.ptr> loc(#loc216) + %tmp1 = arith.constant 0 : i64 loc(#loc217) + %tmp1_23 = arith.constant dense<0> : tensor<1x1xi64> loc(#loc217) + %tmp2 = arith.constant dense<0> : tensor<1x16xi64> loc(#loc218) + %tmp2_24 = arith.cmpi sgt, %tmp0_22, %tmp2 : tensor<1x16xi64> loc(#loc218) + %tmp3 = arith.constant 16384 : i64 loc(#loc219) + %tmp3_25 = arith.constant dense<16384> : tensor<1x1xi64> loc(#loc219) + %tmp4 = arith.constant dense<16384> : tensor<1x16xi64> loc(#loc220) + %tmp4_26 = arith.cmpi slt, %tmp0_22, %tmp4 : tensor<1x16xi64> loc(#loc220) + %tmp5 = arith.andi %tmp2_24, %tmp4_26 : tensor<1x16xi1> loc(#loc221) + %tmp6 = arith.extui %tmp5 : tensor<1x16xi1> to tensor<1x16xi8> loc(#loc222) + %tmp7 = arith.extsi %tmp6 : tensor<1x16xi8> to tensor<1x16xi32> loc(#loc223) + %tmp9 = arith.trunci %r0_index_9 : tensor<1x16xi32> to tensor<1x16xi16> loc(#loc224) + %0:2 = tt.call @"torch._inductor.runtime.triton_helpers.sort_with_index__i32S1_16S_i16S1_16S__(2,)cconstexpr_None__(3,)cconstexpr_1__(4,)cconstexpr_True__(5,)cconstexpr_True_"(%tmp7, %tmp9) : (tensor<1x16xi32>, tensor<1x16xi16>) -> (tensor<1x16xi32>, tensor<1x16xi32>) loc(#loc25) + %tmp14 = arith.constant dense<16384> : tensor<1x16xi64> loc(#loc225) + %tmp14_27 = arith.cmpi eq, %tmp0_22, %tmp14 : tensor<1x16xi64> loc(#loc225) + %tmp15 = arith.extui %tmp14_27 : tensor<1x16xi1> to tensor<1x16xi8> loc(#loc226) + %tmp16 = arith.extsi %tmp15 : tensor<1x16xi8> to tensor<1x16xi32> loc(#loc227) + %1:2 = tt.call @"torch._inductor.runtime.triton_helpers.sort_with_index__i32S1_16S_i16S1_16S__(2,)cconstexpr_None__(3,)cconstexpr_1__(4,)cconstexpr_True__(5,)cconstexpr_True_"(%tmp16, %tmp9) : (tensor<1x16xi32>, tensor<1x16xi16>) -> (tensor<1x16xi32>, tensor<1x16xi32>) loc(#loc29) + %tmp20 = arith.extsi %tmp7 : tensor<1x16xi32> to tensor<1x16xi64> loc(#loc228) + %tmp23 = arith.constant 0 : i32 loc(#loc229) + %tmp23_28 = arith.constant 0 : i64 loc(#loc229) + %tmp23_29 = arith.constant dense<0> : tensor<1x16xi64> loc(#loc229) + %tmp23_30 = tt.broadcast %xmask_8 : tensor<1x1xi1> -> tensor<1x16xi1> loc(#loc229) + %tmp23_31 = arith.select %tmp23_30, %tmp20, %tmp23_29 : tensor<1x16xi1>, tensor<1x16xi64> loc(#loc229) + %tmp24 = tt.call @"triton.language.standard.sum__i64S1_16S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%tmp23_31) : (tensor<1x16xi64>) -> tensor<1xi64> loc(#loc230) + %tmp24_32 = tt.expand_dims %tmp24 {axis = 1 : i32} : tensor<1xi64> -> tensor<1x1xi64> loc(#loc231) + %tmp25 = arith.extsi %tmp16 : tensor<1x16xi32> to tensor<1x16xi64> loc(#loc232) + %tmp28 = arith.constant 0 : i32 loc(#loc233) + %tmp28_33 = arith.constant 0 : i64 loc(#loc233) + %tmp28_34 = arith.constant dense<0> : tensor<1x16xi64> loc(#loc233) + %tmp28_35 = tt.broadcast %xmask_8 : tensor<1x1xi1> -> tensor<1x16xi1> loc(#loc233) + %tmp28_36 = arith.select %tmp28_35, %tmp25, %tmp28_34 : tensor<1x16xi1>, tensor<1x16xi64> loc(#loc233) + %tmp29 = tt.call @"triton.language.standard.sum__i64S1_16S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%tmp28_36) : (tensor<1x16xi64>) -> tensor<1xi64> loc(#loc234) + %tmp29_37 = tt.expand_dims %tmp29 {axis = 1 : i32} : tensor<1xi64> -> tensor<1x1xi64> loc(#loc235) + %tmp30 = arith.trunci %tmp24_32 : tensor<1x1xi64> to tensor<1x1xi32> loc(#loc236) + %tmp31 = arith.trunci %tmp29_37 : tensor<1x1xi64> to tensor<1x1xi32> loc(#loc237) + %tmp32 = arith.extsi %0#1 : tensor<1x16xi32> to tensor<1x16xi64> loc(#loc238) + %tmp33 = arith.trunci %tmp32 : tensor<1x16xi64> to tensor<1x16xi32> loc(#loc239) + %tmp34 = tt.broadcast %tmp30 : tensor<1x1xi32> -> tensor<1x16xi32> loc(#loc240) + %tmp34_38 = arith.cmpi slt, %r0_index_9, %tmp34 : tensor<1x16xi32> loc(#loc240) + %tmp35 = arith.constant 16 : i32 loc(#loc241) + %tmp35_39 = arith.constant dense<16> : tensor<1x1xi32> loc(#loc241) + %tmp36 = arith.constant dense<16> : tensor<1x16xi32> loc(#loc242) + %tmp36_40 = arith.select %tmp34_38, %tmp33, %tmp36 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc242) + %tmp37 = arith.constant 17 : i32 loc(#loc243) + %tmp37_41 = arith.constant dense<17> : tensor<1x16xi32> loc(#loc243) + %tmp38 = arith.addi %tmp36_40, %tmp37_41 : tensor<1x16xi32> loc(#loc244) + %tmp39 = arith.constant 0 : i32 loc(#loc245) + %tmp39_42 = arith.constant dense<0> : tensor<1x16xi32> loc(#loc245) + %tmp39_43 = arith.cmpi slt, %tmp36_40, %tmp39_42 : tensor<1x16xi32> loc(#loc245) + %tmp40 = arith.select %tmp39_43, %tmp38, %tmp36_40 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc246) + %c0_i32 = arith.constant 0 : i32 loc(#loc49) + %cst = arith.constant dense<0> : tensor<1x16xi32> loc(#loc49) + %2 = arith.cmpi sle, %cst, %tmp40 : tensor<1x16xi32> loc(#loc49) + %c17_i32 = arith.constant 17 : i32 loc(#loc50) + %cst_44 = arith.constant dense<17> : tensor<1x16xi32> loc(#loc50) + %3 = arith.cmpi slt, %tmp40, %cst_44 : tensor<1x16xi32> loc(#loc50) + %4 = arith.andi %2, %3 : tensor<1x16xi1> loc(#loc51) + %true = arith.constant true loc(#loc52) + %cst_45 = arith.constant dense : tensor<1x1xi1> loc(#loc52) + %5 = arith.xori %xmask_8, %cst_45 : tensor<1x1xi1> loc(#loc52) + %6 = tt.broadcast %5 : tensor<1x1xi1> -> tensor<1x16xi1> loc(#loc53) + %7 = arith.ori %4, %6 : tensor<1x16xi1> loc(#loc53) + tt.assert %7, "index out of bounds: 0 <= tmp40 < 17" : tensor<1x16xi1> loc(#loc54) + %tmp42 = arith.constant 1 : i32 loc(#loc247) + %tmp42_46 = arith.constant dense<1> : tensor<1x1xi32> loc(#loc247) + %tmp43 = arith.extsi %1#1 : tensor<1x16xi32> to tensor<1x16xi64> loc(#loc248) + %tmp44 = arith.trunci %tmp43 : tensor<1x16xi64> to tensor<1x16xi32> loc(#loc249) + %tmp45 = tt.broadcast %tmp31 : tensor<1x1xi32> -> tensor<1x16xi32> loc(#loc250) + %tmp45_47 = arith.cmpi slt, %r0_index_9, %tmp45 : tensor<1x16xi32> loc(#loc250) + %tmp46 = arith.constant dense<16> : tensor<1x16xi32> loc(#loc251) + %tmp46_48 = arith.select %tmp45_47, %tmp44, %tmp46 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc251) + %tmp47 = arith.addi %tmp46_48, %tmp37_41 : tensor<1x16xi32> loc(#loc252) + %tmp48 = arith.constant 0 : i32 loc(#loc253) + %tmp48_49 = arith.constant dense<0> : tensor<1x16xi32> loc(#loc253) + %tmp48_50 = arith.cmpi slt, %tmp46_48, %tmp48_49 : tensor<1x16xi32> loc(#loc253) + %tmp49 = arith.select %tmp48_50, %tmp47, %tmp46_48 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc254) + %c0_i32_51 = arith.constant 0 : i32 loc(#loc63) + %cst_52 = arith.constant dense<0> : tensor<1x16xi32> loc(#loc63) + %8 = arith.cmpi sle, %cst_52, %tmp49 : tensor<1x16xi32> loc(#loc63) + %c17_i32_53 = arith.constant 17 : i32 loc(#loc64) + %cst_54 = arith.constant dense<17> : tensor<1x16xi32> loc(#loc64) + %9 = arith.cmpi slt, %tmp49, %cst_54 : tensor<1x16xi32> loc(#loc64) + %10 = arith.andi %8, %9 : tensor<1x16xi1> loc(#loc65) + %true_55 = arith.constant true loc(#loc66) + %cst_56 = arith.constant dense : tensor<1x1xi1> loc(#loc66) + %11 = arith.xori %xmask_8, %cst_56 : tensor<1x1xi1> loc(#loc66) + %12 = tt.broadcast %11 : tensor<1x1xi1> -> tensor<1x16xi1> loc(#loc67) + %13 = arith.ori %10, %12 : tensor<1x16xi1> loc(#loc67) + tt.assert %13, "index out of bounds: 0 <= tmp49 < 17" : tensor<1x16xi1> loc(#loc68) + %14 = tt.splat %out_ptr4 : !tt.ptr -> tensor<1x1x!tt.ptr> loc(#loc69) + %15 = tt.addptr %14, %xindex_7 : tensor<1x1x!tt.ptr>, tensor<1x1xi32> loc(#loc69) + tt.store %15, %tmp30, %xmask_8 : tensor<1x1x!tt.ptr> loc(#loc70) + %16 = tt.splat %out_ptr5 : !tt.ptr -> tensor<1x1x!tt.ptr> loc(#loc71) + %17 = tt.addptr %16, %xindex_7 : tensor<1x1x!tt.ptr>, tensor<1x1xi32> loc(#loc71) + tt.store %17, %tmp31, %xmask_8 : tensor<1x1x!tt.ptr> loc(#loc72) + %c16_i32 = arith.constant 16 : i32 loc(#loc73) + %c16_i32_57 = arith.constant 16 : i32 loc(#loc73) + %cst_58 = arith.constant dense<16> : tensor<1x1xi32> loc(#loc73) + %18 = arith.muli %cst_58, %xindex_7 : tensor<1x1xi32> loc(#loc73) + %19 = tt.broadcast %18 : tensor<1x1xi32> -> tensor<1x16xi32> loc(#loc74) + %20 = arith.addi %r0_index_9, %19 : tensor<1x16xi32> loc(#loc74) + %21 = tt.splat %out_ptr6 : !tt.ptr -> tensor<1x16x!tt.ptr> loc(#loc75) + %22 = tt.addptr %21, %20 : tensor<1x16x!tt.ptr>, tensor<1x16xi32> loc(#loc75) + %23 = tt.broadcast %xmask_8 : tensor<1x1xi1> -> tensor<1x16xi1> loc(#loc76) + tt.store %22, %tmp33, %23 : tensor<1x16x!tt.ptr> loc(#loc76) + %c17_i32_59 = arith.constant 17 : i32 loc(#loc77) + %c17_i32_60 = arith.constant 17 : i32 loc(#loc77) + %cst_61 = arith.constant dense<17> : tensor<1x1xi32> loc(#loc77) + %24 = arith.muli %cst_61, %xindex_7 : tensor<1x1xi32> loc(#loc77) + %25 = tt.broadcast %24 : tensor<1x1xi32> -> tensor<1x16xi32> loc(#loc78) + %26 = arith.addi %tmp40, %25 : tensor<1x16xi32> loc(#loc78) + %27 = tt.splat %out_ptr7 : !tt.ptr -> tensor<1x16x!tt.ptr> loc(#loc79) + %28 = tt.addptr %27, %26 : tensor<1x16x!tt.ptr>, tensor<1x16xi32> loc(#loc79) + %cst_62 = arith.constant dense<1> : tensor<1x16xi32> loc(#loc80) + %29 = tt.broadcast %xmask_8 : tensor<1x1xi1> -> tensor<1x16xi1> loc(#loc80) + tt.store %28, %cst_62, %29 : tensor<1x16x!tt.ptr> loc(#loc80) + %c16_i32_63 = arith.constant 16 : i32 loc(#loc81) + %c16_i32_64 = arith.constant 16 : i32 loc(#loc81) + %cst_65 = arith.constant dense<16> : tensor<1x1xi32> loc(#loc81) + %30 = arith.muli %cst_65, %xindex_7 : tensor<1x1xi32> loc(#loc81) + %31 = tt.broadcast %30 : tensor<1x1xi32> -> tensor<1x16xi32> loc(#loc82) + %32 = arith.addi %r0_index_9, %31 : tensor<1x16xi32> loc(#loc82) + %33 = tt.splat %out_ptr8 : !tt.ptr -> tensor<1x16x!tt.ptr> loc(#loc83) + %34 = tt.addptr %33, %32 : tensor<1x16x!tt.ptr>, tensor<1x16xi32> loc(#loc83) + %35 = tt.broadcast %xmask_8 : tensor<1x1xi1> -> tensor<1x16xi1> loc(#loc84) + tt.store %34, %tmp44, %35 : tensor<1x16x!tt.ptr> loc(#loc84) + %c17_i32_66 = arith.constant 17 : i32 loc(#loc85) + %c17_i32_67 = arith.constant 17 : i32 loc(#loc85) + %cst_68 = arith.constant dense<17> : tensor<1x1xi32> loc(#loc85) + %36 = arith.muli %cst_68, %xindex_7 : tensor<1x1xi32> loc(#loc85) + %37 = tt.broadcast %36 : tensor<1x1xi32> -> tensor<1x16xi32> loc(#loc86) + %38 = arith.addi %tmp49, %37 : tensor<1x16xi32> loc(#loc86) + %39 = tt.splat %out_ptr9 : !tt.ptr -> tensor<1x16x!tt.ptr> loc(#loc87) + %40 = tt.addptr %39, %38 : tensor<1x16x!tt.ptr>, tensor<1x16xi32> loc(#loc87) + %cst_69 = arith.constant dense<1> : tensor<1x16xi32> loc(#loc88) + %41 = tt.broadcast %xmask_8 : tensor<1x1xi1> -> tensor<1x16xi1> loc(#loc88) + tt.store %40, %cst_69, %41 : tensor<1x16x!tt.ptr> loc(#loc88) + tt.return loc(#loc89) + } loc(#loc) + tt.func private @"torch._inductor.runtime.triton_helpers.sort_with_index__i32S1_16S_i16S1_16S__(2,)cconstexpr_None__(3,)cconstexpr_1__(4,)cconstexpr_True__(5,)cconstexpr_True_"(%x: tensor<1x16xi32> loc("x"(#loc90)), %idxs: tensor<1x16xi16> loc("idxs"(#loc90))) -> (tensor<1x16xi32>, tensor<1x16xi32>) attributes {noinline = false} { + %0:2 = tt.call @"torch._inductor.runtime.triton_helpers._bitonic_merge_with_index__i32S1_16S_i16S1_16S__(2,)cconstexpr_None__(3,)cconstexpr_1__(4,)cconstexpr_True__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x, %idxs) : (tensor<1x16xi32>, tensor<1x16xi16>) -> (tensor<1x16xi32>, tensor<1x16xi32>) loc(#loc91) + %1:2 = tt.call @"torch._inductor.runtime.triton_helpers._bitonic_merge_with_index__i32S1_16S_i32S1_16S__(2,)cconstexpr_None__(3,)cconstexpr_2__(4,)cconstexpr_True__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%0#0, %0#1) : (tensor<1x16xi32>, tensor<1x16xi32>) -> (tensor<1x16xi32>, tensor<1x16xi32>) loc(#loc91) + %2:2 = tt.call @"torch._inductor.runtime.triton_helpers._bitonic_merge_with_index__i32S1_16S_i32S1_16S__(2,)cconstexpr_None__(3,)cconstexpr_3__(4,)cconstexpr_True__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%1#0, %1#1) : (tensor<1x16xi32>, tensor<1x16xi32>) -> (tensor<1x16xi32>, tensor<1x16xi32>) loc(#loc91) + %3:2 = tt.call @"torch._inductor.runtime.triton_helpers._bitonic_merge_with_index__i32S1_16S_i32S1_16S__(2,)cconstexpr_None__(3,)cconstexpr_4__(4,)cconstexpr_False__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%2#0, %2#1) : (tensor<1x16xi32>, tensor<1x16xi32>) -> (tensor<1x16xi32>, tensor<1x16xi32>) loc(#loc91) + tt.return %3#0, %3#1 : tensor<1x16xi32>, tensor<1x16xi32> loc(#loc92) + ^bb1: // no predecessors + %4 = ub.poison : tensor<1x16xi32> loc(#loc93) + %5 = ub.poison : tensor<1x16xi32> loc(#loc93) + tt.return %4, %5 : tensor<1x16xi32>, tensor<1x16xi32> loc(#loc93) + } loc(#loc90) + tt.func private @"torch._inductor.runtime.triton_helpers._bitonic_merge_with_index__i32S1_16S_i16S1_16S__(2,)cconstexpr_None__(3,)cconstexpr_1__(4,)cconstexpr_True__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x: tensor<1x16xi32> loc("x"(#loc94)), %idxs: tensor<1x16xi16> loc("idxs"(#loc94))) -> (tensor<1x16xi32>, tensor<1x16xi32>) attributes {noinline = false} { + %flip = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc259) + %flip_0 = tt.expand_dims %flip {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc260) + %flip_1 = tt.expand_dims %flip_0 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc260) + %flip_2 = tt.broadcast %flip_1 : tensor<1x2x1xi32> -> tensor<4x2x2xi32> loc(#loc261) + %flip_3 = tt.reshape %flip_2 : tensor<4x2x2xi32> -> tensor<1x16xi32> loc(#loc262) + %0:2 = tt.call @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S1_16S_i16S1_16S_i32S1_16S__(2,)cconstexpr_None__(4,)cconstexpr_3__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x, %idxs, %flip_3) : (tensor<1x16xi32>, tensor<1x16xi16>, tensor<1x16xi32>) -> (tensor<1x16xi32>, tensor<1x16xi32>) loc(#loc99) + tt.return %0#0, %0#1 : tensor<1x16xi32>, tensor<1x16xi32> loc(#loc100) + ^bb1: // no predecessors + %1 = ub.poison : tensor<1x16xi32> loc(#loc101) + %2 = ub.poison : tensor<1x16xi32> loc(#loc101) + tt.return %1, %2 : tensor<1x16xi32>, tensor<1x16xi32> loc(#loc101) + } loc(#loc94) + tt.func private @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S1_16S_i16S1_16S_i32S1_16S__(2,)cconstexpr_None__(4,)cconstexpr_3__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x: tensor<1x16xi32> loc("x"(#loc102)), %idxs: tensor<1x16xi16> loc("idxs"(#loc102)), %flip: tensor<1x16xi32> loc("flip"(#loc102))) -> (tensor<1x16xi32>, tensor<1x16xi32>) attributes {noinline = false} { + %y = tt.reshape %x : tensor<1x16xi32> -> tensor<8x2x1xi32> loc(#loc266) + %right_mask = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc267) + %right_mask_0 = tt.expand_dims %right_mask {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc268) + %right_mask_1 = tt.expand_dims %right_mask_0 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc268) + %left_mask = arith.constant 1 : i32 loc(#loc269) + %left_mask_2 = arith.constant 1 : i32 loc(#loc269) + %left_mask_3 = arith.constant dense<1> : tensor<1x2x1xi32> loc(#loc269) + %left_mask_4 = arith.subi %left_mask_3, %right_mask_1 : tensor<1x2x1xi32> loc(#loc269) + %ileft = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<8x2x1xi32> loc(#loc270) + %ileft_5 = arith.muli %y, %ileft : tensor<8x2x1xi32> loc(#loc270) + %ileft_6 = tt.call @"triton.language.standard.sum__i32S8_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%ileft_5) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc271) + %ileft_7 = tt.expand_dims %ileft_6 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc272) + %ileft_8 = tt.broadcast %ileft_7 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc273) + %iright = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<8x2x1xi32> loc(#loc274) + %iright_9 = arith.muli %y, %iright : tensor<8x2x1xi32> loc(#loc274) + %iright_10 = tt.call @"triton.language.standard.sum__i32S8_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%iright_9) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc275) + %iright_11 = tt.expand_dims %iright_10 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc276) + %iright_12 = tt.broadcast %iright_11 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc277) + %ileft_13 = tt.reshape %ileft_8 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc278) + %iright_14 = tt.reshape %iright_12 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc279) + %y_idx = tt.reshape %idxs : tensor<1x16xi16> -> tensor<8x2x1xi16> loc(#loc280) + %left_idx = arith.trunci %left_mask_4 : tensor<1x2x1xi32> to tensor<1x2x1xi16> loc(#loc281) + %left_idx_15 = tt.broadcast %left_idx : tensor<1x2x1xi16> -> tensor<8x2x1xi16> loc(#loc282) + %left_idx_16 = arith.muli %y_idx, %left_idx_15 : tensor<8x2x1xi16> loc(#loc282) + %left_idx_17 = tt.call @"triton.language.standard.sum__i16S8_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%left_idx_16) : (tensor<8x2x1xi16>) -> tensor<8x1xi32> loc(#loc283) + %left_idx_18 = tt.expand_dims %left_idx_17 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc284) + %left_idx_19 = tt.broadcast %left_idx_18 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc285) + %right_idx = arith.trunci %right_mask_1 : tensor<1x2x1xi32> to tensor<1x2x1xi16> loc(#loc286) + %right_idx_20 = tt.broadcast %right_idx : tensor<1x2x1xi16> -> tensor<8x2x1xi16> loc(#loc287) + %right_idx_21 = arith.muli %y_idx, %right_idx_20 : tensor<8x2x1xi16> loc(#loc287) + %right_idx_22 = tt.call @"triton.language.standard.sum__i16S8_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%right_idx_21) : (tensor<8x2x1xi16>) -> tensor<8x1xi32> loc(#loc288) + %right_idx_23 = tt.expand_dims %right_idx_22 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc289) + %right_idx_24 = tt.broadcast %right_idx_23 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc290) + %left_idx_25 = tt.reshape %left_idx_19 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc291) + %right_idx_26 = tt.reshape %right_idx_24 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc292) + %left_valid_mask = arith.constant true loc(#loc293) + %left_valid_mask_27 = arith.constant dense : tensor<1x16xi1> loc(#loc293) + %right_valid_mask = arith.constant true loc(#loc294) + %right_valid_mask_28 = arith.constant dense : tensor<1x16xi1> loc(#loc294) + %left_isnan = arith.cmpi ne, %ileft_13, %ileft_13 : tensor<1x16xi32> loc(#loc295) + %right_isnan = arith.cmpi ne, %iright_14, %iright_14 : tensor<1x16xi32> loc(#loc296) + %cond = arith.cmpi slt, %ileft_13, %iright_14 : tensor<1x16xi32> loc(#loc329) + %0 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S1_16S__(%ileft_13) : (tensor<1x16xi32>) -> i1 loc(#loc135) + %1 = scf.if %0 -> (tensor<1x16xi1>) { + %cond_49 = arith.constant true loc(#loc298) + %cond_50 = arith.constant dense : tensor<1x16xi1> loc(#loc298) + %cond_51 = arith.xori %left_isnan, %cond_50 : tensor<1x16xi1> loc(#loc298) + %cond_52 = arith.andi %right_isnan, %cond_51 : tensor<1x16xi1> loc(#loc299) + %cond_53 = arith.ori %cond, %cond_52 : tensor<1x16xi1> loc(#loc330) + scf.yield %cond_53 : tensor<1x16xi1> loc(#loc330) + } else { + scf.yield %cond : tensor<1x16xi1> loc(#loc140) + } loc(#loc136) + %eq = arith.cmpi eq, %ileft_13, %iright_14 : tensor<1x16xi32> loc(#loc331) + %2 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S1_16S__(%ileft_13) : (tensor<1x16xi32>) -> i1 loc(#loc142) + %3 = scf.if %2 -> (tensor<1x16xi1>) { + %eq_49 = arith.andi %left_isnan, %right_isnan : tensor<1x16xi1> loc(#loc302) + %eq_50 = arith.ori %eq, %eq_49 : tensor<1x16xi1> loc(#loc332) + scf.yield %eq_50 : tensor<1x16xi1> loc(#loc332) + } else { + scf.yield %eq : tensor<1x16xi1> loc(#loc140) + } loc(#loc143) + %cond_29 = arith.cmpi sgt, %left_idx_25, %right_idx_26 : tensor<1x16xi32> loc(#loc304) + %cond_30 = arith.andi %3, %cond_29 : tensor<1x16xi1> loc(#loc305) + %cond_31 = arith.ori %1, %cond_30 : tensor<1x16xi1> loc(#loc306) + %cond_32 = arith.cmpi ugt, %right_valid_mask_28, %left_valid_mask_27 : tensor<1x16xi1> loc(#loc307) + %cond_33 = arith.cmpi eq, %right_valid_mask_28, %left_valid_mask_27 : tensor<1x16xi1> loc(#loc308) + %cond_34 = arith.andi %cond_33, %cond_31 : tensor<1x16xi1> loc(#loc309) + %cond_35 = arith.ori %cond_32, %cond_34 : tensor<1x16xi1> loc(#loc310) + %cond_36 = arith.extui %cond_35 : tensor<1x16xi1> to tensor<1x16xi32> loc(#loc311) + %cond_37 = arith.xori %cond_36, %flip : tensor<1x16xi32> loc(#loc311) + %cond_38 = arith.constant 0 : i32 loc(#loc312) + %cond_39 = arith.constant dense<0> : tensor<1x16xi32> loc(#loc312) + %cond_40 = arith.cmpi ne, %cond_37, %cond_39 : tensor<1x16xi32> loc(#loc312) + %ret = arith.xori %ileft_13, %iright_14 : tensor<1x16xi32> loc(#loc313) + %ret_41 = tt.call @triton.language.standard.zeros_like__i32S1_16S__(%x) : (tensor<1x16xi32>) -> tensor<1x16xi32> loc(#loc314) + %ret_42 = arith.select %cond_40, %ret, %ret_41 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc315) + %ret_43 = arith.xori %x, %ret_42 : tensor<1x16xi32> loc(#loc316) + %new_idxs = arith.xori %left_idx_25, %right_idx_26 : tensor<1x16xi32> loc(#loc317) + %new_idxs_44 = tt.call @triton.language.standard.zeros_like__i16S1_16S__(%idxs) : (tensor<1x16xi16>) -> tensor<1x16xi16> loc(#loc318) + %new_idxs_45 = arith.extsi %new_idxs_44 : tensor<1x16xi16> to tensor<1x16xi32> loc(#loc319) + %new_idxs_46 = arith.select %cond_40, %new_idxs, %new_idxs_45 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc319) + %new_idxs_47 = arith.extsi %idxs : tensor<1x16xi16> to tensor<1x16xi32> loc(#loc320) + %new_idxs_48 = arith.xori %new_idxs_47, %new_idxs_46 : tensor<1x16xi32> loc(#loc320) + tt.return %ret_43, %new_idxs_48 : tensor<1x16xi32>, tensor<1x16xi32> loc(#loc163) + ^bb1: // no predecessors + %4 = ub.poison : tensor<1x16xi32> loc(#loc164) + %5 = ub.poison : tensor<1x16xi32> loc(#loc164) + tt.return %4, %5 : tensor<1x16xi32>, tensor<1x16xi32> loc(#loc164) + } loc(#loc102) + tt.func private @"triton.language.standard.sum__i32S8_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%input: tensor<8x2x1xi32> loc("input"(#loc165))) -> tensor<8x1xi32> attributes {noinline = false} { + %0 = "tt.reduce"(%input) <{axis = 1 : i32}> ({ + ^bb0(%arg1: i32 loc(unknown), %arg2: i32 loc(unknown)): + %2 = tt.call @triton.language.standard._sum_combine__i32_i32__(%arg1, %arg2) : (i32, i32) -> i32 loc(#loc166) + tt.reduce.return %2 : i32 loc(#loc166) + }) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc166) + tt.return %0 : tensor<8x1xi32> loc(#loc167) + ^bb1: // no predecessors + %1 = ub.poison : tensor<8x1xi32> loc(#loc168) + tt.return %1 : tensor<8x1xi32> loc(#loc168) + } loc(#loc165) + tt.func private @triton.language.standard._sum_combine__i32_i32__(%a: i32 loc("a"(#loc169)), %b: i32 loc("b"(#loc169))) -> i32 attributes {noinline = false} { + %0 = arith.addi %a, %b : i32 loc(#loc170) + tt.return %0 : i32 loc(#loc171) + ^bb1: // no predecessors + %1 = ub.poison : i32 loc(#loc172) + tt.return %1 : i32 loc(#loc172) + } loc(#loc169) + tt.func private @"triton.language.standard.sum__i16S8_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%input: tensor<8x2x1xi16> loc("input"(#loc165))) -> tensor<8x1xi32> attributes {noinline = false} { + %input_0 = arith.extsi %input : tensor<8x2x1xi16> to tensor<8x2x1xi32> loc(#loc324) + %0 = "tt.reduce"(%input_0) <{axis = 1 : i32}> ({ + ^bb0(%arg1: i32 loc(unknown), %arg2: i32 loc(unknown)): + %2 = tt.call @triton.language.standard._sum_combine__i32_i32__(%arg1, %arg2) : (i32, i32) -> i32 loc(#loc166) + tt.reduce.return %2 : i32 loc(#loc166) + }) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc166) + tt.return %0 : tensor<8x1xi32> loc(#loc167) + ^bb1: // no predecessors + %1 = ub.poison : tensor<8x1xi32> loc(#loc168) + tt.return %1 : tensor<8x1xi32> loc(#loc168) + } loc(#loc165) + tt.func private @torch._inductor.runtime.triton_helpers.is_floating__i32S1_16S__(%x: tensor<1x16xi32> loc("x"(#loc174))) -> i1 attributes {noinline = false} { + %0 = tt.call @torch._inductor.runtime.triton_helpers.promote_to_tensor__i32S1_16S__(%x) : (tensor<1x16xi32>) -> tensor<1x16xi32> loc(#loc175) + %false = arith.constant false loc(#loc176) + tt.return %false : i1 loc(#loc176) + ^bb1: // no predecessors + %1 = ub.poison : i1 loc(#loc177) + tt.return %1 : i1 loc(#loc177) + } loc(#loc174) + tt.func private @torch._inductor.runtime.triton_helpers.promote_to_tensor__i32S1_16S__(%x: tensor<1x16xi32> loc("x"(#loc178))) -> tensor<1x16xi32> attributes {noinline = false} { + %0 = tt.call @"triton.language.standard.zeros____(0, 0)cconstexpr_1__(1,)cconstexpr_int1_"() : () -> tensor<1xi1> loc(#loc179) + %1 = arith.extui %0 : tensor<1xi1> to tensor<1xi32> loc(#loc180) + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc180) + %3 = tt.broadcast %2 : tensor<1x1xi32> -> tensor<1x16xi32> loc(#loc180) + %4 = arith.addi %x, %3 : tensor<1x16xi32> loc(#loc180) + tt.return %4 : tensor<1x16xi32> loc(#loc181) + ^bb1: // no predecessors + %5 = ub.poison : tensor<1x16xi32> loc(#loc182) + tt.return %5 : tensor<1x16xi32> loc(#loc182) + } loc(#loc178) + tt.func private @"triton.language.standard.zeros____(0, 0)cconstexpr_1__(1,)cconstexpr_int1_"() -> tensor<1xi1> attributes {noinline = false} { + %false = arith.constant false loc(#loc184) + %cst = arith.constant dense : tensor<1xi1> loc(#loc184) + tt.return %cst : tensor<1xi1> loc(#loc185) + ^bb1: // no predecessors + %0 = ub.poison : tensor<1xi1> loc(#loc186) + tt.return %0 : tensor<1xi1> loc(#loc186) + } loc(#loc183) + tt.func private @triton.language.standard.zeros_like__i32S1_16S__(%input: tensor<1x16xi32> loc("input"(#loc187))) -> tensor<1x16xi32> attributes {noinline = false} { + %0 = tt.call @"triton.language.standard.zeros____(0, 0)cconstexpr_1__(0, 1)cconstexpr_16__(1,)cconstexpr_int32_"() : () -> tensor<1x16xi32> loc(#loc188) + tt.return %0 : tensor<1x16xi32> loc(#loc189) + ^bb1: // no predecessors + %1 = ub.poison : tensor<1x16xi32> loc(#loc190) + tt.return %1 : tensor<1x16xi32> loc(#loc190) + } loc(#loc187) + tt.func private @"triton.language.standard.zeros____(0, 0)cconstexpr_1__(0, 1)cconstexpr_16__(1,)cconstexpr_int32_"() -> tensor<1x16xi32> attributes {noinline = false} { + %c0_i32 = arith.constant 0 : i32 loc(#loc184) + %cst = arith.constant dense<0> : tensor<1x16xi32> loc(#loc184) + tt.return %cst : tensor<1x16xi32> loc(#loc185) + ^bb1: // no predecessors + %0 = ub.poison : tensor<1x16xi32> loc(#loc186) + tt.return %0 : tensor<1x16xi32> loc(#loc186) + } loc(#loc183) + tt.func private @triton.language.standard.zeros_like__i16S1_16S__(%input: tensor<1x16xi16> loc("input"(#loc187))) -> tensor<1x16xi16> attributes {noinline = false} { + %0 = tt.call @"triton.language.standard.zeros____(0, 0)cconstexpr_1__(0, 1)cconstexpr_16__(1,)cconstexpr_int16_"() : () -> tensor<1x16xi16> loc(#loc188) + tt.return %0 : tensor<1x16xi16> loc(#loc189) + ^bb1: // no predecessors + %1 = ub.poison : tensor<1x16xi16> loc(#loc190) + tt.return %1 : tensor<1x16xi16> loc(#loc190) + } loc(#loc187) + tt.func private @"triton.language.standard.zeros____(0, 0)cconstexpr_1__(0, 1)cconstexpr_16__(1,)cconstexpr_int16_"() -> tensor<1x16xi16> attributes {noinline = false} { + %c0_i16 = arith.constant 0 : i16 loc(#loc184) + %cst = arith.constant dense<0> : tensor<1x16xi16> loc(#loc184) + tt.return %cst : tensor<1x16xi16> loc(#loc185) + ^bb1: // no predecessors + %0 = ub.poison : tensor<1x16xi16> loc(#loc186) + tt.return %0 : tensor<1x16xi16> loc(#loc186) + } loc(#loc183) + tt.func private @"torch._inductor.runtime.triton_helpers._bitonic_merge_with_index__i32S1_16S_i32S1_16S__(2,)cconstexpr_None__(3,)cconstexpr_2__(4,)cconstexpr_True__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x: tensor<1x16xi32> loc("x"(#loc94)), %idxs: tensor<1x16xi32> loc("idxs"(#loc94))) -> (tensor<1x16xi32>, tensor<1x16xi32>) attributes {noinline = false} { + %flip = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc259) + %flip_0 = tt.expand_dims %flip {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc260) + %flip_1 = tt.expand_dims %flip_0 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc260) + %flip_2 = tt.broadcast %flip_1 : tensor<1x2x1xi32> -> tensor<2x2x4xi32> loc(#loc261) + %flip_3 = tt.reshape %flip_2 : tensor<2x2x4xi32> -> tensor<1x16xi32> loc(#loc262) + %0:2 = tt.call @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S1_16S_i32S1_16S_i32S1_16S__(2,)cconstexpr_None__(4,)cconstexpr_2__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x, %idxs, %flip_3) : (tensor<1x16xi32>, tensor<1x16xi32>, tensor<1x16xi32>) -> (tensor<1x16xi32>, tensor<1x16xi32>) loc(#loc99) + %1:2 = tt.call @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S1_16S_i32S1_16S_i32S1_16S__(2,)cconstexpr_None__(4,)cconstexpr_3__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%0#0, %0#1, %flip_3) : (tensor<1x16xi32>, tensor<1x16xi32>, tensor<1x16xi32>) -> (tensor<1x16xi32>, tensor<1x16xi32>) loc(#loc99) + tt.return %1#0, %1#1 : tensor<1x16xi32>, tensor<1x16xi32> loc(#loc100) + ^bb1: // no predecessors + %2 = ub.poison : tensor<1x16xi32> loc(#loc101) + %3 = ub.poison : tensor<1x16xi32> loc(#loc101) + tt.return %2, %3 : tensor<1x16xi32>, tensor<1x16xi32> loc(#loc101) + } loc(#loc94) + tt.func private @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S1_16S_i32S1_16S_i32S1_16S__(2,)cconstexpr_None__(4,)cconstexpr_2__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x: tensor<1x16xi32> loc("x"(#loc102)), %idxs: tensor<1x16xi32> loc("idxs"(#loc102)), %flip: tensor<1x16xi32> loc("flip"(#loc102))) -> (tensor<1x16xi32>, tensor<1x16xi32>) attributes {noinline = false} { + %y = tt.reshape %x : tensor<1x16xi32> -> tensor<4x2x2xi32> loc(#loc266) + %right_mask = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc267) + %right_mask_0 = tt.expand_dims %right_mask {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc268) + %right_mask_1 = tt.expand_dims %right_mask_0 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc268) + %left_mask = arith.constant 1 : i32 loc(#loc269) + %left_mask_2 = arith.constant 1 : i32 loc(#loc269) + %left_mask_3 = arith.constant dense<1> : tensor<1x2x1xi32> loc(#loc269) + %left_mask_4 = arith.subi %left_mask_3, %right_mask_1 : tensor<1x2x1xi32> loc(#loc269) + %ileft = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<4x2x2xi32> loc(#loc270) + %ileft_5 = arith.muli %y, %ileft : tensor<4x2x2xi32> loc(#loc270) + %ileft_6 = tt.call @"triton.language.standard.sum__i32S4_2_2S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%ileft_5) : (tensor<4x2x2xi32>) -> tensor<4x2xi32> loc(#loc271) + %ileft_7 = tt.expand_dims %ileft_6 {axis = 1 : i32} : tensor<4x2xi32> -> tensor<4x1x2xi32> loc(#loc272) + %ileft_8 = tt.broadcast %ileft_7 : tensor<4x1x2xi32> -> tensor<4x2x2xi32> loc(#loc273) + %iright = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<4x2x2xi32> loc(#loc274) + %iright_9 = arith.muli %y, %iright : tensor<4x2x2xi32> loc(#loc274) + %iright_10 = tt.call @"triton.language.standard.sum__i32S4_2_2S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%iright_9) : (tensor<4x2x2xi32>) -> tensor<4x2xi32> loc(#loc275) + %iright_11 = tt.expand_dims %iright_10 {axis = 1 : i32} : tensor<4x2xi32> -> tensor<4x1x2xi32> loc(#loc276) + %iright_12 = tt.broadcast %iright_11 : tensor<4x1x2xi32> -> tensor<4x2x2xi32> loc(#loc277) + %ileft_13 = tt.reshape %ileft_8 : tensor<4x2x2xi32> -> tensor<1x16xi32> loc(#loc278) + %iright_14 = tt.reshape %iright_12 : tensor<4x2x2xi32> -> tensor<1x16xi32> loc(#loc279) + %y_idx = tt.reshape %idxs : tensor<1x16xi32> -> tensor<4x2x2xi32> loc(#loc280) + %left_idx = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<4x2x2xi32> loc(#loc282) + %left_idx_15 = arith.muli %y_idx, %left_idx : tensor<4x2x2xi32> loc(#loc282) + %left_idx_16 = tt.call @"triton.language.standard.sum__i32S4_2_2S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%left_idx_15) : (tensor<4x2x2xi32>) -> tensor<4x2xi32> loc(#loc283) + %left_idx_17 = tt.expand_dims %left_idx_16 {axis = 1 : i32} : tensor<4x2xi32> -> tensor<4x1x2xi32> loc(#loc284) + %left_idx_18 = tt.broadcast %left_idx_17 : tensor<4x1x2xi32> -> tensor<4x2x2xi32> loc(#loc285) + %right_idx = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<4x2x2xi32> loc(#loc287) + %right_idx_19 = arith.muli %y_idx, %right_idx : tensor<4x2x2xi32> loc(#loc287) + %right_idx_20 = tt.call @"triton.language.standard.sum__i32S4_2_2S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%right_idx_19) : (tensor<4x2x2xi32>) -> tensor<4x2xi32> loc(#loc288) + %right_idx_21 = tt.expand_dims %right_idx_20 {axis = 1 : i32} : tensor<4x2xi32> -> tensor<4x1x2xi32> loc(#loc289) + %right_idx_22 = tt.broadcast %right_idx_21 : tensor<4x1x2xi32> -> tensor<4x2x2xi32> loc(#loc290) + %left_idx_23 = tt.reshape %left_idx_18 : tensor<4x2x2xi32> -> tensor<1x16xi32> loc(#loc291) + %right_idx_24 = tt.reshape %right_idx_22 : tensor<4x2x2xi32> -> tensor<1x16xi32> loc(#loc292) + %left_valid_mask = arith.constant true loc(#loc293) + %left_valid_mask_25 = arith.constant dense : tensor<1x16xi1> loc(#loc293) + %right_valid_mask = arith.constant true loc(#loc294) + %right_valid_mask_26 = arith.constant dense : tensor<1x16xi1> loc(#loc294) + %left_isnan = arith.cmpi ne, %ileft_13, %ileft_13 : tensor<1x16xi32> loc(#loc295) + %right_isnan = arith.cmpi ne, %iright_14, %iright_14 : tensor<1x16xi32> loc(#loc296) + %cond = arith.cmpi slt, %ileft_13, %iright_14 : tensor<1x16xi32> loc(#loc329) + %0 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S1_16S__(%ileft_13) : (tensor<1x16xi32>) -> i1 loc(#loc135) + %1 = scf.if %0 -> (tensor<1x16xi1>) { + %cond_45 = arith.constant true loc(#loc298) + %cond_46 = arith.constant dense : tensor<1x16xi1> loc(#loc298) + %cond_47 = arith.xori %left_isnan, %cond_46 : tensor<1x16xi1> loc(#loc298) + %cond_48 = arith.andi %right_isnan, %cond_47 : tensor<1x16xi1> loc(#loc299) + %cond_49 = arith.ori %cond, %cond_48 : tensor<1x16xi1> loc(#loc330) + scf.yield %cond_49 : tensor<1x16xi1> loc(#loc330) + } else { + scf.yield %cond : tensor<1x16xi1> loc(#loc140) + } loc(#loc136) + %eq = arith.cmpi eq, %ileft_13, %iright_14 : tensor<1x16xi32> loc(#loc331) + %2 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S1_16S__(%ileft_13) : (tensor<1x16xi32>) -> i1 loc(#loc142) + %3 = scf.if %2 -> (tensor<1x16xi1>) { + %eq_45 = arith.andi %left_isnan, %right_isnan : tensor<1x16xi1> loc(#loc302) + %eq_46 = arith.ori %eq, %eq_45 : tensor<1x16xi1> loc(#loc332) + scf.yield %eq_46 : tensor<1x16xi1> loc(#loc332) + } else { + scf.yield %eq : tensor<1x16xi1> loc(#loc140) + } loc(#loc143) + %cond_27 = arith.cmpi sgt, %left_idx_23, %right_idx_24 : tensor<1x16xi32> loc(#loc304) + %cond_28 = arith.andi %3, %cond_27 : tensor<1x16xi1> loc(#loc305) + %cond_29 = arith.ori %1, %cond_28 : tensor<1x16xi1> loc(#loc306) + %cond_30 = arith.cmpi ugt, %right_valid_mask_26, %left_valid_mask_25 : tensor<1x16xi1> loc(#loc307) + %cond_31 = arith.cmpi eq, %right_valid_mask_26, %left_valid_mask_25 : tensor<1x16xi1> loc(#loc308) + %cond_32 = arith.andi %cond_31, %cond_29 : tensor<1x16xi1> loc(#loc309) + %cond_33 = arith.ori %cond_30, %cond_32 : tensor<1x16xi1> loc(#loc310) + %cond_34 = arith.extui %cond_33 : tensor<1x16xi1> to tensor<1x16xi32> loc(#loc311) + %cond_35 = arith.xori %cond_34, %flip : tensor<1x16xi32> loc(#loc311) + %cond_36 = arith.constant 0 : i32 loc(#loc312) + %cond_37 = arith.constant dense<0> : tensor<1x16xi32> loc(#loc312) + %cond_38 = arith.cmpi ne, %cond_35, %cond_37 : tensor<1x16xi32> loc(#loc312) + %ret = arith.xori %ileft_13, %iright_14 : tensor<1x16xi32> loc(#loc313) + %ret_39 = tt.call @triton.language.standard.zeros_like__i32S1_16S__(%x) : (tensor<1x16xi32>) -> tensor<1x16xi32> loc(#loc314) + %ret_40 = arith.select %cond_38, %ret, %ret_39 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc315) + %ret_41 = arith.xori %x, %ret_40 : tensor<1x16xi32> loc(#loc316) + %new_idxs = arith.xori %left_idx_23, %right_idx_24 : tensor<1x16xi32> loc(#loc317) + %new_idxs_42 = tt.call @triton.language.standard.zeros_like__i32S1_16S__(%idxs) : (tensor<1x16xi32>) -> tensor<1x16xi32> loc(#loc318) + %new_idxs_43 = arith.select %cond_38, %new_idxs, %new_idxs_42 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc319) + %new_idxs_44 = arith.xori %idxs, %new_idxs_43 : tensor<1x16xi32> loc(#loc320) + tt.return %ret_41, %new_idxs_44 : tensor<1x16xi32>, tensor<1x16xi32> loc(#loc163) + ^bb1: // no predecessors + %4 = ub.poison : tensor<1x16xi32> loc(#loc164) + %5 = ub.poison : tensor<1x16xi32> loc(#loc164) + tt.return %4, %5 : tensor<1x16xi32>, tensor<1x16xi32> loc(#loc164) + } loc(#loc102) + tt.func private @"triton.language.standard.sum__i32S4_2_2S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%input: tensor<4x2x2xi32> loc("input"(#loc165))) -> tensor<4x2xi32> attributes {noinline = false} { + %0 = "tt.reduce"(%input) <{axis = 1 : i32}> ({ + ^bb0(%arg1: i32 loc(unknown), %arg2: i32 loc(unknown)): + %2 = tt.call @triton.language.standard._sum_combine__i32_i32__(%arg1, %arg2) : (i32, i32) -> i32 loc(#loc166) + tt.reduce.return %2 : i32 loc(#loc166) + }) : (tensor<4x2x2xi32>) -> tensor<4x2xi32> loc(#loc166) + tt.return %0 : tensor<4x2xi32> loc(#loc167) + ^bb1: // no predecessors + %1 = ub.poison : tensor<4x2xi32> loc(#loc168) + tt.return %1 : tensor<4x2xi32> loc(#loc168) + } loc(#loc165) + tt.func private @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S1_16S_i32S1_16S_i32S1_16S__(2,)cconstexpr_None__(4,)cconstexpr_3__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x: tensor<1x16xi32> loc("x"(#loc102)), %idxs: tensor<1x16xi32> loc("idxs"(#loc102)), %flip: tensor<1x16xi32> loc("flip"(#loc102))) -> (tensor<1x16xi32>, tensor<1x16xi32>) attributes {noinline = false} { + %y = tt.reshape %x : tensor<1x16xi32> -> tensor<8x2x1xi32> loc(#loc266) + %right_mask = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc267) + %right_mask_0 = tt.expand_dims %right_mask {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc268) + %right_mask_1 = tt.expand_dims %right_mask_0 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc268) + %left_mask = arith.constant 1 : i32 loc(#loc269) + %left_mask_2 = arith.constant 1 : i32 loc(#loc269) + %left_mask_3 = arith.constant dense<1> : tensor<1x2x1xi32> loc(#loc269) + %left_mask_4 = arith.subi %left_mask_3, %right_mask_1 : tensor<1x2x1xi32> loc(#loc269) + %ileft = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<8x2x1xi32> loc(#loc270) + %ileft_5 = arith.muli %y, %ileft : tensor<8x2x1xi32> loc(#loc270) + %ileft_6 = tt.call @"triton.language.standard.sum__i32S8_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%ileft_5) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc271) + %ileft_7 = tt.expand_dims %ileft_6 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc272) + %ileft_8 = tt.broadcast %ileft_7 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc273) + %iright = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<8x2x1xi32> loc(#loc274) + %iright_9 = arith.muli %y, %iright : tensor<8x2x1xi32> loc(#loc274) + %iright_10 = tt.call @"triton.language.standard.sum__i32S8_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%iright_9) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc275) + %iright_11 = tt.expand_dims %iright_10 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc276) + %iright_12 = tt.broadcast %iright_11 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc277) + %ileft_13 = tt.reshape %ileft_8 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc278) + %iright_14 = tt.reshape %iright_12 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc279) + %y_idx = tt.reshape %idxs : tensor<1x16xi32> -> tensor<8x2x1xi32> loc(#loc280) + %left_idx = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<8x2x1xi32> loc(#loc282) + %left_idx_15 = arith.muli %y_idx, %left_idx : tensor<8x2x1xi32> loc(#loc282) + %left_idx_16 = tt.call @"triton.language.standard.sum__i32S8_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%left_idx_15) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc283) + %left_idx_17 = tt.expand_dims %left_idx_16 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc284) + %left_idx_18 = tt.broadcast %left_idx_17 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc285) + %right_idx = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<8x2x1xi32> loc(#loc287) + %right_idx_19 = arith.muli %y_idx, %right_idx : tensor<8x2x1xi32> loc(#loc287) + %right_idx_20 = tt.call @"triton.language.standard.sum__i32S8_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%right_idx_19) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc288) + %right_idx_21 = tt.expand_dims %right_idx_20 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc289) + %right_idx_22 = tt.broadcast %right_idx_21 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc290) + %left_idx_23 = tt.reshape %left_idx_18 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc291) + %right_idx_24 = tt.reshape %right_idx_22 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc292) + %left_valid_mask = arith.constant true loc(#loc293) + %left_valid_mask_25 = arith.constant dense : tensor<1x16xi1> loc(#loc293) + %right_valid_mask = arith.constant true loc(#loc294) + %right_valid_mask_26 = arith.constant dense : tensor<1x16xi1> loc(#loc294) + %left_isnan = arith.cmpi ne, %ileft_13, %ileft_13 : tensor<1x16xi32> loc(#loc295) + %right_isnan = arith.cmpi ne, %iright_14, %iright_14 : tensor<1x16xi32> loc(#loc296) + %cond = arith.cmpi slt, %ileft_13, %iright_14 : tensor<1x16xi32> loc(#loc329) + %0 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S1_16S__(%ileft_13) : (tensor<1x16xi32>) -> i1 loc(#loc135) + %1 = scf.if %0 -> (tensor<1x16xi1>) { + %cond_45 = arith.constant true loc(#loc298) + %cond_46 = arith.constant dense : tensor<1x16xi1> loc(#loc298) + %cond_47 = arith.xori %left_isnan, %cond_46 : tensor<1x16xi1> loc(#loc298) + %cond_48 = arith.andi %right_isnan, %cond_47 : tensor<1x16xi1> loc(#loc299) + %cond_49 = arith.ori %cond, %cond_48 : tensor<1x16xi1> loc(#loc330) + scf.yield %cond_49 : tensor<1x16xi1> loc(#loc330) + } else { + scf.yield %cond : tensor<1x16xi1> loc(#loc140) + } loc(#loc136) + %eq = arith.cmpi eq, %ileft_13, %iright_14 : tensor<1x16xi32> loc(#loc331) + %2 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S1_16S__(%ileft_13) : (tensor<1x16xi32>) -> i1 loc(#loc142) + %3 = scf.if %2 -> (tensor<1x16xi1>) { + %eq_45 = arith.andi %left_isnan, %right_isnan : tensor<1x16xi1> loc(#loc302) + %eq_46 = arith.ori %eq, %eq_45 : tensor<1x16xi1> loc(#loc332) + scf.yield %eq_46 : tensor<1x16xi1> loc(#loc332) + } else { + scf.yield %eq : tensor<1x16xi1> loc(#loc140) + } loc(#loc143) + %cond_27 = arith.cmpi sgt, %left_idx_23, %right_idx_24 : tensor<1x16xi32> loc(#loc304) + %cond_28 = arith.andi %3, %cond_27 : tensor<1x16xi1> loc(#loc305) + %cond_29 = arith.ori %1, %cond_28 : tensor<1x16xi1> loc(#loc306) + %cond_30 = arith.cmpi ugt, %right_valid_mask_26, %left_valid_mask_25 : tensor<1x16xi1> loc(#loc307) + %cond_31 = arith.cmpi eq, %right_valid_mask_26, %left_valid_mask_25 : tensor<1x16xi1> loc(#loc308) + %cond_32 = arith.andi %cond_31, %cond_29 : tensor<1x16xi1> loc(#loc309) + %cond_33 = arith.ori %cond_30, %cond_32 : tensor<1x16xi1> loc(#loc310) + %cond_34 = arith.extui %cond_33 : tensor<1x16xi1> to tensor<1x16xi32> loc(#loc311) + %cond_35 = arith.xori %cond_34, %flip : tensor<1x16xi32> loc(#loc311) + %cond_36 = arith.constant 0 : i32 loc(#loc312) + %cond_37 = arith.constant dense<0> : tensor<1x16xi32> loc(#loc312) + %cond_38 = arith.cmpi ne, %cond_35, %cond_37 : tensor<1x16xi32> loc(#loc312) + %ret = arith.xori %ileft_13, %iright_14 : tensor<1x16xi32> loc(#loc313) + %ret_39 = tt.call @triton.language.standard.zeros_like__i32S1_16S__(%x) : (tensor<1x16xi32>) -> tensor<1x16xi32> loc(#loc314) + %ret_40 = arith.select %cond_38, %ret, %ret_39 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc315) + %ret_41 = arith.xori %x, %ret_40 : tensor<1x16xi32> loc(#loc316) + %new_idxs = arith.xori %left_idx_23, %right_idx_24 : tensor<1x16xi32> loc(#loc317) + %new_idxs_42 = tt.call @triton.language.standard.zeros_like__i32S1_16S__(%idxs) : (tensor<1x16xi32>) -> tensor<1x16xi32> loc(#loc318) + %new_idxs_43 = arith.select %cond_38, %new_idxs, %new_idxs_42 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc319) + %new_idxs_44 = arith.xori %idxs, %new_idxs_43 : tensor<1x16xi32> loc(#loc320) + tt.return %ret_41, %new_idxs_44 : tensor<1x16xi32>, tensor<1x16xi32> loc(#loc163) + ^bb1: // no predecessors + %4 = ub.poison : tensor<1x16xi32> loc(#loc164) + %5 = ub.poison : tensor<1x16xi32> loc(#loc164) + tt.return %4, %5 : tensor<1x16xi32>, tensor<1x16xi32> loc(#loc164) + } loc(#loc102) + tt.func private @"torch._inductor.runtime.triton_helpers._bitonic_merge_with_index__i32S1_16S_i32S1_16S__(2,)cconstexpr_None__(3,)cconstexpr_3__(4,)cconstexpr_True__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x: tensor<1x16xi32> loc("x"(#loc94)), %idxs: tensor<1x16xi32> loc("idxs"(#loc94))) -> (tensor<1x16xi32>, tensor<1x16xi32>) attributes {noinline = false} { + %flip = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc259) + %flip_0 = tt.expand_dims %flip {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc260) + %flip_1 = tt.expand_dims %flip_0 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc260) + %flip_2 = tt.broadcast %flip_1 : tensor<1x2x1xi32> -> tensor<1x2x8xi32> loc(#loc261) + %flip_3 = tt.reshape %flip_2 : tensor<1x2x8xi32> -> tensor<1x16xi32> loc(#loc262) + %0:2 = tt.call @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S1_16S_i32S1_16S_i32S1_16S__(2,)cconstexpr_None__(4,)cconstexpr_1__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x, %idxs, %flip_3) : (tensor<1x16xi32>, tensor<1x16xi32>, tensor<1x16xi32>) -> (tensor<1x16xi32>, tensor<1x16xi32>) loc(#loc99) + %1:2 = tt.call @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S1_16S_i32S1_16S_i32S1_16S__(2,)cconstexpr_None__(4,)cconstexpr_2__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%0#0, %0#1, %flip_3) : (tensor<1x16xi32>, tensor<1x16xi32>, tensor<1x16xi32>) -> (tensor<1x16xi32>, tensor<1x16xi32>) loc(#loc99) + %2:2 = tt.call @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S1_16S_i32S1_16S_i32S1_16S__(2,)cconstexpr_None__(4,)cconstexpr_3__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%1#0, %1#1, %flip_3) : (tensor<1x16xi32>, tensor<1x16xi32>, tensor<1x16xi32>) -> (tensor<1x16xi32>, tensor<1x16xi32>) loc(#loc99) + tt.return %2#0, %2#1 : tensor<1x16xi32>, tensor<1x16xi32> loc(#loc100) + ^bb1: // no predecessors + %3 = ub.poison : tensor<1x16xi32> loc(#loc101) + %4 = ub.poison : tensor<1x16xi32> loc(#loc101) + tt.return %3, %4 : tensor<1x16xi32>, tensor<1x16xi32> loc(#loc101) + } loc(#loc94) + tt.func private @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S1_16S_i32S1_16S_i32S1_16S__(2,)cconstexpr_None__(4,)cconstexpr_1__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x: tensor<1x16xi32> loc("x"(#loc102)), %idxs: tensor<1x16xi32> loc("idxs"(#loc102)), %flip: tensor<1x16xi32> loc("flip"(#loc102))) -> (tensor<1x16xi32>, tensor<1x16xi32>) attributes {noinline = false} { + %y = tt.reshape %x : tensor<1x16xi32> -> tensor<2x2x4xi32> loc(#loc266) + %right_mask = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc267) + %right_mask_0 = tt.expand_dims %right_mask {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc268) + %right_mask_1 = tt.expand_dims %right_mask_0 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc268) + %left_mask = arith.constant 1 : i32 loc(#loc269) + %left_mask_2 = arith.constant 1 : i32 loc(#loc269) + %left_mask_3 = arith.constant dense<1> : tensor<1x2x1xi32> loc(#loc269) + %left_mask_4 = arith.subi %left_mask_3, %right_mask_1 : tensor<1x2x1xi32> loc(#loc269) + %ileft = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<2x2x4xi32> loc(#loc270) + %ileft_5 = arith.muli %y, %ileft : tensor<2x2x4xi32> loc(#loc270) + %ileft_6 = tt.call @"triton.language.standard.sum__i32S2_2_4S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%ileft_5) : (tensor<2x2x4xi32>) -> tensor<2x4xi32> loc(#loc271) + %ileft_7 = tt.expand_dims %ileft_6 {axis = 1 : i32} : tensor<2x4xi32> -> tensor<2x1x4xi32> loc(#loc272) + %ileft_8 = tt.broadcast %ileft_7 : tensor<2x1x4xi32> -> tensor<2x2x4xi32> loc(#loc273) + %iright = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<2x2x4xi32> loc(#loc274) + %iright_9 = arith.muli %y, %iright : tensor<2x2x4xi32> loc(#loc274) + %iright_10 = tt.call @"triton.language.standard.sum__i32S2_2_4S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%iright_9) : (tensor<2x2x4xi32>) -> tensor<2x4xi32> loc(#loc275) + %iright_11 = tt.expand_dims %iright_10 {axis = 1 : i32} : tensor<2x4xi32> -> tensor<2x1x4xi32> loc(#loc276) + %iright_12 = tt.broadcast %iright_11 : tensor<2x1x4xi32> -> tensor<2x2x4xi32> loc(#loc277) + %ileft_13 = tt.reshape %ileft_8 : tensor<2x2x4xi32> -> tensor<1x16xi32> loc(#loc278) + %iright_14 = tt.reshape %iright_12 : tensor<2x2x4xi32> -> tensor<1x16xi32> loc(#loc279) + %y_idx = tt.reshape %idxs : tensor<1x16xi32> -> tensor<2x2x4xi32> loc(#loc280) + %left_idx = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<2x2x4xi32> loc(#loc282) + %left_idx_15 = arith.muli %y_idx, %left_idx : tensor<2x2x4xi32> loc(#loc282) + %left_idx_16 = tt.call @"triton.language.standard.sum__i32S2_2_4S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%left_idx_15) : (tensor<2x2x4xi32>) -> tensor<2x4xi32> loc(#loc283) + %left_idx_17 = tt.expand_dims %left_idx_16 {axis = 1 : i32} : tensor<2x4xi32> -> tensor<2x1x4xi32> loc(#loc284) + %left_idx_18 = tt.broadcast %left_idx_17 : tensor<2x1x4xi32> -> tensor<2x2x4xi32> loc(#loc285) + %right_idx = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<2x2x4xi32> loc(#loc287) + %right_idx_19 = arith.muli %y_idx, %right_idx : tensor<2x2x4xi32> loc(#loc287) + %right_idx_20 = tt.call @"triton.language.standard.sum__i32S2_2_4S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%right_idx_19) : (tensor<2x2x4xi32>) -> tensor<2x4xi32> loc(#loc288) + %right_idx_21 = tt.expand_dims %right_idx_20 {axis = 1 : i32} : tensor<2x4xi32> -> tensor<2x1x4xi32> loc(#loc289) + %right_idx_22 = tt.broadcast %right_idx_21 : tensor<2x1x4xi32> -> tensor<2x2x4xi32> loc(#loc290) + %left_idx_23 = tt.reshape %left_idx_18 : tensor<2x2x4xi32> -> tensor<1x16xi32> loc(#loc291) + %right_idx_24 = tt.reshape %right_idx_22 : tensor<2x2x4xi32> -> tensor<1x16xi32> loc(#loc292) + %left_valid_mask = arith.constant true loc(#loc293) + %left_valid_mask_25 = arith.constant dense : tensor<1x16xi1> loc(#loc293) + %right_valid_mask = arith.constant true loc(#loc294) + %right_valid_mask_26 = arith.constant dense : tensor<1x16xi1> loc(#loc294) + %left_isnan = arith.cmpi ne, %ileft_13, %ileft_13 : tensor<1x16xi32> loc(#loc295) + %right_isnan = arith.cmpi ne, %iright_14, %iright_14 : tensor<1x16xi32> loc(#loc296) + %cond = arith.cmpi slt, %ileft_13, %iright_14 : tensor<1x16xi32> loc(#loc329) + %0 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S1_16S__(%ileft_13) : (tensor<1x16xi32>) -> i1 loc(#loc135) + %1 = scf.if %0 -> (tensor<1x16xi1>) { + %cond_45 = arith.constant true loc(#loc298) + %cond_46 = arith.constant dense : tensor<1x16xi1> loc(#loc298) + %cond_47 = arith.xori %left_isnan, %cond_46 : tensor<1x16xi1> loc(#loc298) + %cond_48 = arith.andi %right_isnan, %cond_47 : tensor<1x16xi1> loc(#loc299) + %cond_49 = arith.ori %cond, %cond_48 : tensor<1x16xi1> loc(#loc330) + scf.yield %cond_49 : tensor<1x16xi1> loc(#loc330) + } else { + scf.yield %cond : tensor<1x16xi1> loc(#loc140) + } loc(#loc136) + %eq = arith.cmpi eq, %ileft_13, %iright_14 : tensor<1x16xi32> loc(#loc331) + %2 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S1_16S__(%ileft_13) : (tensor<1x16xi32>) -> i1 loc(#loc142) + %3 = scf.if %2 -> (tensor<1x16xi1>) { + %eq_45 = arith.andi %left_isnan, %right_isnan : tensor<1x16xi1> loc(#loc302) + %eq_46 = arith.ori %eq, %eq_45 : tensor<1x16xi1> loc(#loc332) + scf.yield %eq_46 : tensor<1x16xi1> loc(#loc332) + } else { + scf.yield %eq : tensor<1x16xi1> loc(#loc140) + } loc(#loc143) + %cond_27 = arith.cmpi sgt, %left_idx_23, %right_idx_24 : tensor<1x16xi32> loc(#loc304) + %cond_28 = arith.andi %3, %cond_27 : tensor<1x16xi1> loc(#loc305) + %cond_29 = arith.ori %1, %cond_28 : tensor<1x16xi1> loc(#loc306) + %cond_30 = arith.cmpi ugt, %right_valid_mask_26, %left_valid_mask_25 : tensor<1x16xi1> loc(#loc307) + %cond_31 = arith.cmpi eq, %right_valid_mask_26, %left_valid_mask_25 : tensor<1x16xi1> loc(#loc308) + %cond_32 = arith.andi %cond_31, %cond_29 : tensor<1x16xi1> loc(#loc309) + %cond_33 = arith.ori %cond_30, %cond_32 : tensor<1x16xi1> loc(#loc310) + %cond_34 = arith.extui %cond_33 : tensor<1x16xi1> to tensor<1x16xi32> loc(#loc311) + %cond_35 = arith.xori %cond_34, %flip : tensor<1x16xi32> loc(#loc311) + %cond_36 = arith.constant 0 : i32 loc(#loc312) + %cond_37 = arith.constant dense<0> : tensor<1x16xi32> loc(#loc312) + %cond_38 = arith.cmpi ne, %cond_35, %cond_37 : tensor<1x16xi32> loc(#loc312) + %ret = arith.xori %ileft_13, %iright_14 : tensor<1x16xi32> loc(#loc313) + %ret_39 = tt.call @triton.language.standard.zeros_like__i32S1_16S__(%x) : (tensor<1x16xi32>) -> tensor<1x16xi32> loc(#loc314) + %ret_40 = arith.select %cond_38, %ret, %ret_39 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc315) + %ret_41 = arith.xori %x, %ret_40 : tensor<1x16xi32> loc(#loc316) + %new_idxs = arith.xori %left_idx_23, %right_idx_24 : tensor<1x16xi32> loc(#loc317) + %new_idxs_42 = tt.call @triton.language.standard.zeros_like__i32S1_16S__(%idxs) : (tensor<1x16xi32>) -> tensor<1x16xi32> loc(#loc318) + %new_idxs_43 = arith.select %cond_38, %new_idxs, %new_idxs_42 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc319) + %new_idxs_44 = arith.xori %idxs, %new_idxs_43 : tensor<1x16xi32> loc(#loc320) + tt.return %ret_41, %new_idxs_44 : tensor<1x16xi32>, tensor<1x16xi32> loc(#loc163) + ^bb1: // no predecessors + %4 = ub.poison : tensor<1x16xi32> loc(#loc164) + %5 = ub.poison : tensor<1x16xi32> loc(#loc164) + tt.return %4, %5 : tensor<1x16xi32>, tensor<1x16xi32> loc(#loc164) + } loc(#loc102) + tt.func private @"triton.language.standard.sum__i32S2_2_4S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%input: tensor<2x2x4xi32> loc("input"(#loc165))) -> tensor<2x4xi32> attributes {noinline = false} { + %0 = "tt.reduce"(%input) <{axis = 1 : i32}> ({ + ^bb0(%arg1: i32 loc(unknown), %arg2: i32 loc(unknown)): + %2 = tt.call @triton.language.standard._sum_combine__i32_i32__(%arg1, %arg2) : (i32, i32) -> i32 loc(#loc166) + tt.reduce.return %2 : i32 loc(#loc166) + }) : (tensor<2x2x4xi32>) -> tensor<2x4xi32> loc(#loc166) + tt.return %0 : tensor<2x4xi32> loc(#loc167) + ^bb1: // no predecessors + %1 = ub.poison : tensor<2x4xi32> loc(#loc168) + tt.return %1 : tensor<2x4xi32> loc(#loc168) + } loc(#loc165) + tt.func private @"torch._inductor.runtime.triton_helpers._bitonic_merge_with_index__i32S1_16S_i32S1_16S__(2,)cconstexpr_None__(3,)cconstexpr_4__(4,)cconstexpr_False__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x: tensor<1x16xi32> loc("x"(#loc94)), %idxs: tensor<1x16xi32> loc("idxs"(#loc94))) -> (tensor<1x16xi32>, tensor<1x16xi32>) attributes {noinline = false} { + %flip = arith.constant false loc(#loc328) + %0:2 = tt.call @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S1_16S_i32S1_16S_u1__(2,)cconstexpr_None__(4,)cconstexpr_0__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x, %idxs, %flip) : (tensor<1x16xi32>, tensor<1x16xi32>, i1) -> (tensor<1x16xi32>, tensor<1x16xi32>) loc(#loc99) + %1:2 = tt.call @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S1_16S_i32S1_16S_u1__(2,)cconstexpr_None__(4,)cconstexpr_1__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%0#0, %0#1, %flip) : (tensor<1x16xi32>, tensor<1x16xi32>, i1) -> (tensor<1x16xi32>, tensor<1x16xi32>) loc(#loc99) + %2:2 = tt.call @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S1_16S_i32S1_16S_u1__(2,)cconstexpr_None__(4,)cconstexpr_2__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%1#0, %1#1, %flip) : (tensor<1x16xi32>, tensor<1x16xi32>, i1) -> (tensor<1x16xi32>, tensor<1x16xi32>) loc(#loc99) + %3:2 = tt.call @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S1_16S_i32S1_16S_u1__(2,)cconstexpr_None__(4,)cconstexpr_3__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%2#0, %2#1, %flip) : (tensor<1x16xi32>, tensor<1x16xi32>, i1) -> (tensor<1x16xi32>, tensor<1x16xi32>) loc(#loc99) + tt.return %3#0, %3#1 : tensor<1x16xi32>, tensor<1x16xi32> loc(#loc100) + ^bb1: // no predecessors + %4 = ub.poison : tensor<1x16xi32> loc(#loc101) + %5 = ub.poison : tensor<1x16xi32> loc(#loc101) + tt.return %4, %5 : tensor<1x16xi32>, tensor<1x16xi32> loc(#loc101) + } loc(#loc94) + tt.func private @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S1_16S_i32S1_16S_u1__(2,)cconstexpr_None__(4,)cconstexpr_0__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x: tensor<1x16xi32> loc("x"(#loc102)), %idxs: tensor<1x16xi32> loc("idxs"(#loc102)), %flip: i1 loc("flip"(#loc102))) -> (tensor<1x16xi32>, tensor<1x16xi32>) attributes {noinline = false} { + %y = tt.reshape %x : tensor<1x16xi32> -> tensor<1x2x8xi32> loc(#loc266) + %right_mask = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc267) + %right_mask_0 = tt.expand_dims %right_mask {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc268) + %right_mask_1 = tt.expand_dims %right_mask_0 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc268) + %left_mask = arith.constant 1 : i32 loc(#loc269) + %left_mask_2 = arith.constant 1 : i32 loc(#loc269) + %left_mask_3 = arith.constant dense<1> : tensor<1x2x1xi32> loc(#loc269) + %left_mask_4 = arith.subi %left_mask_3, %right_mask_1 : tensor<1x2x1xi32> loc(#loc269) + %ileft = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<1x2x8xi32> loc(#loc270) + %ileft_5 = arith.muli %y, %ileft : tensor<1x2x8xi32> loc(#loc270) + %ileft_6 = tt.call @"triton.language.standard.sum__i32S1_2_8S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%ileft_5) : (tensor<1x2x8xi32>) -> tensor<1x8xi32> loc(#loc271) + %ileft_7 = tt.expand_dims %ileft_6 {axis = 1 : i32} : tensor<1x8xi32> -> tensor<1x1x8xi32> loc(#loc272) + %ileft_8 = tt.broadcast %ileft_7 : tensor<1x1x8xi32> -> tensor<1x2x8xi32> loc(#loc273) + %iright = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<1x2x8xi32> loc(#loc274) + %iright_9 = arith.muli %y, %iright : tensor<1x2x8xi32> loc(#loc274) + %iright_10 = tt.call @"triton.language.standard.sum__i32S1_2_8S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%iright_9) : (tensor<1x2x8xi32>) -> tensor<1x8xi32> loc(#loc275) + %iright_11 = tt.expand_dims %iright_10 {axis = 1 : i32} : tensor<1x8xi32> -> tensor<1x1x8xi32> loc(#loc276) + %iright_12 = tt.broadcast %iright_11 : tensor<1x1x8xi32> -> tensor<1x2x8xi32> loc(#loc277) + %ileft_13 = tt.reshape %ileft_8 : tensor<1x2x8xi32> -> tensor<1x16xi32> loc(#loc278) + %iright_14 = tt.reshape %iright_12 : tensor<1x2x8xi32> -> tensor<1x16xi32> loc(#loc279) + %y_idx = tt.reshape %idxs : tensor<1x16xi32> -> tensor<1x2x8xi32> loc(#loc280) + %left_idx = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<1x2x8xi32> loc(#loc282) + %left_idx_15 = arith.muli %y_idx, %left_idx : tensor<1x2x8xi32> loc(#loc282) + %left_idx_16 = tt.call @"triton.language.standard.sum__i32S1_2_8S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%left_idx_15) : (tensor<1x2x8xi32>) -> tensor<1x8xi32> loc(#loc283) + %left_idx_17 = tt.expand_dims %left_idx_16 {axis = 1 : i32} : tensor<1x8xi32> -> tensor<1x1x8xi32> loc(#loc284) + %left_idx_18 = tt.broadcast %left_idx_17 : tensor<1x1x8xi32> -> tensor<1x2x8xi32> loc(#loc285) + %right_idx = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<1x2x8xi32> loc(#loc287) + %right_idx_19 = arith.muli %y_idx, %right_idx : tensor<1x2x8xi32> loc(#loc287) + %right_idx_20 = tt.call @"triton.language.standard.sum__i32S1_2_8S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%right_idx_19) : (tensor<1x2x8xi32>) -> tensor<1x8xi32> loc(#loc288) + %right_idx_21 = tt.expand_dims %right_idx_20 {axis = 1 : i32} : tensor<1x8xi32> -> tensor<1x1x8xi32> loc(#loc289) + %right_idx_22 = tt.broadcast %right_idx_21 : tensor<1x1x8xi32> -> tensor<1x2x8xi32> loc(#loc290) + %left_idx_23 = tt.reshape %left_idx_18 : tensor<1x2x8xi32> -> tensor<1x16xi32> loc(#loc291) + %right_idx_24 = tt.reshape %right_idx_22 : tensor<1x2x8xi32> -> tensor<1x16xi32> loc(#loc292) + %left_valid_mask = arith.constant true loc(#loc293) + %left_valid_mask_25 = arith.constant dense : tensor<1x16xi1> loc(#loc293) + %right_valid_mask = arith.constant true loc(#loc294) + %right_valid_mask_26 = arith.constant dense : tensor<1x16xi1> loc(#loc294) + %left_isnan = arith.cmpi ne, %ileft_13, %ileft_13 : tensor<1x16xi32> loc(#loc295) + %right_isnan = arith.cmpi ne, %iright_14, %iright_14 : tensor<1x16xi32> loc(#loc296) + %cond = arith.cmpi slt, %ileft_13, %iright_14 : tensor<1x16xi32> loc(#loc329) + %0 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S1_16S__(%ileft_13) : (tensor<1x16xi32>) -> i1 loc(#loc135) + %1 = scf.if %0 -> (tensor<1x16xi1>) { + %cond_42 = arith.constant true loc(#loc298) + %cond_43 = arith.constant dense : tensor<1x16xi1> loc(#loc298) + %cond_44 = arith.xori %left_isnan, %cond_43 : tensor<1x16xi1> loc(#loc298) + %cond_45 = arith.andi %right_isnan, %cond_44 : tensor<1x16xi1> loc(#loc299) + %cond_46 = arith.ori %cond, %cond_45 : tensor<1x16xi1> loc(#loc330) + scf.yield %cond_46 : tensor<1x16xi1> loc(#loc330) + } else { + scf.yield %cond : tensor<1x16xi1> loc(#loc140) + } loc(#loc136) + %eq = arith.cmpi eq, %ileft_13, %iright_14 : tensor<1x16xi32> loc(#loc331) + %2 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S1_16S__(%ileft_13) : (tensor<1x16xi32>) -> i1 loc(#loc142) + %3 = scf.if %2 -> (tensor<1x16xi1>) { + %eq_42 = arith.andi %left_isnan, %right_isnan : tensor<1x16xi1> loc(#loc302) + %eq_43 = arith.ori %eq, %eq_42 : tensor<1x16xi1> loc(#loc332) + scf.yield %eq_43 : tensor<1x16xi1> loc(#loc332) + } else { + scf.yield %eq : tensor<1x16xi1> loc(#loc140) + } loc(#loc143) + %cond_27 = arith.cmpi sgt, %left_idx_23, %right_idx_24 : tensor<1x16xi32> loc(#loc304) + %cond_28 = arith.andi %3, %cond_27 : tensor<1x16xi1> loc(#loc305) + %cond_29 = arith.ori %1, %cond_28 : tensor<1x16xi1> loc(#loc306) + %cond_30 = arith.cmpi ugt, %right_valid_mask_26, %left_valid_mask_25 : tensor<1x16xi1> loc(#loc307) + %cond_31 = arith.cmpi eq, %right_valid_mask_26, %left_valid_mask_25 : tensor<1x16xi1> loc(#loc308) + %cond_32 = arith.andi %cond_31, %cond_29 : tensor<1x16xi1> loc(#loc309) + %cond_33 = arith.ori %cond_30, %cond_32 : tensor<1x16xi1> loc(#loc310) + %cond_34 = tt.splat %flip : i1 -> tensor<1x16xi1> loc(#loc311) + %cond_35 = arith.xori %cond_33, %cond_34 : tensor<1x16xi1> loc(#loc311) + %ret = arith.xori %ileft_13, %iright_14 : tensor<1x16xi32> loc(#loc313) + %ret_36 = tt.call @triton.language.standard.zeros_like__i32S1_16S__(%x) : (tensor<1x16xi32>) -> tensor<1x16xi32> loc(#loc314) + %ret_37 = arith.select %cond_35, %ret, %ret_36 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc315) + %ret_38 = arith.xori %x, %ret_37 : tensor<1x16xi32> loc(#loc316) + %new_idxs = arith.xori %left_idx_23, %right_idx_24 : tensor<1x16xi32> loc(#loc317) + %new_idxs_39 = tt.call @triton.language.standard.zeros_like__i32S1_16S__(%idxs) : (tensor<1x16xi32>) -> tensor<1x16xi32> loc(#loc318) + %new_idxs_40 = arith.select %cond_35, %new_idxs, %new_idxs_39 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc319) + %new_idxs_41 = arith.xori %idxs, %new_idxs_40 : tensor<1x16xi32> loc(#loc320) + tt.return %ret_38, %new_idxs_41 : tensor<1x16xi32>, tensor<1x16xi32> loc(#loc163) + ^bb1: // no predecessors + %4 = ub.poison : tensor<1x16xi32> loc(#loc164) + %5 = ub.poison : tensor<1x16xi32> loc(#loc164) + tt.return %4, %5 : tensor<1x16xi32>, tensor<1x16xi32> loc(#loc164) + } loc(#loc102) + tt.func private @"triton.language.standard.sum__i32S1_2_8S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%input: tensor<1x2x8xi32> loc("input"(#loc165))) -> tensor<1x8xi32> attributes {noinline = false} { + %0 = "tt.reduce"(%input) <{axis = 1 : i32}> ({ + ^bb0(%arg1: i32 loc(unknown), %arg2: i32 loc(unknown)): + %2 = tt.call @triton.language.standard._sum_combine__i32_i32__(%arg1, %arg2) : (i32, i32) -> i32 loc(#loc166) + tt.reduce.return %2 : i32 loc(#loc166) + }) : (tensor<1x2x8xi32>) -> tensor<1x8xi32> loc(#loc166) + tt.return %0 : tensor<1x8xi32> loc(#loc167) + ^bb1: // no predecessors + %1 = ub.poison : tensor<1x8xi32> loc(#loc168) + tt.return %1 : tensor<1x8xi32> loc(#loc168) + } loc(#loc165) + tt.func private @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S1_16S_i32S1_16S_u1__(2,)cconstexpr_None__(4,)cconstexpr_1__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x: tensor<1x16xi32> loc("x"(#loc102)), %idxs: tensor<1x16xi32> loc("idxs"(#loc102)), %flip: i1 loc("flip"(#loc102))) -> (tensor<1x16xi32>, tensor<1x16xi32>) attributes {noinline = false} { + %y = tt.reshape %x : tensor<1x16xi32> -> tensor<2x2x4xi32> loc(#loc266) + %right_mask = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc267) + %right_mask_0 = tt.expand_dims %right_mask {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc268) + %right_mask_1 = tt.expand_dims %right_mask_0 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc268) + %left_mask = arith.constant 1 : i32 loc(#loc269) + %left_mask_2 = arith.constant 1 : i32 loc(#loc269) + %left_mask_3 = arith.constant dense<1> : tensor<1x2x1xi32> loc(#loc269) + %left_mask_4 = arith.subi %left_mask_3, %right_mask_1 : tensor<1x2x1xi32> loc(#loc269) + %ileft = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<2x2x4xi32> loc(#loc270) + %ileft_5 = arith.muli %y, %ileft : tensor<2x2x4xi32> loc(#loc270) + %ileft_6 = tt.call @"triton.language.standard.sum__i32S2_2_4S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%ileft_5) : (tensor<2x2x4xi32>) -> tensor<2x4xi32> loc(#loc271) + %ileft_7 = tt.expand_dims %ileft_6 {axis = 1 : i32} : tensor<2x4xi32> -> tensor<2x1x4xi32> loc(#loc272) + %ileft_8 = tt.broadcast %ileft_7 : tensor<2x1x4xi32> -> tensor<2x2x4xi32> loc(#loc273) + %iright = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<2x2x4xi32> loc(#loc274) + %iright_9 = arith.muli %y, %iright : tensor<2x2x4xi32> loc(#loc274) + %iright_10 = tt.call @"triton.language.standard.sum__i32S2_2_4S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%iright_9) : (tensor<2x2x4xi32>) -> tensor<2x4xi32> loc(#loc275) + %iright_11 = tt.expand_dims %iright_10 {axis = 1 : i32} : tensor<2x4xi32> -> tensor<2x1x4xi32> loc(#loc276) + %iright_12 = tt.broadcast %iright_11 : tensor<2x1x4xi32> -> tensor<2x2x4xi32> loc(#loc277) + %ileft_13 = tt.reshape %ileft_8 : tensor<2x2x4xi32> -> tensor<1x16xi32> loc(#loc278) + %iright_14 = tt.reshape %iright_12 : tensor<2x2x4xi32> -> tensor<1x16xi32> loc(#loc279) + %y_idx = tt.reshape %idxs : tensor<1x16xi32> -> tensor<2x2x4xi32> loc(#loc280) + %left_idx = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<2x2x4xi32> loc(#loc282) + %left_idx_15 = arith.muli %y_idx, %left_idx : tensor<2x2x4xi32> loc(#loc282) + %left_idx_16 = tt.call @"triton.language.standard.sum__i32S2_2_4S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%left_idx_15) : (tensor<2x2x4xi32>) -> tensor<2x4xi32> loc(#loc283) + %left_idx_17 = tt.expand_dims %left_idx_16 {axis = 1 : i32} : tensor<2x4xi32> -> tensor<2x1x4xi32> loc(#loc284) + %left_idx_18 = tt.broadcast %left_idx_17 : tensor<2x1x4xi32> -> tensor<2x2x4xi32> loc(#loc285) + %right_idx = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<2x2x4xi32> loc(#loc287) + %right_idx_19 = arith.muli %y_idx, %right_idx : tensor<2x2x4xi32> loc(#loc287) + %right_idx_20 = tt.call @"triton.language.standard.sum__i32S2_2_4S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%right_idx_19) : (tensor<2x2x4xi32>) -> tensor<2x4xi32> loc(#loc288) + %right_idx_21 = tt.expand_dims %right_idx_20 {axis = 1 : i32} : tensor<2x4xi32> -> tensor<2x1x4xi32> loc(#loc289) + %right_idx_22 = tt.broadcast %right_idx_21 : tensor<2x1x4xi32> -> tensor<2x2x4xi32> loc(#loc290) + %left_idx_23 = tt.reshape %left_idx_18 : tensor<2x2x4xi32> -> tensor<1x16xi32> loc(#loc291) + %right_idx_24 = tt.reshape %right_idx_22 : tensor<2x2x4xi32> -> tensor<1x16xi32> loc(#loc292) + %left_valid_mask = arith.constant true loc(#loc293) + %left_valid_mask_25 = arith.constant dense : tensor<1x16xi1> loc(#loc293) + %right_valid_mask = arith.constant true loc(#loc294) + %right_valid_mask_26 = arith.constant dense : tensor<1x16xi1> loc(#loc294) + %left_isnan = arith.cmpi ne, %ileft_13, %ileft_13 : tensor<1x16xi32> loc(#loc295) + %right_isnan = arith.cmpi ne, %iright_14, %iright_14 : tensor<1x16xi32> loc(#loc296) + %cond = arith.cmpi slt, %ileft_13, %iright_14 : tensor<1x16xi32> loc(#loc329) + %0 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S1_16S__(%ileft_13) : (tensor<1x16xi32>) -> i1 loc(#loc135) + %1 = scf.if %0 -> (tensor<1x16xi1>) { + %cond_42 = arith.constant true loc(#loc298) + %cond_43 = arith.constant dense : tensor<1x16xi1> loc(#loc298) + %cond_44 = arith.xori %left_isnan, %cond_43 : tensor<1x16xi1> loc(#loc298) + %cond_45 = arith.andi %right_isnan, %cond_44 : tensor<1x16xi1> loc(#loc299) + %cond_46 = arith.ori %cond, %cond_45 : tensor<1x16xi1> loc(#loc330) + scf.yield %cond_46 : tensor<1x16xi1> loc(#loc330) + } else { + scf.yield %cond : tensor<1x16xi1> loc(#loc140) + } loc(#loc136) + %eq = arith.cmpi eq, %ileft_13, %iright_14 : tensor<1x16xi32> loc(#loc331) + %2 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S1_16S__(%ileft_13) : (tensor<1x16xi32>) -> i1 loc(#loc142) + %3 = scf.if %2 -> (tensor<1x16xi1>) { + %eq_42 = arith.andi %left_isnan, %right_isnan : tensor<1x16xi1> loc(#loc302) + %eq_43 = arith.ori %eq, %eq_42 : tensor<1x16xi1> loc(#loc332) + scf.yield %eq_43 : tensor<1x16xi1> loc(#loc332) + } else { + scf.yield %eq : tensor<1x16xi1> loc(#loc140) + } loc(#loc143) + %cond_27 = arith.cmpi sgt, %left_idx_23, %right_idx_24 : tensor<1x16xi32> loc(#loc304) + %cond_28 = arith.andi %3, %cond_27 : tensor<1x16xi1> loc(#loc305) + %cond_29 = arith.ori %1, %cond_28 : tensor<1x16xi1> loc(#loc306) + %cond_30 = arith.cmpi ugt, %right_valid_mask_26, %left_valid_mask_25 : tensor<1x16xi1> loc(#loc307) + %cond_31 = arith.cmpi eq, %right_valid_mask_26, %left_valid_mask_25 : tensor<1x16xi1> loc(#loc308) + %cond_32 = arith.andi %cond_31, %cond_29 : tensor<1x16xi1> loc(#loc309) + %cond_33 = arith.ori %cond_30, %cond_32 : tensor<1x16xi1> loc(#loc310) + %cond_34 = tt.splat %flip : i1 -> tensor<1x16xi1> loc(#loc311) + %cond_35 = arith.xori %cond_33, %cond_34 : tensor<1x16xi1> loc(#loc311) + %ret = arith.xori %ileft_13, %iright_14 : tensor<1x16xi32> loc(#loc313) + %ret_36 = tt.call @triton.language.standard.zeros_like__i32S1_16S__(%x) : (tensor<1x16xi32>) -> tensor<1x16xi32> loc(#loc314) + %ret_37 = arith.select %cond_35, %ret, %ret_36 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc315) + %ret_38 = arith.xori %x, %ret_37 : tensor<1x16xi32> loc(#loc316) + %new_idxs = arith.xori %left_idx_23, %right_idx_24 : tensor<1x16xi32> loc(#loc317) + %new_idxs_39 = tt.call @triton.language.standard.zeros_like__i32S1_16S__(%idxs) : (tensor<1x16xi32>) -> tensor<1x16xi32> loc(#loc318) + %new_idxs_40 = arith.select %cond_35, %new_idxs, %new_idxs_39 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc319) + %new_idxs_41 = arith.xori %idxs, %new_idxs_40 : tensor<1x16xi32> loc(#loc320) + tt.return %ret_38, %new_idxs_41 : tensor<1x16xi32>, tensor<1x16xi32> loc(#loc163) + ^bb1: // no predecessors + %4 = ub.poison : tensor<1x16xi32> loc(#loc164) + %5 = ub.poison : tensor<1x16xi32> loc(#loc164) + tt.return %4, %5 : tensor<1x16xi32>, tensor<1x16xi32> loc(#loc164) + } loc(#loc102) + tt.func private @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S1_16S_i32S1_16S_u1__(2,)cconstexpr_None__(4,)cconstexpr_2__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x: tensor<1x16xi32> loc("x"(#loc102)), %idxs: tensor<1x16xi32> loc("idxs"(#loc102)), %flip: i1 loc("flip"(#loc102))) -> (tensor<1x16xi32>, tensor<1x16xi32>) attributes {noinline = false} { + %y = tt.reshape %x : tensor<1x16xi32> -> tensor<4x2x2xi32> loc(#loc266) + %right_mask = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc267) + %right_mask_0 = tt.expand_dims %right_mask {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc268) + %right_mask_1 = tt.expand_dims %right_mask_0 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc268) + %left_mask = arith.constant 1 : i32 loc(#loc269) + %left_mask_2 = arith.constant 1 : i32 loc(#loc269) + %left_mask_3 = arith.constant dense<1> : tensor<1x2x1xi32> loc(#loc269) + %left_mask_4 = arith.subi %left_mask_3, %right_mask_1 : tensor<1x2x1xi32> loc(#loc269) + %ileft = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<4x2x2xi32> loc(#loc270) + %ileft_5 = arith.muli %y, %ileft : tensor<4x2x2xi32> loc(#loc270) + %ileft_6 = tt.call @"triton.language.standard.sum__i32S4_2_2S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%ileft_5) : (tensor<4x2x2xi32>) -> tensor<4x2xi32> loc(#loc271) + %ileft_7 = tt.expand_dims %ileft_6 {axis = 1 : i32} : tensor<4x2xi32> -> tensor<4x1x2xi32> loc(#loc272) + %ileft_8 = tt.broadcast %ileft_7 : tensor<4x1x2xi32> -> tensor<4x2x2xi32> loc(#loc273) + %iright = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<4x2x2xi32> loc(#loc274) + %iright_9 = arith.muli %y, %iright : tensor<4x2x2xi32> loc(#loc274) + %iright_10 = tt.call @"triton.language.standard.sum__i32S4_2_2S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%iright_9) : (tensor<4x2x2xi32>) -> tensor<4x2xi32> loc(#loc275) + %iright_11 = tt.expand_dims %iright_10 {axis = 1 : i32} : tensor<4x2xi32> -> tensor<4x1x2xi32> loc(#loc276) + %iright_12 = tt.broadcast %iright_11 : tensor<4x1x2xi32> -> tensor<4x2x2xi32> loc(#loc277) + %ileft_13 = tt.reshape %ileft_8 : tensor<4x2x2xi32> -> tensor<1x16xi32> loc(#loc278) + %iright_14 = tt.reshape %iright_12 : tensor<4x2x2xi32> -> tensor<1x16xi32> loc(#loc279) + %y_idx = tt.reshape %idxs : tensor<1x16xi32> -> tensor<4x2x2xi32> loc(#loc280) + %left_idx = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<4x2x2xi32> loc(#loc282) + %left_idx_15 = arith.muli %y_idx, %left_idx : tensor<4x2x2xi32> loc(#loc282) + %left_idx_16 = tt.call @"triton.language.standard.sum__i32S4_2_2S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%left_idx_15) : (tensor<4x2x2xi32>) -> tensor<4x2xi32> loc(#loc283) + %left_idx_17 = tt.expand_dims %left_idx_16 {axis = 1 : i32} : tensor<4x2xi32> -> tensor<4x1x2xi32> loc(#loc284) + %left_idx_18 = tt.broadcast %left_idx_17 : tensor<4x1x2xi32> -> tensor<4x2x2xi32> loc(#loc285) + %right_idx = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<4x2x2xi32> loc(#loc287) + %right_idx_19 = arith.muli %y_idx, %right_idx : tensor<4x2x2xi32> loc(#loc287) + %right_idx_20 = tt.call @"triton.language.standard.sum__i32S4_2_2S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%right_idx_19) : (tensor<4x2x2xi32>) -> tensor<4x2xi32> loc(#loc288) + %right_idx_21 = tt.expand_dims %right_idx_20 {axis = 1 : i32} : tensor<4x2xi32> -> tensor<4x1x2xi32> loc(#loc289) + %right_idx_22 = tt.broadcast %right_idx_21 : tensor<4x1x2xi32> -> tensor<4x2x2xi32> loc(#loc290) + %left_idx_23 = tt.reshape %left_idx_18 : tensor<4x2x2xi32> -> tensor<1x16xi32> loc(#loc291) + %right_idx_24 = tt.reshape %right_idx_22 : tensor<4x2x2xi32> -> tensor<1x16xi32> loc(#loc292) + %left_valid_mask = arith.constant true loc(#loc293) + %left_valid_mask_25 = arith.constant dense : tensor<1x16xi1> loc(#loc293) + %right_valid_mask = arith.constant true loc(#loc294) + %right_valid_mask_26 = arith.constant dense : tensor<1x16xi1> loc(#loc294) + %left_isnan = arith.cmpi ne, %ileft_13, %ileft_13 : tensor<1x16xi32> loc(#loc295) + %right_isnan = arith.cmpi ne, %iright_14, %iright_14 : tensor<1x16xi32> loc(#loc296) + %cond = arith.cmpi slt, %ileft_13, %iright_14 : tensor<1x16xi32> loc(#loc329) + %0 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S1_16S__(%ileft_13) : (tensor<1x16xi32>) -> i1 loc(#loc135) + %1 = scf.if %0 -> (tensor<1x16xi1>) { + %cond_42 = arith.constant true loc(#loc298) + %cond_43 = arith.constant dense : tensor<1x16xi1> loc(#loc298) + %cond_44 = arith.xori %left_isnan, %cond_43 : tensor<1x16xi1> loc(#loc298) + %cond_45 = arith.andi %right_isnan, %cond_44 : tensor<1x16xi1> loc(#loc299) + %cond_46 = arith.ori %cond, %cond_45 : tensor<1x16xi1> loc(#loc330) + scf.yield %cond_46 : tensor<1x16xi1> loc(#loc330) + } else { + scf.yield %cond : tensor<1x16xi1> loc(#loc140) + } loc(#loc136) + %eq = arith.cmpi eq, %ileft_13, %iright_14 : tensor<1x16xi32> loc(#loc331) + %2 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S1_16S__(%ileft_13) : (tensor<1x16xi32>) -> i1 loc(#loc142) + %3 = scf.if %2 -> (tensor<1x16xi1>) { + %eq_42 = arith.andi %left_isnan, %right_isnan : tensor<1x16xi1> loc(#loc302) + %eq_43 = arith.ori %eq, %eq_42 : tensor<1x16xi1> loc(#loc332) + scf.yield %eq_43 : tensor<1x16xi1> loc(#loc332) + } else { + scf.yield %eq : tensor<1x16xi1> loc(#loc140) + } loc(#loc143) + %cond_27 = arith.cmpi sgt, %left_idx_23, %right_idx_24 : tensor<1x16xi32> loc(#loc304) + %cond_28 = arith.andi %3, %cond_27 : tensor<1x16xi1> loc(#loc305) + %cond_29 = arith.ori %1, %cond_28 : tensor<1x16xi1> loc(#loc306) + %cond_30 = arith.cmpi ugt, %right_valid_mask_26, %left_valid_mask_25 : tensor<1x16xi1> loc(#loc307) + %cond_31 = arith.cmpi eq, %right_valid_mask_26, %left_valid_mask_25 : tensor<1x16xi1> loc(#loc308) + %cond_32 = arith.andi %cond_31, %cond_29 : tensor<1x16xi1> loc(#loc309) + %cond_33 = arith.ori %cond_30, %cond_32 : tensor<1x16xi1> loc(#loc310) + %cond_34 = tt.splat %flip : i1 -> tensor<1x16xi1> loc(#loc311) + %cond_35 = arith.xori %cond_33, %cond_34 : tensor<1x16xi1> loc(#loc311) + %ret = arith.xori %ileft_13, %iright_14 : tensor<1x16xi32> loc(#loc313) + %ret_36 = tt.call @triton.language.standard.zeros_like__i32S1_16S__(%x) : (tensor<1x16xi32>) -> tensor<1x16xi32> loc(#loc314) + %ret_37 = arith.select %cond_35, %ret, %ret_36 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc315) + %ret_38 = arith.xori %x, %ret_37 : tensor<1x16xi32> loc(#loc316) + %new_idxs = arith.xori %left_idx_23, %right_idx_24 : tensor<1x16xi32> loc(#loc317) + %new_idxs_39 = tt.call @triton.language.standard.zeros_like__i32S1_16S__(%idxs) : (tensor<1x16xi32>) -> tensor<1x16xi32> loc(#loc318) + %new_idxs_40 = arith.select %cond_35, %new_idxs, %new_idxs_39 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc319) + %new_idxs_41 = arith.xori %idxs, %new_idxs_40 : tensor<1x16xi32> loc(#loc320) + tt.return %ret_38, %new_idxs_41 : tensor<1x16xi32>, tensor<1x16xi32> loc(#loc163) + ^bb1: // no predecessors + %4 = ub.poison : tensor<1x16xi32> loc(#loc164) + %5 = ub.poison : tensor<1x16xi32> loc(#loc164) + tt.return %4, %5 : tensor<1x16xi32>, tensor<1x16xi32> loc(#loc164) + } loc(#loc102) + tt.func private @"torch._inductor.runtime.triton_helpers._compare_and_swap_with_index__i32S1_16S_i32S1_16S_u1__(2,)cconstexpr_None__(4,)cconstexpr_3__(5,)cconstexpr_4__(6,)cconstexpr_True__(7,)cconstexpr_True_"(%x: tensor<1x16xi32> loc("x"(#loc102)), %idxs: tensor<1x16xi32> loc("idxs"(#loc102)), %flip: i1 loc("flip"(#loc102))) -> (tensor<1x16xi32>, tensor<1x16xi32>) attributes {noinline = false} { + %y = tt.reshape %x : tensor<1x16xi32> -> tensor<8x2x1xi32> loc(#loc266) + %right_mask = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc267) + %right_mask_0 = tt.expand_dims %right_mask {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc268) + %right_mask_1 = tt.expand_dims %right_mask_0 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc268) + %left_mask = arith.constant 1 : i32 loc(#loc269) + %left_mask_2 = arith.constant 1 : i32 loc(#loc269) + %left_mask_3 = arith.constant dense<1> : tensor<1x2x1xi32> loc(#loc269) + %left_mask_4 = arith.subi %left_mask_3, %right_mask_1 : tensor<1x2x1xi32> loc(#loc269) + %ileft = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<8x2x1xi32> loc(#loc270) + %ileft_5 = arith.muli %y, %ileft : tensor<8x2x1xi32> loc(#loc270) + %ileft_6 = tt.call @"triton.language.standard.sum__i32S8_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%ileft_5) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc271) + %ileft_7 = tt.expand_dims %ileft_6 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc272) + %ileft_8 = tt.broadcast %ileft_7 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc273) + %iright = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<8x2x1xi32> loc(#loc274) + %iright_9 = arith.muli %y, %iright : tensor<8x2x1xi32> loc(#loc274) + %iright_10 = tt.call @"triton.language.standard.sum__i32S8_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%iright_9) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc275) + %iright_11 = tt.expand_dims %iright_10 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc276) + %iright_12 = tt.broadcast %iright_11 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc277) + %ileft_13 = tt.reshape %ileft_8 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc278) + %iright_14 = tt.reshape %iright_12 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc279) + %y_idx = tt.reshape %idxs : tensor<1x16xi32> -> tensor<8x2x1xi32> loc(#loc280) + %left_idx = tt.broadcast %left_mask_4 : tensor<1x2x1xi32> -> tensor<8x2x1xi32> loc(#loc282) + %left_idx_15 = arith.muli %y_idx, %left_idx : tensor<8x2x1xi32> loc(#loc282) + %left_idx_16 = tt.call @"triton.language.standard.sum__i32S8_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%left_idx_15) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc283) + %left_idx_17 = tt.expand_dims %left_idx_16 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc284) + %left_idx_18 = tt.broadcast %left_idx_17 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc285) + %right_idx = tt.broadcast %right_mask_1 : tensor<1x2x1xi32> -> tensor<8x2x1xi32> loc(#loc287) + %right_idx_19 = arith.muli %y_idx, %right_idx : tensor<8x2x1xi32> loc(#loc287) + %right_idx_20 = tt.call @"triton.language.standard.sum__i32S8_2_1S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%right_idx_19) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc288) + %right_idx_21 = tt.expand_dims %right_idx_20 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc289) + %right_idx_22 = tt.broadcast %right_idx_21 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc290) + %left_idx_23 = tt.reshape %left_idx_18 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc291) + %right_idx_24 = tt.reshape %right_idx_22 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc292) + %left_valid_mask = arith.constant true loc(#loc293) + %left_valid_mask_25 = arith.constant dense : tensor<1x16xi1> loc(#loc293) + %right_valid_mask = arith.constant true loc(#loc294) + %right_valid_mask_26 = arith.constant dense : tensor<1x16xi1> loc(#loc294) + %left_isnan = arith.cmpi ne, %ileft_13, %ileft_13 : tensor<1x16xi32> loc(#loc295) + %right_isnan = arith.cmpi ne, %iright_14, %iright_14 : tensor<1x16xi32> loc(#loc296) + %cond = arith.cmpi slt, %ileft_13, %iright_14 : tensor<1x16xi32> loc(#loc329) + %0 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S1_16S__(%ileft_13) : (tensor<1x16xi32>) -> i1 loc(#loc135) + %1 = scf.if %0 -> (tensor<1x16xi1>) { + %cond_42 = arith.constant true loc(#loc298) + %cond_43 = arith.constant dense : tensor<1x16xi1> loc(#loc298) + %cond_44 = arith.xori %left_isnan, %cond_43 : tensor<1x16xi1> loc(#loc298) + %cond_45 = arith.andi %right_isnan, %cond_44 : tensor<1x16xi1> loc(#loc299) + %cond_46 = arith.ori %cond, %cond_45 : tensor<1x16xi1> loc(#loc330) + scf.yield %cond_46 : tensor<1x16xi1> loc(#loc330) + } else { + scf.yield %cond : tensor<1x16xi1> loc(#loc140) + } loc(#loc136) + %eq = arith.cmpi eq, %ileft_13, %iright_14 : tensor<1x16xi32> loc(#loc331) + %2 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__i32S1_16S__(%ileft_13) : (tensor<1x16xi32>) -> i1 loc(#loc142) + %3 = scf.if %2 -> (tensor<1x16xi1>) { + %eq_42 = arith.andi %left_isnan, %right_isnan : tensor<1x16xi1> loc(#loc302) + %eq_43 = arith.ori %eq, %eq_42 : tensor<1x16xi1> loc(#loc332) + scf.yield %eq_43 : tensor<1x16xi1> loc(#loc332) + } else { + scf.yield %eq : tensor<1x16xi1> loc(#loc140) + } loc(#loc143) + %cond_27 = arith.cmpi sgt, %left_idx_23, %right_idx_24 : tensor<1x16xi32> loc(#loc304) + %cond_28 = arith.andi %3, %cond_27 : tensor<1x16xi1> loc(#loc305) + %cond_29 = arith.ori %1, %cond_28 : tensor<1x16xi1> loc(#loc306) + %cond_30 = arith.cmpi ugt, %right_valid_mask_26, %left_valid_mask_25 : tensor<1x16xi1> loc(#loc307) + %cond_31 = arith.cmpi eq, %right_valid_mask_26, %left_valid_mask_25 : tensor<1x16xi1> loc(#loc308) + %cond_32 = arith.andi %cond_31, %cond_29 : tensor<1x16xi1> loc(#loc309) + %cond_33 = arith.ori %cond_30, %cond_32 : tensor<1x16xi1> loc(#loc310) + %cond_34 = tt.splat %flip : i1 -> tensor<1x16xi1> loc(#loc311) + %cond_35 = arith.xori %cond_33, %cond_34 : tensor<1x16xi1> loc(#loc311) + %ret = arith.xori %ileft_13, %iright_14 : tensor<1x16xi32> loc(#loc313) + %ret_36 = tt.call @triton.language.standard.zeros_like__i32S1_16S__(%x) : (tensor<1x16xi32>) -> tensor<1x16xi32> loc(#loc314) + %ret_37 = arith.select %cond_35, %ret, %ret_36 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc315) + %ret_38 = arith.xori %x, %ret_37 : tensor<1x16xi32> loc(#loc316) + %new_idxs = arith.xori %left_idx_23, %right_idx_24 : tensor<1x16xi32> loc(#loc317) + %new_idxs_39 = tt.call @triton.language.standard.zeros_like__i32S1_16S__(%idxs) : (tensor<1x16xi32>) -> tensor<1x16xi32> loc(#loc318) + %new_idxs_40 = arith.select %cond_35, %new_idxs, %new_idxs_39 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc319) + %new_idxs_41 = arith.xori %idxs, %new_idxs_40 : tensor<1x16xi32> loc(#loc320) + tt.return %ret_38, %new_idxs_41 : tensor<1x16xi32>, tensor<1x16xi32> loc(#loc163) + ^bb1: // no predecessors + %4 = ub.poison : tensor<1x16xi32> loc(#loc164) + %5 = ub.poison : tensor<1x16xi32> loc(#loc164) + tt.return %4, %5 : tensor<1x16xi32>, tensor<1x16xi32> loc(#loc164) + } loc(#loc102) + tt.func private @"triton.language.standard.sum__i64S1_16S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%input: tensor<1x16xi64> loc("input"(#loc165))) -> tensor<1xi64> attributes {noinline = false} { + %0 = "tt.reduce"(%input) <{axis = 1 : i32}> ({ + ^bb0(%arg1: i64 loc(unknown), %arg2: i64 loc(unknown)): + %2 = tt.call @triton.language.standard._sum_combine__i64_i64__(%arg1, %arg2) : (i64, i64) -> i64 loc(#loc166) + tt.reduce.return %2 : i64 loc(#loc166) + }) : (tensor<1x16xi64>) -> tensor<1xi64> loc(#loc166) + tt.return %0 : tensor<1xi64> loc(#loc167) + ^bb1: // no predecessors + %1 = ub.poison : tensor<1xi64> loc(#loc168) + tt.return %1 : tensor<1xi64> loc(#loc168) + } loc(#loc165) + tt.func private @triton.language.standard._sum_combine__i64_i64__(%a: i64 loc("a"(#loc169)), %b: i64 loc("b"(#loc169))) -> i64 attributes {noinline = false} { + %0 = arith.addi %a, %b : i64 loc(#loc170) + tt.return %0 : i64 loc(#loc171) + ^bb1: // no predecessors + %1 = ub.poison : i64 loc(#loc172) + tt.return %1 : i64 loc(#loc172) + } loc(#loc169) +} loc(#loc) +#loc1 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":19:13) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":20:15) +#loc3 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":24:28) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":24:33) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":25:36) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":25:44) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":25:23) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":26:21) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":27:28) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":27:38) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":28:16) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":29:48) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":34:40) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":34:37) +#loc15 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":34:30) +#loc16 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":34:45) +#loc17 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":35:30) +#loc18 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":36:18) +#loc19 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":37:34) +#loc20 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":38:18) +#loc21 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":39:18) +#loc22 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":40:19) +#loc23 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":41:19) +#loc24 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":43:19) +#loc25 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":46:71) +#loc26 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":47:20) +#loc27 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":48:21) +#loc28 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":49:21) +#loc29 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":51:71) +#loc30 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":52:20) +#loc31 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":54:35) +#loc32 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":55:26) +#loc33 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":55:29) +#loc34 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":56:21) +#loc35 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":58:35) +#loc36 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":59:26) +#loc37 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":59:29) +#loc38 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":60:21) +#loc39 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":61:21) +#loc40 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":62:21) +#loc41 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":63:21) +#loc42 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":64:19) +#loc43 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":65:32) +#loc44 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":66:35) +#loc45 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":67:44) +#loc46 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":68:20) +#loc47 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":69:20) +#loc48 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":70:35) +#loc49 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":71:28) +#loc50 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":71:46) +#loc51 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":71:38) +#loc52 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":71:55) +#loc53 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":71:53) +#loc54 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":71:63) +#loc55 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":72:31) +#loc56 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":73:21) +#loc57 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":74:21) +#loc58 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":75:19) +#loc59 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":76:35) +#loc60 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":77:20) +#loc61 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":78:20) +#loc62 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":79:35) +#loc63 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":80:28) +#loc64 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":80:46) +#loc65 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":80:38) +#loc66 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":80:55) +#loc67 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":80:53) +#loc68 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":80:63) +#loc69 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":81:25) +#loc70 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":81:37) +#loc71 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":82:25) +#loc72 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":82:37) +#loc73 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":83:35) +#loc74 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":83:32) +#loc75 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":83:25) +#loc76 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":83:47) +#loc77 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":84:52) +#loc78 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":84:49) +#loc79 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":84:25) +#loc80 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":84:85) +#loc81 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":85:35) +#loc82 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":85:32) +#loc83 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":85:25) +#loc84 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":85:47) +#loc85 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":86:52) +#loc86 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":86:49) +#loc87 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":86:25) +#loc88 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":86:85) +#loc89 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":86:4) +#loc91 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":662:12) +#loc92 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":668:11) +#loc93 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":668:4) +#loc95 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":627:41) +#loc96 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":627:44) +#loc97 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":627:60) +#loc98 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":627:68) +#loc99 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":634:73) +#loc100 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":636:11) +#loc101 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":636:4) +#loc103 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":533:22) +#loc104 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":536:30) +#loc105 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":536:33) +#loc106 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":537:21) +#loc107 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":538:40) +#loc108 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":538:51) +#loc109 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":538:65) +#loc110 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":538:78) +#loc111 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":539:41) +#loc112 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":539:53) +#loc113 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":539:67) +#loc114 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":539:80) +#loc115 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":540:30) +#loc116 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":541:32) +#loc117 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":546:29) +#loc118 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:36) +#loc119 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:23) +#loc120 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:50) +#loc121 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:53) +#loc122 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:66) +#loc123 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:37) +#loc124 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:23) +#loc125 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:51) +#loc126 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:54) +#loc127 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:67) +#loc128 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":553:36) +#loc129 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":554:38) +#loc130 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":558:49) +#loc131 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":559:50) +#loc132 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":570:25) +#loc133 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":571:27) +#loc134 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":574:22) +#loc135 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":575:23) +#loc136 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":575:11) +#loc137 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":579:47) +#loc138 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":579:46) +#loc139 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":579:31) +#loc141 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":591:21) +#loc142 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":592:23) +#loc143 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":592:11) +#loc144 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":593:36) +#loc145 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":593:23) +#loc146 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":594:40) +#loc147 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":594:29) +#loc148 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":594:23) +#loc149 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":596:31) +#loc150 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":597:29) +#loc151 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":597:48) +#loc152 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":597:8) +#loc153 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":599:19) +#loc154 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":599:28) +#loc155 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":600:38) +#loc156 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":600:60) +#loc157 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":600:46) +#loc158 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":600:15) +#loc159 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":601:48) +#loc160 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":601:73) +#loc161 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":601:59) +#loc162 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":601:22) +#loc163 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":603:11) +#loc164 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":603:4) +#loc166 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:36) +#loc167 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:11) +#loc168 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:4) +#loc170 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:15) +#loc171 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:11) +#loc172 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:4) +#loc173 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":290:25) +#loc175 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":87:29) +#loc176 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":87:11) +#loc177 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":87:4) +#loc179 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":65:30) +#loc180 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":65:15) +#loc181 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":65:11) +#loc182 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":65:4) +#loc183 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":118:0) +#loc184 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":127:31) +#loc185 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":127:11) +#loc186 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":127:4) +#loc188 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":138:30) +#loc189 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":138:11) +#loc190 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":138:4) +#loc191 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":630:15) +#loc201 = loc("xnumel"(#loc1)) +#loc202 = loc("r0_numel"(#loc2)) +#loc203 = loc("xoffset"(#loc3)) +#loc204 = loc("xoffset"(#loc4)) +#loc205 = loc("xindex"(#loc5)) +#loc206 = loc("xindex"(#loc6)) +#loc207 = loc("xindex"(#loc7)) +#loc208 = loc("xmask"(#loc8)) +#loc209 = loc("r0_index"(#loc9)) +#loc210 = loc("r0_index"(#loc10)) +#loc211 = loc("r0_offset"(#loc11)) +#loc212 = loc("r0_mask"(#loc12)) +#loc213 = loc("tmp0"(#loc13)) +#loc214 = loc("tmp0"(#loc14)) +#loc215 = loc("tmp0"(#loc15)) +#loc216 = loc("tmp0"(#loc16)) +#loc217 = loc("tmp1"(#loc17)) +#loc218 = loc("tmp2"(#loc18)) +#loc219 = loc("tmp3"(#loc19)) +#loc220 = loc("tmp4"(#loc20)) +#loc221 = loc("tmp5"(#loc21)) +#loc222 = loc("tmp6"(#loc22)) +#loc223 = loc("tmp7"(#loc23)) +#loc224 = loc("tmp9"(#loc24)) +#loc225 = loc("tmp14"(#loc26)) +#loc226 = loc("tmp15"(#loc27)) +#loc227 = loc("tmp16"(#loc28)) +#loc228 = loc("tmp20"(#loc30)) +#loc229 = loc("tmp23"(#loc31)) +#loc230 = loc("tmp24"(#loc32)) +#loc231 = loc("tmp24"(#loc33)) +#loc232 = loc("tmp25"(#loc34)) +#loc233 = loc("tmp28"(#loc35)) +#loc234 = loc("tmp29"(#loc36)) +#loc235 = loc("tmp29"(#loc37)) +#loc236 = loc("tmp30"(#loc38)) +#loc237 = loc("tmp31"(#loc39)) +#loc238 = loc("tmp32"(#loc40)) +#loc239 = loc("tmp33"(#loc41)) +#loc240 = loc("tmp34"(#loc42)) +#loc241 = loc("tmp35"(#loc43)) +#loc242 = loc("tmp36"(#loc44)) +#loc243 = loc("tmp37"(#loc45)) +#loc244 = loc("tmp38"(#loc46)) +#loc245 = loc("tmp39"(#loc47)) +#loc246 = loc("tmp40"(#loc48)) +#loc247 = loc("tmp42"(#loc55)) +#loc248 = loc("tmp43"(#loc56)) +#loc249 = loc("tmp44"(#loc57)) +#loc250 = loc("tmp45"(#loc58)) +#loc251 = loc("tmp46"(#loc59)) +#loc252 = loc("tmp47"(#loc60)) +#loc253 = loc("tmp48"(#loc61)) +#loc254 = loc("tmp49"(#loc62)) +#loc259 = loc("flip"(#loc95)) +#loc260 = loc("flip"(#loc96)) +#loc261 = loc("flip"(#loc97)) +#loc262 = loc("flip"(#loc98)) +#loc266 = loc("y"(#loc103)) +#loc267 = loc("right_mask"(#loc104)) +#loc268 = loc("right_mask"(#loc105)) +#loc269 = loc("left_mask"(#loc106)) +#loc270 = loc("ileft"(#loc107)) +#loc271 = loc("ileft"(#loc108)) +#loc272 = loc("ileft"(#loc109)) +#loc273 = loc("ileft"(#loc110)) +#loc274 = loc("iright"(#loc111)) +#loc275 = loc("iright"(#loc112)) +#loc276 = loc("iright"(#loc113)) +#loc277 = loc("iright"(#loc114)) +#loc278 = loc("ileft"(#loc115)) +#loc279 = loc("iright"(#loc116)) +#loc280 = loc("y_idx"(#loc117)) +#loc281 = loc("left_idx"(#loc118)) +#loc282 = loc("left_idx"(#loc119)) +#loc283 = loc("left_idx"(#loc120)) +#loc284 = loc("left_idx"(#loc121)) +#loc285 = loc("left_idx"(#loc122)) +#loc286 = loc("right_idx"(#loc123)) +#loc287 = loc("right_idx"(#loc124)) +#loc288 = loc("right_idx"(#loc125)) +#loc289 = loc("right_idx"(#loc126)) +#loc290 = loc("right_idx"(#loc127)) +#loc291 = loc("left_idx"(#loc128)) +#loc292 = loc("right_idx"(#loc129)) +#loc293 = loc("left_valid_mask"(#loc130)) +#loc294 = loc("right_valid_mask"(#loc131)) +#loc295 = loc("left_isnan"(#loc132)) +#loc296 = loc("right_isnan"(#loc133)) +#loc297 = loc("cond"(#loc134)) +#loc298 = loc("cond"(#loc137)) +#loc299 = loc("cond"(#loc138)) +#loc300 = loc("cond"(#loc139)) +#loc301 = loc("eq"(#loc141)) +#loc302 = loc("eq"(#loc144)) +#loc303 = loc("eq"(#loc145)) +#loc304 = loc("cond"(#loc146)) +#loc305 = loc("cond"(#loc147)) +#loc306 = loc("cond"(#loc148)) +#loc307 = loc("cond"(#loc149)) +#loc308 = loc("cond"(#loc150)) +#loc309 = loc("cond"(#loc151)) +#loc310 = loc("cond"(#loc152)) +#loc311 = loc("cond"(#loc153)) +#loc312 = loc("cond"(#loc154)) +#loc313 = loc("ret"(#loc155)) +#loc314 = loc("ret"(#loc156)) +#loc315 = loc("ret"(#loc157)) +#loc316 = loc("ret"(#loc158)) +#loc317 = loc("new_idxs"(#loc159)) +#loc318 = loc("new_idxs"(#loc160)) +#loc319 = loc("new_idxs"(#loc161)) +#loc320 = loc("new_idxs"(#loc162)) +#loc324 = loc("input"(#loc173)) +#loc328 = loc("flip"(#loc191)) +#loc329 = loc("cond"(#loc297)) +#loc330 = loc("cond"(#loc300)) +#loc331 = loc("eq"(#loc301)) +#loc332 = loc("eq"(#loc303)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/NFXDDTZIEQAGR5JBWWCXV73QZILXJMIVVJVPZH3CHAIULDAND5UQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ttgir b/SpecForge-ext/cache/compiled_kernels/triton/0/NFXDDTZIEQAGR5JBWWCXV73QZILXJMIVVJVPZH3CHAIULDAND5UQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ttgir new file mode 100644 index 0000000000000000000000000000000000000000..8073133e74952df827e8826026f9a1da9da59b30 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/NFXDDTZIEQAGR5JBWWCXV73QZILXJMIVVJVPZH3CHAIULDAND5UQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ttgir @@ -0,0 +1,1487 @@ +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 2], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [2, 2, 8], warpsPerCTA = [2, 1, 1], order = [2, 1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [4, 2, 4], warpsPerCTA = [2, 1, 1], order = [2, 1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [8, 2, 2], warpsPerCTA = [2, 1, 1], order = [2, 1, 0]}> +#blocked4 = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [16, 2, 1], warpsPerCTA = [2, 1, 1], order = [2, 1, 0]}> +#blocked5 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [2, 1], order = [1, 0]}> +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":18:0) +#loc1 = loc(unknown) +#loc16 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":662:12) +#loc17 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":46:71) +#loc21 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":634:73) +#loc25 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":538:51) +#loc30 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":539:53) +#loc39 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:50) +#loc44 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:51) +#loc65 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":51:71) +#loc68 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":55:26) +#loc72 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":59:26) +#loc113 = loc("in_ptr0"(#loc)) +#loc114 = loc("out_ptr4"(#loc)) +#loc115 = loc("out_ptr5"(#loc)) +#loc116 = loc("out_ptr6"(#loc)) +#loc117 = loc("out_ptr7"(#loc)) +#loc118 = loc("out_ptr8"(#loc)) +#loc119 = loc("out_ptr9"(#loc)) +#loc120 = loc("xnumel"(#loc)) +#loc121 = loc("r0_numel"(#loc)) +#loc136 = loc(callsite(#loc16 at #loc17)) +#loc142 = loc("ileft"(#loc25)) +#loc146 = loc("iright"(#loc30)) +#loc155 = loc("left_idx"(#loc39)) +#loc160 = loc("right_idx"(#loc44)) +#loc181 = loc(callsite(#loc16 at #loc65)) +#loc184 = loc("tmp24"(#loc68)) +#loc188 = loc("tmp29"(#loc72)) +#loc210 = loc(callsite(#loc21 at #loc136)) +#loc214 = loc(callsite(#loc21 at #loc181)) +#loc217 = loc(callsite(#loc1 at #loc184)) +#loc220 = loc(callsite(#loc1 at #loc188)) +#loc225 = loc(callsite(#loc142 at #loc210)) +#loc229 = loc(callsite(#loc146 at #loc210)) +#loc237 = loc(callsite(#loc155 at #loc210)) +#loc242 = loc(callsite(#loc160 at #loc210)) +#loc262 = loc(callsite(#loc142 at #loc214)) +#loc266 = loc(callsite(#loc146 at #loc214)) +#loc284 = loc(callsite(#loc155 at #loc214)) +#loc288 = loc(callsite(#loc160 at #loc214)) +#loc298 = loc(callsite(#loc1 at #loc225)) +#loc300 = loc(callsite(#loc1 at #loc229)) +#loc303 = loc(callsite(#loc1 at #loc237)) +#loc306 = loc(callsite(#loc1 at #loc242)) +#loc308 = loc(callsite(#loc1 at #loc262)) +#loc310 = loc(callsite(#loc1 at #loc266)) +#loc312 = loc(callsite(#loc1 at #loc284)) +#loc314 = loc(callsite(#loc1 at #loc288)) +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %out_ptr4: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr4"(#loc)), %out_ptr5: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr5"(#loc)), %out_ptr6: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr6"(#loc)), %out_ptr7: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr7"(#loc)), %out_ptr8: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr8"(#loc)), %out_ptr9: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr9"(#loc)), %xnumel: i32 {tt.divisibility = 16 : i32} loc("xnumel"(#loc)), %r0_numel: i32 {tt.divisibility = 16 : i32} loc("r0_numel"(#loc))) attributes {noinline = false} { + %cst = arith.constant dense<0> : tensor<1x16xi64, #blocked> loc(#loc1) + %cst_0 = arith.constant dense<16384> : tensor<1x16xi64, #blocked> loc(#loc1) + %cst_1 = arith.constant dense<1> : tensor<1x2x1xi32, #blocked1> loc(#loc1) + %cst_2 = arith.constant dense<1> : tensor<1x2x1xi32, #blocked2> loc(#loc1) + %cst_3 = arith.constant dense<1> : tensor<1x2x1xi32, #blocked3> loc(#loc1) + %cst_4 = arith.constant dense<1> : tensor<1x2x1xi32, #blocked4> loc(#loc1) + %c32_i32 = arith.constant 32 : i32 loc(#loc1) + %c16_i32 = arith.constant 16 : i32 loc(#loc1) + %c17_i32 = arith.constant 17 : i32 loc(#loc1) + %cst_5 = arith.constant dense<1> : tensor<1x16xi32, #blocked> loc(#loc1) + %cst_6 = arith.constant dense<0> : tensor<1x16xi32, #blocked5> loc(#loc1) + %cst_7 = arith.constant dense<17> : tensor<1x16xi32, #blocked5> loc(#loc1) + %cst_8 = arith.constant dense<16> : tensor<1x16xi32, #blocked5> loc(#loc1) + %cst_9 = arith.constant dense<16384> : tensor<1x16xi64, #blocked5> loc(#loc1) + %cst_10 = arith.constant dense<0> : tensor<1x16xi64, #blocked5> loc(#loc1) + %xoffset = tt.get_program_id x : i32 loc(#loc122) + %xmask = arith.cmpi slt, %xoffset, %c32_i32 : i32 loc(#loc123) + %r0_index = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked5}>> loc(#loc124) + %r0_index_11 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc124) + %r0_index_12 = tt.expand_dims %r0_index {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked5}>> -> tensor<1x16xi32, #blocked5> loc(#loc124) + %r0_index_13 = tt.expand_dims %r0_index_11 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> loc(#loc124) + %tmp0 = arith.muli %xoffset, %c16_i32 : i32 loc(#loc125) + %tmp0_14 = tt.splat %tmp0 : i32 -> tensor<1x16xi32, #blocked5> loc(#loc204) + %tmp0_15 = tt.splat %tmp0 : i32 -> tensor<1x16xi32, #blocked> loc(#loc204) + %tmp0_16 = arith.addi %r0_index_12, %tmp0_14 : tensor<1x16xi32, #blocked5> loc(#loc126) + %tmp0_17 = arith.addi %r0_index_13, %tmp0_15 : tensor<1x16xi32, #blocked> loc(#loc126) + %tmp0_18 = tt.splat %in_ptr0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked5> loc(#loc127) + %tmp0_19 = tt.splat %in_ptr0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> loc(#loc127) + %tmp0_20 = tt.addptr %tmp0_18, %tmp0_16 : tensor<1x16x!tt.ptr, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc127) + %tmp0_21 = tt.addptr %tmp0_19, %tmp0_17 : tensor<1x16x!tt.ptr, #blocked>, tensor<1x16xi32, #blocked> loc(#loc127) + %tmp0_22 = tt.splat %xmask : i1 -> tensor<1x16xi1, #blocked5> loc(#loc205) + %tmp0_23 = tt.splat %xmask : i1 -> tensor<1x16xi1, #blocked> loc(#loc205) + %tmp0_24 = tt.load %tmp0_20, %tmp0_22, %cst_10 : tensor<1x16x!tt.ptr, #blocked5> loc(#loc128) + %tmp0_25 = tt.load %tmp0_21, %tmp0_23, %cst : tensor<1x16x!tt.ptr, #blocked> loc(#loc128) + %tmp2 = arith.cmpi sgt, %tmp0_24, %cst_10 : tensor<1x16xi64, #blocked5> loc(#loc129) + %tmp2_26 = arith.cmpi sgt, %tmp0_25, %cst : tensor<1x16xi64, #blocked> loc(#loc129) + %tmp4 = arith.cmpi slt, %tmp0_24, %cst_9 : tensor<1x16xi64, #blocked5> loc(#loc130) + %tmp4_27 = arith.cmpi slt, %tmp0_25, %cst_0 : tensor<1x16xi64, #blocked> loc(#loc130) + %tmp5 = arith.andi %tmp2, %tmp4 : tensor<1x16xi1, #blocked5> loc(#loc131) + %tmp5_28 = arith.andi %tmp2_26, %tmp4_27 : tensor<1x16xi1, #blocked> loc(#loc131) + %tmp7 = arith.extui %tmp5 : tensor<1x16xi1, #blocked5> to tensor<1x16xi32, #blocked5> loc(#loc206) + %tmp9 = arith.trunci %r0_index_12 : tensor<1x16xi32, #blocked5> to tensor<1x16xi16, #blocked5> loc(#loc134) + %flip = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #blocked3}>}>> loc(#loc207) + %flip_29 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #blocked4}>}>> loc(#loc207) + %flip_30 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #blocked2}>}>> loc(#loc207) + %flip_31 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #blocked1}>}>> loc(#loc207) + %flip_32 = tt.expand_dims %flip {axis = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #blocked3}>}>> -> tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #blocked3}>> loc(#loc207) + %flip_33 = tt.expand_dims %flip_29 {axis = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #blocked4}>}>> -> tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #blocked4}>> loc(#loc207) + %flip_34 = tt.expand_dims %flip_30 {axis = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #blocked2}>}>> -> tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #blocked2}>> loc(#loc207) + %flip_35 = tt.expand_dims %flip_31 {axis = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #blocked1}>}>> -> tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #blocked1}>> loc(#loc207) + %flip_36 = tt.expand_dims %flip_32 {axis = 2 : i32} : tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #blocked3}>> -> tensor<1x2x1xi32, #blocked3> loc(#loc207) + %flip_37 = tt.expand_dims %flip_33 {axis = 2 : i32} : tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #blocked4}>> -> tensor<1x2x1xi32, #blocked4> loc(#loc207) + %flip_38 = tt.expand_dims %flip_34 {axis = 2 : i32} : tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #blocked2}>> -> tensor<1x2x1xi32, #blocked2> loc(#loc207) + %flip_39 = tt.expand_dims %flip_35 {axis = 2 : i32} : tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #blocked1}>> -> tensor<1x2x1xi32, #blocked1> loc(#loc207) + %flip_40 = tt.broadcast %flip_36 : tensor<1x2x1xi32, #blocked3> -> tensor<4x2x2xi32, #blocked3> loc(#loc208) + %flip_41 = tt.reshape %flip_40 : tensor<4x2x2xi32, #blocked3> -> tensor<1x16xi32, #blocked5> loc(#loc209) + %y = tt.reshape %tmp7 : tensor<1x16xi32, #blocked5> -> tensor<8x2x1xi32, #blocked4> loc(#loc222) + %left_mask = arith.subi %cst_4, %flip_37 : tensor<1x2x1xi32, #blocked4> loc(#loc223) + %left_mask_42 = arith.subi %cst_3, %flip_36 : tensor<1x2x1xi32, #blocked3> loc(#loc223) + %left_mask_43 = arith.subi %cst_2, %flip_38 : tensor<1x2x1xi32, #blocked2> loc(#loc223) + %left_mask_44 = arith.subi %cst_1, %flip_39 : tensor<1x2x1xi32, #blocked1> loc(#loc223) + %ileft = tt.broadcast %left_mask : tensor<1x2x1xi32, #blocked4> -> tensor<8x2x1xi32, #blocked4> loc(#loc224) + %ileft_45 = arith.muli %y, %ileft : tensor<8x2x1xi32, #blocked4> loc(#loc224) + %ileft_46 = "tt.reduce"(%ileft_45) <{axis = 1 : i32}> ({ + ^bb0(%ileft_743: i32 loc(callsite(#loc1 at #loc225)), %ileft_744: i32 loc(callsite(#loc1 at #loc225))): + %ileft_745 = arith.addi %ileft_743, %ileft_744 : i32 loc(#loc315) + tt.reduce.return %ileft_745 : i32 loc(#loc297) + }) : (tensor<8x2x1xi32, #blocked4>) -> tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc297) + %ileft_47 = tt.expand_dims %ileft_46 {axis = 1 : i32} : tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<8x1x1xi32, #blocked4> loc(#loc226) + %ileft_48 = tt.broadcast %ileft_47 : tensor<8x1x1xi32, #blocked4> -> tensor<8x2x1xi32, #blocked4> loc(#loc227) + %iright = tt.broadcast %flip_37 : tensor<1x2x1xi32, #blocked4> -> tensor<8x2x1xi32, #blocked4> loc(#loc228) + %iright_49 = arith.muli %y, %iright : tensor<8x2x1xi32, #blocked4> loc(#loc228) + %iright_50 = "tt.reduce"(%iright_49) <{axis = 1 : i32}> ({ + ^bb0(%iright_743: i32 loc(callsite(#loc1 at #loc229)), %iright_744: i32 loc(callsite(#loc1 at #loc229))): + %iright_745 = arith.addi %iright_743, %iright_744 : i32 loc(#loc316) + tt.reduce.return %iright_745 : i32 loc(#loc299) + }) : (tensor<8x2x1xi32, #blocked4>) -> tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc299) + %iright_51 = tt.expand_dims %iright_50 {axis = 1 : i32} : tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<8x1x1xi32, #blocked4> loc(#loc230) + %iright_52 = tt.broadcast %iright_51 : tensor<8x1x1xi32, #blocked4> -> tensor<8x2x1xi32, #blocked4> loc(#loc231) + %ileft_53 = tt.reshape %ileft_48 : tensor<8x2x1xi32, #blocked4> -> tensor<1x16xi32, #blocked5> loc(#loc232) + %iright_54 = tt.reshape %iright_52 : tensor<8x2x1xi32, #blocked4> -> tensor<1x16xi32, #blocked5> loc(#loc233) + %y_idx = tt.reshape %tmp9 : tensor<1x16xi16, #blocked5> -> tensor<8x2x1xi16, #blocked4> loc(#loc234) + %left_idx = arith.trunci %left_mask : tensor<1x2x1xi32, #blocked4> to tensor<1x2x1xi16, #blocked4> loc(#loc235) + %left_idx_55 = tt.broadcast %left_idx : tensor<1x2x1xi16, #blocked4> -> tensor<8x2x1xi16, #blocked4> loc(#loc236) + %left_idx_56 = arith.muli %y_idx, %left_idx_55 : tensor<8x2x1xi16, #blocked4> loc(#loc236) + %input = arith.extsi %left_idx_56 : tensor<8x2x1xi16, #blocked4> to tensor<8x2x1xi32, #blocked4> loc(#loc301) + %left_idx_57 = "tt.reduce"(%input) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_743: i32 loc(callsite(#loc1 at #loc237)), %left_idx_744: i32 loc(callsite(#loc1 at #loc237))): + %left_idx_745 = arith.addi %left_idx_743, %left_idx_744 : i32 loc(#loc317) + tt.reduce.return %left_idx_745 : i32 loc(#loc302) + }) : (tensor<8x2x1xi32, #blocked4>) -> tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc302) + %left_idx_58 = tt.expand_dims %left_idx_57 {axis = 1 : i32} : tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<8x1x1xi32, #blocked4> loc(#loc238) + %left_idx_59 = tt.broadcast %left_idx_58 : tensor<8x1x1xi32, #blocked4> -> tensor<8x2x1xi32, #blocked4> loc(#loc239) + %right_idx = arith.trunci %flip_37 : tensor<1x2x1xi32, #blocked4> to tensor<1x2x1xi16, #blocked4> loc(#loc240) + %right_idx_60 = tt.broadcast %right_idx : tensor<1x2x1xi16, #blocked4> -> tensor<8x2x1xi16, #blocked4> loc(#loc241) + %right_idx_61 = arith.muli %y_idx, %right_idx_60 : tensor<8x2x1xi16, #blocked4> loc(#loc241) + %input_62 = arith.extsi %right_idx_61 : tensor<8x2x1xi16, #blocked4> to tensor<8x2x1xi32, #blocked4> loc(#loc304) + %right_idx_63 = "tt.reduce"(%input_62) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_743: i32 loc(callsite(#loc1 at #loc242)), %right_idx_744: i32 loc(callsite(#loc1 at #loc242))): + %right_idx_745 = arith.addi %right_idx_743, %right_idx_744 : i32 loc(#loc318) + tt.reduce.return %right_idx_745 : i32 loc(#loc305) + }) : (tensor<8x2x1xi32, #blocked4>) -> tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc305) + %right_idx_64 = tt.expand_dims %right_idx_63 {axis = 1 : i32} : tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<8x1x1xi32, #blocked4> loc(#loc243) + %right_idx_65 = tt.broadcast %right_idx_64 : tensor<8x1x1xi32, #blocked4> -> tensor<8x2x1xi32, #blocked4> loc(#loc244) + %left_idx_66 = tt.reshape %left_idx_59 : tensor<8x2x1xi32, #blocked4> -> tensor<1x16xi32, #blocked5> loc(#loc245) + %right_idx_67 = tt.reshape %right_idx_65 : tensor<8x2x1xi32, #blocked4> -> tensor<1x16xi32, #blocked5> loc(#loc246) + %cond = arith.cmpi slt, %ileft_53, %iright_54 : tensor<1x16xi32, #blocked5> loc(#loc247) + %eq = arith.cmpi eq, %ileft_53, %iright_54 : tensor<1x16xi32, #blocked5> loc(#loc248) + %cond_68 = arith.cmpi sgt, %left_idx_66, %right_idx_67 : tensor<1x16xi32, #blocked5> loc(#loc249) + %cond_69 = arith.andi %eq, %cond_68 : tensor<1x16xi1, #blocked5> loc(#loc250) + %cond_70 = arith.ori %cond, %cond_69 : tensor<1x16xi1, #blocked5> loc(#loc251) + %cond_71 = arith.extui %cond_70 : tensor<1x16xi1, #blocked5> to tensor<1x16xi32, #blocked5> loc(#loc252) + %cond_72 = arith.xori %cond_71, %flip_41 : tensor<1x16xi32, #blocked5> loc(#loc252) + %cond_73 = arith.cmpi ne, %cond_72, %cst_6 : tensor<1x16xi32, #blocked5> loc(#loc253) + %ret = arith.xori %ileft_53, %iright_54 : tensor<1x16xi32, #blocked5> loc(#loc254) + %ret_74 = arith.select %cond_73, %ret, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc255) + %ret_75 = arith.xori %tmp7, %ret_74 : tensor<1x16xi32, #blocked5> loc(#loc256) + %new_idxs = arith.xori %left_idx_66, %right_idx_67 : tensor<1x16xi32, #blocked5> loc(#loc257) + %new_idxs_76 = arith.select %cond_73, %new_idxs, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc258) + %new_idxs_77 = arith.extsi %tmp9 : tensor<1x16xi16, #blocked5> to tensor<1x16xi32, #blocked5> loc(#loc259) + %new_idxs_78 = arith.xori %new_idxs_77, %new_idxs_76 : tensor<1x16xi32, #blocked5> loc(#loc259) + %flip_79 = tt.broadcast %flip_38 : tensor<1x2x1xi32, #blocked2> -> tensor<2x2x4xi32, #blocked2> loc(#loc208) + %flip_80 = tt.reshape %flip_79 : tensor<2x2x4xi32, #blocked2> -> tensor<1x16xi32, #blocked5> loc(#loc209) + %y_81 = tt.reshape %ret_75 : tensor<1x16xi32, #blocked5> -> tensor<4x2x2xi32, #blocked3> loc(#loc222) + %ileft_82 = tt.broadcast %left_mask_42 : tensor<1x2x1xi32, #blocked3> -> tensor<4x2x2xi32, #blocked3> loc(#loc224) + %ileft_83 = arith.muli %y_81, %ileft_82 : tensor<4x2x2xi32, #blocked3> loc(#loc224) + %ileft_84 = "tt.reduce"(%ileft_83) <{axis = 1 : i32}> ({ + ^bb0(%ileft_743: i32 loc(callsite(#loc1 at #loc225)), %ileft_744: i32 loc(callsite(#loc1 at #loc225))): + %ileft_745 = arith.addi %ileft_743, %ileft_744 : i32 loc(#loc315) + tt.reduce.return %ileft_745 : i32 loc(#loc297) + }) : (tensor<4x2x2xi32, #blocked3>) -> tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc297) + %ileft_85 = tt.expand_dims %ileft_84 {axis = 1 : i32} : tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<4x1x2xi32, #blocked3> loc(#loc226) + %ileft_86 = tt.broadcast %ileft_85 : tensor<4x1x2xi32, #blocked3> -> tensor<4x2x2xi32, #blocked3> loc(#loc227) + %iright_87 = arith.muli %y_81, %flip_40 : tensor<4x2x2xi32, #blocked3> loc(#loc228) + %iright_88 = "tt.reduce"(%iright_87) <{axis = 1 : i32}> ({ + ^bb0(%iright_743: i32 loc(callsite(#loc1 at #loc229)), %iright_744: i32 loc(callsite(#loc1 at #loc229))): + %iright_745 = arith.addi %iright_743, %iright_744 : i32 loc(#loc316) + tt.reduce.return %iright_745 : i32 loc(#loc299) + }) : (tensor<4x2x2xi32, #blocked3>) -> tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc299) + %iright_89 = tt.expand_dims %iright_88 {axis = 1 : i32} : tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<4x1x2xi32, #blocked3> loc(#loc230) + %iright_90 = tt.broadcast %iright_89 : tensor<4x1x2xi32, #blocked3> -> tensor<4x2x2xi32, #blocked3> loc(#loc231) + %ileft_91 = tt.reshape %ileft_86 : tensor<4x2x2xi32, #blocked3> -> tensor<1x16xi32, #blocked5> loc(#loc232) + %iright_92 = tt.reshape %iright_90 : tensor<4x2x2xi32, #blocked3> -> tensor<1x16xi32, #blocked5> loc(#loc233) + %y_idx_93 = tt.reshape %new_idxs_78 : tensor<1x16xi32, #blocked5> -> tensor<4x2x2xi32, #blocked3> loc(#loc234) + %left_idx_94 = arith.muli %y_idx_93, %ileft_82 : tensor<4x2x2xi32, #blocked3> loc(#loc236) + %left_idx_95 = "tt.reduce"(%left_idx_94) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_743: i32 loc(callsite(#loc1 at #loc237)), %left_idx_744: i32 loc(callsite(#loc1 at #loc237))): + %left_idx_745 = arith.addi %left_idx_743, %left_idx_744 : i32 loc(#loc317) + tt.reduce.return %left_idx_745 : i32 loc(#loc302) + }) : (tensor<4x2x2xi32, #blocked3>) -> tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc302) + %left_idx_96 = tt.expand_dims %left_idx_95 {axis = 1 : i32} : tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<4x1x2xi32, #blocked3> loc(#loc238) + %left_idx_97 = tt.broadcast %left_idx_96 : tensor<4x1x2xi32, #blocked3> -> tensor<4x2x2xi32, #blocked3> loc(#loc239) + %right_idx_98 = arith.muli %y_idx_93, %flip_40 : tensor<4x2x2xi32, #blocked3> loc(#loc241) + %right_idx_99 = "tt.reduce"(%right_idx_98) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_743: i32 loc(callsite(#loc1 at #loc242)), %right_idx_744: i32 loc(callsite(#loc1 at #loc242))): + %right_idx_745 = arith.addi %right_idx_743, %right_idx_744 : i32 loc(#loc318) + tt.reduce.return %right_idx_745 : i32 loc(#loc305) + }) : (tensor<4x2x2xi32, #blocked3>) -> tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc305) + %right_idx_100 = tt.expand_dims %right_idx_99 {axis = 1 : i32} : tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<4x1x2xi32, #blocked3> loc(#loc243) + %right_idx_101 = tt.broadcast %right_idx_100 : tensor<4x1x2xi32, #blocked3> -> tensor<4x2x2xi32, #blocked3> loc(#loc244) + %left_idx_102 = tt.reshape %left_idx_97 : tensor<4x2x2xi32, #blocked3> -> tensor<1x16xi32, #blocked5> loc(#loc245) + %right_idx_103 = tt.reshape %right_idx_101 : tensor<4x2x2xi32, #blocked3> -> tensor<1x16xi32, #blocked5> loc(#loc246) + %cond_104 = arith.cmpi slt, %ileft_91, %iright_92 : tensor<1x16xi32, #blocked5> loc(#loc247) + %eq_105 = arith.cmpi eq, %ileft_91, %iright_92 : tensor<1x16xi32, #blocked5> loc(#loc248) + %cond_106 = arith.cmpi sgt, %left_idx_102, %right_idx_103 : tensor<1x16xi32, #blocked5> loc(#loc249) + %cond_107 = arith.andi %eq_105, %cond_106 : tensor<1x16xi1, #blocked5> loc(#loc250) + %cond_108 = arith.ori %cond_104, %cond_107 : tensor<1x16xi1, #blocked5> loc(#loc251) + %cond_109 = arith.extui %cond_108 : tensor<1x16xi1, #blocked5> to tensor<1x16xi32, #blocked5> loc(#loc252) + %cond_110 = arith.xori %cond_109, %flip_80 : tensor<1x16xi32, #blocked5> loc(#loc252) + %cond_111 = arith.cmpi ne, %cond_110, %cst_6 : tensor<1x16xi32, #blocked5> loc(#loc253) + %ret_112 = arith.xori %ileft_91, %iright_92 : tensor<1x16xi32, #blocked5> loc(#loc254) + %ret_113 = arith.select %cond_111, %ret_112, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc255) + %ret_114 = arith.xori %ret_75, %ret_113 : tensor<1x16xi32, #blocked5> loc(#loc256) + %new_idxs_115 = arith.xori %left_idx_102, %right_idx_103 : tensor<1x16xi32, #blocked5> loc(#loc257) + %new_idxs_116 = arith.select %cond_111, %new_idxs_115, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc258) + %new_idxs_117 = arith.xori %new_idxs_78, %new_idxs_116 : tensor<1x16xi32, #blocked5> loc(#loc259) + %y_118 = tt.reshape %ret_114 : tensor<1x16xi32, #blocked5> -> tensor<8x2x1xi32, #blocked4> loc(#loc222) + %ileft_119 = arith.muli %y_118, %ileft : tensor<8x2x1xi32, #blocked4> loc(#loc224) + %ileft_120 = "tt.reduce"(%ileft_119) <{axis = 1 : i32}> ({ + ^bb0(%ileft_743: i32 loc(callsite(#loc1 at #loc225)), %ileft_744: i32 loc(callsite(#loc1 at #loc225))): + %ileft_745 = arith.addi %ileft_743, %ileft_744 : i32 loc(#loc315) + tt.reduce.return %ileft_745 : i32 loc(#loc297) + }) : (tensor<8x2x1xi32, #blocked4>) -> tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc297) + %ileft_121 = tt.expand_dims %ileft_120 {axis = 1 : i32} : tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<8x1x1xi32, #blocked4> loc(#loc226) + %ileft_122 = tt.broadcast %ileft_121 : tensor<8x1x1xi32, #blocked4> -> tensor<8x2x1xi32, #blocked4> loc(#loc227) + %iright_123 = arith.muli %y_118, %iright : tensor<8x2x1xi32, #blocked4> loc(#loc228) + %iright_124 = "tt.reduce"(%iright_123) <{axis = 1 : i32}> ({ + ^bb0(%iright_743: i32 loc(callsite(#loc1 at #loc229)), %iright_744: i32 loc(callsite(#loc1 at #loc229))): + %iright_745 = arith.addi %iright_743, %iright_744 : i32 loc(#loc316) + tt.reduce.return %iright_745 : i32 loc(#loc299) + }) : (tensor<8x2x1xi32, #blocked4>) -> tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc299) + %iright_125 = tt.expand_dims %iright_124 {axis = 1 : i32} : tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<8x1x1xi32, #blocked4> loc(#loc230) + %iright_126 = tt.broadcast %iright_125 : tensor<8x1x1xi32, #blocked4> -> tensor<8x2x1xi32, #blocked4> loc(#loc231) + %ileft_127 = tt.reshape %ileft_122 : tensor<8x2x1xi32, #blocked4> -> tensor<1x16xi32, #blocked5> loc(#loc232) + %iright_128 = tt.reshape %iright_126 : tensor<8x2x1xi32, #blocked4> -> tensor<1x16xi32, #blocked5> loc(#loc233) + %y_idx_129 = tt.reshape %new_idxs_117 : tensor<1x16xi32, #blocked5> -> tensor<8x2x1xi32, #blocked4> loc(#loc234) + %left_idx_130 = arith.muli %y_idx_129, %ileft : tensor<8x2x1xi32, #blocked4> loc(#loc236) + %left_idx_131 = "tt.reduce"(%left_idx_130) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_743: i32 loc(callsite(#loc1 at #loc237)), %left_idx_744: i32 loc(callsite(#loc1 at #loc237))): + %left_idx_745 = arith.addi %left_idx_743, %left_idx_744 : i32 loc(#loc317) + tt.reduce.return %left_idx_745 : i32 loc(#loc302) + }) : (tensor<8x2x1xi32, #blocked4>) -> tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc302) + %left_idx_132 = tt.expand_dims %left_idx_131 {axis = 1 : i32} : tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<8x1x1xi32, #blocked4> loc(#loc238) + %left_idx_133 = tt.broadcast %left_idx_132 : tensor<8x1x1xi32, #blocked4> -> tensor<8x2x1xi32, #blocked4> loc(#loc239) + %right_idx_134 = arith.muli %y_idx_129, %iright : tensor<8x2x1xi32, #blocked4> loc(#loc241) + %right_idx_135 = "tt.reduce"(%right_idx_134) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_743: i32 loc(callsite(#loc1 at #loc242)), %right_idx_744: i32 loc(callsite(#loc1 at #loc242))): + %right_idx_745 = arith.addi %right_idx_743, %right_idx_744 : i32 loc(#loc318) + tt.reduce.return %right_idx_745 : i32 loc(#loc305) + }) : (tensor<8x2x1xi32, #blocked4>) -> tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc305) + %right_idx_136 = tt.expand_dims %right_idx_135 {axis = 1 : i32} : tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<8x1x1xi32, #blocked4> loc(#loc243) + %right_idx_137 = tt.broadcast %right_idx_136 : tensor<8x1x1xi32, #blocked4> -> tensor<8x2x1xi32, #blocked4> loc(#loc244) + %left_idx_138 = tt.reshape %left_idx_133 : tensor<8x2x1xi32, #blocked4> -> tensor<1x16xi32, #blocked5> loc(#loc245) + %right_idx_139 = tt.reshape %right_idx_137 : tensor<8x2x1xi32, #blocked4> -> tensor<1x16xi32, #blocked5> loc(#loc246) + %cond_140 = arith.cmpi slt, %ileft_127, %iright_128 : tensor<1x16xi32, #blocked5> loc(#loc247) + %eq_141 = arith.cmpi eq, %ileft_127, %iright_128 : tensor<1x16xi32, #blocked5> loc(#loc248) + %cond_142 = arith.cmpi sgt, %left_idx_138, %right_idx_139 : tensor<1x16xi32, #blocked5> loc(#loc249) + %cond_143 = arith.andi %eq_141, %cond_142 : tensor<1x16xi1, #blocked5> loc(#loc250) + %cond_144 = arith.ori %cond_140, %cond_143 : tensor<1x16xi1, #blocked5> loc(#loc251) + %cond_145 = arith.extui %cond_144 : tensor<1x16xi1, #blocked5> to tensor<1x16xi32, #blocked5> loc(#loc252) + %cond_146 = arith.xori %cond_145, %flip_80 : tensor<1x16xi32, #blocked5> loc(#loc252) + %cond_147 = arith.cmpi ne, %cond_146, %cst_6 : tensor<1x16xi32, #blocked5> loc(#loc253) + %ret_148 = arith.xori %ileft_127, %iright_128 : tensor<1x16xi32, #blocked5> loc(#loc254) + %ret_149 = arith.select %cond_147, %ret_148, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc255) + %ret_150 = arith.xori %ret_114, %ret_149 : tensor<1x16xi32, #blocked5> loc(#loc256) + %new_idxs_151 = arith.xori %left_idx_138, %right_idx_139 : tensor<1x16xi32, #blocked5> loc(#loc257) + %new_idxs_152 = arith.select %cond_147, %new_idxs_151, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc258) + %new_idxs_153 = arith.xori %new_idxs_117, %new_idxs_152 : tensor<1x16xi32, #blocked5> loc(#loc259) + %flip_154 = tt.broadcast %flip_39 : tensor<1x2x1xi32, #blocked1> -> tensor<1x2x8xi32, #blocked1> loc(#loc208) + %flip_155 = tt.reshape %flip_154 : tensor<1x2x8xi32, #blocked1> -> tensor<1x16xi32, #blocked5> loc(#loc209) + %y_156 = tt.reshape %ret_150 : tensor<1x16xi32, #blocked5> -> tensor<2x2x4xi32, #blocked2> loc(#loc222) + %ileft_157 = tt.broadcast %left_mask_43 : tensor<1x2x1xi32, #blocked2> -> tensor<2x2x4xi32, #blocked2> loc(#loc224) + %ileft_158 = arith.muli %y_156, %ileft_157 : tensor<2x2x4xi32, #blocked2> loc(#loc224) + %ileft_159 = "tt.reduce"(%ileft_158) <{axis = 1 : i32}> ({ + ^bb0(%ileft_743: i32 loc(callsite(#loc1 at #loc225)), %ileft_744: i32 loc(callsite(#loc1 at #loc225))): + %ileft_745 = arith.addi %ileft_743, %ileft_744 : i32 loc(#loc315) + tt.reduce.return %ileft_745 : i32 loc(#loc297) + }) : (tensor<2x2x4xi32, #blocked2>) -> tensor<2x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc297) + %ileft_160 = tt.expand_dims %ileft_159 {axis = 1 : i32} : tensor<2x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<2x1x4xi32, #blocked2> loc(#loc226) + %ileft_161 = tt.broadcast %ileft_160 : tensor<2x1x4xi32, #blocked2> -> tensor<2x2x4xi32, #blocked2> loc(#loc227) + %iright_162 = arith.muli %y_156, %flip_79 : tensor<2x2x4xi32, #blocked2> loc(#loc228) + %iright_163 = "tt.reduce"(%iright_162) <{axis = 1 : i32}> ({ + ^bb0(%iright_743: i32 loc(callsite(#loc1 at #loc229)), %iright_744: i32 loc(callsite(#loc1 at #loc229))): + %iright_745 = arith.addi %iright_743, %iright_744 : i32 loc(#loc316) + tt.reduce.return %iright_745 : i32 loc(#loc299) + }) : (tensor<2x2x4xi32, #blocked2>) -> tensor<2x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc299) + %iright_164 = tt.expand_dims %iright_163 {axis = 1 : i32} : tensor<2x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<2x1x4xi32, #blocked2> loc(#loc230) + %iright_165 = tt.broadcast %iright_164 : tensor<2x1x4xi32, #blocked2> -> tensor<2x2x4xi32, #blocked2> loc(#loc231) + %ileft_166 = tt.reshape %ileft_161 : tensor<2x2x4xi32, #blocked2> -> tensor<1x16xi32, #blocked5> loc(#loc232) + %iright_167 = tt.reshape %iright_165 : tensor<2x2x4xi32, #blocked2> -> tensor<1x16xi32, #blocked5> loc(#loc233) + %y_idx_168 = tt.reshape %new_idxs_153 : tensor<1x16xi32, #blocked5> -> tensor<2x2x4xi32, #blocked2> loc(#loc234) + %left_idx_169 = arith.muli %y_idx_168, %ileft_157 : tensor<2x2x4xi32, #blocked2> loc(#loc236) + %left_idx_170 = "tt.reduce"(%left_idx_169) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_743: i32 loc(callsite(#loc1 at #loc237)), %left_idx_744: i32 loc(callsite(#loc1 at #loc237))): + %left_idx_745 = arith.addi %left_idx_743, %left_idx_744 : i32 loc(#loc317) + tt.reduce.return %left_idx_745 : i32 loc(#loc302) + }) : (tensor<2x2x4xi32, #blocked2>) -> tensor<2x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc302) + %left_idx_171 = tt.expand_dims %left_idx_170 {axis = 1 : i32} : tensor<2x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<2x1x4xi32, #blocked2> loc(#loc238) + %left_idx_172 = tt.broadcast %left_idx_171 : tensor<2x1x4xi32, #blocked2> -> tensor<2x2x4xi32, #blocked2> loc(#loc239) + %right_idx_173 = arith.muli %y_idx_168, %flip_79 : tensor<2x2x4xi32, #blocked2> loc(#loc241) + %right_idx_174 = "tt.reduce"(%right_idx_173) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_743: i32 loc(callsite(#loc1 at #loc242)), %right_idx_744: i32 loc(callsite(#loc1 at #loc242))): + %right_idx_745 = arith.addi %right_idx_743, %right_idx_744 : i32 loc(#loc318) + tt.reduce.return %right_idx_745 : i32 loc(#loc305) + }) : (tensor<2x2x4xi32, #blocked2>) -> tensor<2x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc305) + %right_idx_175 = tt.expand_dims %right_idx_174 {axis = 1 : i32} : tensor<2x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<2x1x4xi32, #blocked2> loc(#loc243) + %right_idx_176 = tt.broadcast %right_idx_175 : tensor<2x1x4xi32, #blocked2> -> tensor<2x2x4xi32, #blocked2> loc(#loc244) + %left_idx_177 = tt.reshape %left_idx_172 : tensor<2x2x4xi32, #blocked2> -> tensor<1x16xi32, #blocked5> loc(#loc245) + %right_idx_178 = tt.reshape %right_idx_176 : tensor<2x2x4xi32, #blocked2> -> tensor<1x16xi32, #blocked5> loc(#loc246) + %cond_179 = arith.cmpi slt, %ileft_166, %iright_167 : tensor<1x16xi32, #blocked5> loc(#loc247) + %eq_180 = arith.cmpi eq, %ileft_166, %iright_167 : tensor<1x16xi32, #blocked5> loc(#loc248) + %cond_181 = arith.cmpi sgt, %left_idx_177, %right_idx_178 : tensor<1x16xi32, #blocked5> loc(#loc249) + %cond_182 = arith.andi %eq_180, %cond_181 : tensor<1x16xi1, #blocked5> loc(#loc250) + %cond_183 = arith.ori %cond_179, %cond_182 : tensor<1x16xi1, #blocked5> loc(#loc251) + %cond_184 = arith.extui %cond_183 : tensor<1x16xi1, #blocked5> to tensor<1x16xi32, #blocked5> loc(#loc252) + %cond_185 = arith.xori %cond_184, %flip_155 : tensor<1x16xi32, #blocked5> loc(#loc252) + %cond_186 = arith.cmpi ne, %cond_185, %cst_6 : tensor<1x16xi32, #blocked5> loc(#loc253) + %ret_187 = arith.xori %ileft_166, %iright_167 : tensor<1x16xi32, #blocked5> loc(#loc254) + %ret_188 = arith.select %cond_186, %ret_187, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc255) + %ret_189 = arith.xori %ret_150, %ret_188 : tensor<1x16xi32, #blocked5> loc(#loc256) + %new_idxs_190 = arith.xori %left_idx_177, %right_idx_178 : tensor<1x16xi32, #blocked5> loc(#loc257) + %new_idxs_191 = arith.select %cond_186, %new_idxs_190, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc258) + %new_idxs_192 = arith.xori %new_idxs_153, %new_idxs_191 : tensor<1x16xi32, #blocked5> loc(#loc259) + %y_193 = tt.reshape %ret_189 : tensor<1x16xi32, #blocked5> -> tensor<4x2x2xi32, #blocked3> loc(#loc222) + %ileft_194 = arith.muli %y_193, %ileft_82 : tensor<4x2x2xi32, #blocked3> loc(#loc224) + %ileft_195 = "tt.reduce"(%ileft_194) <{axis = 1 : i32}> ({ + ^bb0(%ileft_743: i32 loc(callsite(#loc1 at #loc225)), %ileft_744: i32 loc(callsite(#loc1 at #loc225))): + %ileft_745 = arith.addi %ileft_743, %ileft_744 : i32 loc(#loc315) + tt.reduce.return %ileft_745 : i32 loc(#loc297) + }) : (tensor<4x2x2xi32, #blocked3>) -> tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc297) + %ileft_196 = tt.expand_dims %ileft_195 {axis = 1 : i32} : tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<4x1x2xi32, #blocked3> loc(#loc226) + %ileft_197 = tt.broadcast %ileft_196 : tensor<4x1x2xi32, #blocked3> -> tensor<4x2x2xi32, #blocked3> loc(#loc227) + %iright_198 = arith.muli %y_193, %flip_40 : tensor<4x2x2xi32, #blocked3> loc(#loc228) + %iright_199 = "tt.reduce"(%iright_198) <{axis = 1 : i32}> ({ + ^bb0(%iright_743: i32 loc(callsite(#loc1 at #loc229)), %iright_744: i32 loc(callsite(#loc1 at #loc229))): + %iright_745 = arith.addi %iright_743, %iright_744 : i32 loc(#loc316) + tt.reduce.return %iright_745 : i32 loc(#loc299) + }) : (tensor<4x2x2xi32, #blocked3>) -> tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc299) + %iright_200 = tt.expand_dims %iright_199 {axis = 1 : i32} : tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<4x1x2xi32, #blocked3> loc(#loc230) + %iright_201 = tt.broadcast %iright_200 : tensor<4x1x2xi32, #blocked3> -> tensor<4x2x2xi32, #blocked3> loc(#loc231) + %ileft_202 = tt.reshape %ileft_197 : tensor<4x2x2xi32, #blocked3> -> tensor<1x16xi32, #blocked5> loc(#loc232) + %iright_203 = tt.reshape %iright_201 : tensor<4x2x2xi32, #blocked3> -> tensor<1x16xi32, #blocked5> loc(#loc233) + %y_idx_204 = tt.reshape %new_idxs_192 : tensor<1x16xi32, #blocked5> -> tensor<4x2x2xi32, #blocked3> loc(#loc234) + %left_idx_205 = arith.muli %y_idx_204, %ileft_82 : tensor<4x2x2xi32, #blocked3> loc(#loc236) + %left_idx_206 = "tt.reduce"(%left_idx_205) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_743: i32 loc(callsite(#loc1 at #loc237)), %left_idx_744: i32 loc(callsite(#loc1 at #loc237))): + %left_idx_745 = arith.addi %left_idx_743, %left_idx_744 : i32 loc(#loc317) + tt.reduce.return %left_idx_745 : i32 loc(#loc302) + }) : (tensor<4x2x2xi32, #blocked3>) -> tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc302) + %left_idx_207 = tt.expand_dims %left_idx_206 {axis = 1 : i32} : tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<4x1x2xi32, #blocked3> loc(#loc238) + %left_idx_208 = tt.broadcast %left_idx_207 : tensor<4x1x2xi32, #blocked3> -> tensor<4x2x2xi32, #blocked3> loc(#loc239) + %right_idx_209 = arith.muli %y_idx_204, %flip_40 : tensor<4x2x2xi32, #blocked3> loc(#loc241) + %right_idx_210 = "tt.reduce"(%right_idx_209) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_743: i32 loc(callsite(#loc1 at #loc242)), %right_idx_744: i32 loc(callsite(#loc1 at #loc242))): + %right_idx_745 = arith.addi %right_idx_743, %right_idx_744 : i32 loc(#loc318) + tt.reduce.return %right_idx_745 : i32 loc(#loc305) + }) : (tensor<4x2x2xi32, #blocked3>) -> tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc305) + %right_idx_211 = tt.expand_dims %right_idx_210 {axis = 1 : i32} : tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<4x1x2xi32, #blocked3> loc(#loc243) + %right_idx_212 = tt.broadcast %right_idx_211 : tensor<4x1x2xi32, #blocked3> -> tensor<4x2x2xi32, #blocked3> loc(#loc244) + %left_idx_213 = tt.reshape %left_idx_208 : tensor<4x2x2xi32, #blocked3> -> tensor<1x16xi32, #blocked5> loc(#loc245) + %right_idx_214 = tt.reshape %right_idx_212 : tensor<4x2x2xi32, #blocked3> -> tensor<1x16xi32, #blocked5> loc(#loc246) + %cond_215 = arith.cmpi slt, %ileft_202, %iright_203 : tensor<1x16xi32, #blocked5> loc(#loc247) + %eq_216 = arith.cmpi eq, %ileft_202, %iright_203 : tensor<1x16xi32, #blocked5> loc(#loc248) + %cond_217 = arith.cmpi sgt, %left_idx_213, %right_idx_214 : tensor<1x16xi32, #blocked5> loc(#loc249) + %cond_218 = arith.andi %eq_216, %cond_217 : tensor<1x16xi1, #blocked5> loc(#loc250) + %cond_219 = arith.ori %cond_215, %cond_218 : tensor<1x16xi1, #blocked5> loc(#loc251) + %cond_220 = arith.extui %cond_219 : tensor<1x16xi1, #blocked5> to tensor<1x16xi32, #blocked5> loc(#loc252) + %cond_221 = arith.xori %cond_220, %flip_155 : tensor<1x16xi32, #blocked5> loc(#loc252) + %cond_222 = arith.cmpi ne, %cond_221, %cst_6 : tensor<1x16xi32, #blocked5> loc(#loc253) + %ret_223 = arith.xori %ileft_202, %iright_203 : tensor<1x16xi32, #blocked5> loc(#loc254) + %ret_224 = arith.select %cond_222, %ret_223, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc255) + %ret_225 = arith.xori %ret_189, %ret_224 : tensor<1x16xi32, #blocked5> loc(#loc256) + %new_idxs_226 = arith.xori %left_idx_213, %right_idx_214 : tensor<1x16xi32, #blocked5> loc(#loc257) + %new_idxs_227 = arith.select %cond_222, %new_idxs_226, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc258) + %new_idxs_228 = arith.xori %new_idxs_192, %new_idxs_227 : tensor<1x16xi32, #blocked5> loc(#loc259) + %y_229 = tt.reshape %ret_225 : tensor<1x16xi32, #blocked5> -> tensor<8x2x1xi32, #blocked4> loc(#loc222) + %ileft_230 = arith.muli %y_229, %ileft : tensor<8x2x1xi32, #blocked4> loc(#loc224) + %ileft_231 = "tt.reduce"(%ileft_230) <{axis = 1 : i32}> ({ + ^bb0(%ileft_743: i32 loc(callsite(#loc1 at #loc225)), %ileft_744: i32 loc(callsite(#loc1 at #loc225))): + %ileft_745 = arith.addi %ileft_743, %ileft_744 : i32 loc(#loc315) + tt.reduce.return %ileft_745 : i32 loc(#loc297) + }) : (tensor<8x2x1xi32, #blocked4>) -> tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc297) + %ileft_232 = tt.expand_dims %ileft_231 {axis = 1 : i32} : tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<8x1x1xi32, #blocked4> loc(#loc226) + %ileft_233 = tt.broadcast %ileft_232 : tensor<8x1x1xi32, #blocked4> -> tensor<8x2x1xi32, #blocked4> loc(#loc227) + %iright_234 = arith.muli %y_229, %iright : tensor<8x2x1xi32, #blocked4> loc(#loc228) + %iright_235 = "tt.reduce"(%iright_234) <{axis = 1 : i32}> ({ + ^bb0(%iright_743: i32 loc(callsite(#loc1 at #loc229)), %iright_744: i32 loc(callsite(#loc1 at #loc229))): + %iright_745 = arith.addi %iright_743, %iright_744 : i32 loc(#loc316) + tt.reduce.return %iright_745 : i32 loc(#loc299) + }) : (tensor<8x2x1xi32, #blocked4>) -> tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc299) + %iright_236 = tt.expand_dims %iright_235 {axis = 1 : i32} : tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<8x1x1xi32, #blocked4> loc(#loc230) + %iright_237 = tt.broadcast %iright_236 : tensor<8x1x1xi32, #blocked4> -> tensor<8x2x1xi32, #blocked4> loc(#loc231) + %ileft_238 = tt.reshape %ileft_233 : tensor<8x2x1xi32, #blocked4> -> tensor<1x16xi32, #blocked5> loc(#loc232) + %iright_239 = tt.reshape %iright_237 : tensor<8x2x1xi32, #blocked4> -> tensor<1x16xi32, #blocked5> loc(#loc233) + %y_idx_240 = tt.reshape %new_idxs_228 : tensor<1x16xi32, #blocked5> -> tensor<8x2x1xi32, #blocked4> loc(#loc234) + %left_idx_241 = arith.muli %y_idx_240, %ileft : tensor<8x2x1xi32, #blocked4> loc(#loc236) + %left_idx_242 = "tt.reduce"(%left_idx_241) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_743: i32 loc(callsite(#loc1 at #loc237)), %left_idx_744: i32 loc(callsite(#loc1 at #loc237))): + %left_idx_745 = arith.addi %left_idx_743, %left_idx_744 : i32 loc(#loc317) + tt.reduce.return %left_idx_745 : i32 loc(#loc302) + }) : (tensor<8x2x1xi32, #blocked4>) -> tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc302) + %left_idx_243 = tt.expand_dims %left_idx_242 {axis = 1 : i32} : tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<8x1x1xi32, #blocked4> loc(#loc238) + %left_idx_244 = tt.broadcast %left_idx_243 : tensor<8x1x1xi32, #blocked4> -> tensor<8x2x1xi32, #blocked4> loc(#loc239) + %right_idx_245 = arith.muli %y_idx_240, %iright : tensor<8x2x1xi32, #blocked4> loc(#loc241) + %right_idx_246 = "tt.reduce"(%right_idx_245) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_743: i32 loc(callsite(#loc1 at #loc242)), %right_idx_744: i32 loc(callsite(#loc1 at #loc242))): + %right_idx_745 = arith.addi %right_idx_743, %right_idx_744 : i32 loc(#loc318) + tt.reduce.return %right_idx_745 : i32 loc(#loc305) + }) : (tensor<8x2x1xi32, #blocked4>) -> tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc305) + %right_idx_247 = tt.expand_dims %right_idx_246 {axis = 1 : i32} : tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<8x1x1xi32, #blocked4> loc(#loc243) + %right_idx_248 = tt.broadcast %right_idx_247 : tensor<8x1x1xi32, #blocked4> -> tensor<8x2x1xi32, #blocked4> loc(#loc244) + %left_idx_249 = tt.reshape %left_idx_244 : tensor<8x2x1xi32, #blocked4> -> tensor<1x16xi32, #blocked5> loc(#loc245) + %right_idx_250 = tt.reshape %right_idx_248 : tensor<8x2x1xi32, #blocked4> -> tensor<1x16xi32, #blocked5> loc(#loc246) + %cond_251 = arith.cmpi slt, %ileft_238, %iright_239 : tensor<1x16xi32, #blocked5> loc(#loc247) + %eq_252 = arith.cmpi eq, %ileft_238, %iright_239 : tensor<1x16xi32, #blocked5> loc(#loc248) + %cond_253 = arith.cmpi sgt, %left_idx_249, %right_idx_250 : tensor<1x16xi32, #blocked5> loc(#loc249) + %cond_254 = arith.andi %eq_252, %cond_253 : tensor<1x16xi1, #blocked5> loc(#loc250) + %cond_255 = arith.ori %cond_251, %cond_254 : tensor<1x16xi1, #blocked5> loc(#loc251) + %cond_256 = arith.extui %cond_255 : tensor<1x16xi1, #blocked5> to tensor<1x16xi32, #blocked5> loc(#loc252) + %cond_257 = arith.xori %cond_256, %flip_155 : tensor<1x16xi32, #blocked5> loc(#loc252) + %cond_258 = arith.cmpi ne, %cond_257, %cst_6 : tensor<1x16xi32, #blocked5> loc(#loc253) + %ret_259 = arith.xori %ileft_238, %iright_239 : tensor<1x16xi32, #blocked5> loc(#loc254) + %ret_260 = arith.select %cond_258, %ret_259, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc255) + %ret_261 = arith.xori %ret_225, %ret_260 : tensor<1x16xi32, #blocked5> loc(#loc256) + %new_idxs_262 = arith.xori %left_idx_249, %right_idx_250 : tensor<1x16xi32, #blocked5> loc(#loc257) + %new_idxs_263 = arith.select %cond_258, %new_idxs_262, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc258) + %new_idxs_264 = arith.xori %new_idxs_228, %new_idxs_263 : tensor<1x16xi32, #blocked5> loc(#loc259) + %y_265 = tt.reshape %ret_261 : tensor<1x16xi32, #blocked5> -> tensor<1x2x8xi32, #blocked1> loc(#loc222) + %ileft_266 = tt.broadcast %left_mask_44 : tensor<1x2x1xi32, #blocked1> -> tensor<1x2x8xi32, #blocked1> loc(#loc224) + %ileft_267 = arith.muli %y_265, %ileft_266 : tensor<1x2x8xi32, #blocked1> loc(#loc224) + %ileft_268 = "tt.reduce"(%ileft_267) <{axis = 1 : i32}> ({ + ^bb0(%ileft_743: i32 loc(callsite(#loc1 at #loc225)), %ileft_744: i32 loc(callsite(#loc1 at #loc225))): + %ileft_745 = arith.addi %ileft_743, %ileft_744 : i32 loc(#loc315) + tt.reduce.return %ileft_745 : i32 loc(#loc297) + }) : (tensor<1x2x8xi32, #blocked1>) -> tensor<1x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc297) + %ileft_269 = tt.expand_dims %ileft_268 {axis = 1 : i32} : tensor<1x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<1x1x8xi32, #blocked1> loc(#loc226) + %ileft_270 = tt.broadcast %ileft_269 : tensor<1x1x8xi32, #blocked1> -> tensor<1x2x8xi32, #blocked1> loc(#loc227) + %iright_271 = arith.muli %y_265, %flip_154 : tensor<1x2x8xi32, #blocked1> loc(#loc228) + %iright_272 = "tt.reduce"(%iright_271) <{axis = 1 : i32}> ({ + ^bb0(%iright_743: i32 loc(callsite(#loc1 at #loc229)), %iright_744: i32 loc(callsite(#loc1 at #loc229))): + %iright_745 = arith.addi %iright_743, %iright_744 : i32 loc(#loc316) + tt.reduce.return %iright_745 : i32 loc(#loc299) + }) : (tensor<1x2x8xi32, #blocked1>) -> tensor<1x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc299) + %iright_273 = tt.expand_dims %iright_272 {axis = 1 : i32} : tensor<1x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<1x1x8xi32, #blocked1> loc(#loc230) + %iright_274 = tt.broadcast %iright_273 : tensor<1x1x8xi32, #blocked1> -> tensor<1x2x8xi32, #blocked1> loc(#loc231) + %ileft_275 = tt.reshape %ileft_270 : tensor<1x2x8xi32, #blocked1> -> tensor<1x16xi32, #blocked5> loc(#loc232) + %iright_276 = tt.reshape %iright_274 : tensor<1x2x8xi32, #blocked1> -> tensor<1x16xi32, #blocked5> loc(#loc233) + %y_idx_277 = tt.reshape %new_idxs_264 : tensor<1x16xi32, #blocked5> -> tensor<1x2x8xi32, #blocked1> loc(#loc234) + %left_idx_278 = arith.muli %y_idx_277, %ileft_266 : tensor<1x2x8xi32, #blocked1> loc(#loc236) + %left_idx_279 = "tt.reduce"(%left_idx_278) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_743: i32 loc(callsite(#loc1 at #loc237)), %left_idx_744: i32 loc(callsite(#loc1 at #loc237))): + %left_idx_745 = arith.addi %left_idx_743, %left_idx_744 : i32 loc(#loc317) + tt.reduce.return %left_idx_745 : i32 loc(#loc302) + }) : (tensor<1x2x8xi32, #blocked1>) -> tensor<1x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc302) + %left_idx_280 = tt.expand_dims %left_idx_279 {axis = 1 : i32} : tensor<1x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<1x1x8xi32, #blocked1> loc(#loc238) + %left_idx_281 = tt.broadcast %left_idx_280 : tensor<1x1x8xi32, #blocked1> -> tensor<1x2x8xi32, #blocked1> loc(#loc239) + %right_idx_282 = arith.muli %y_idx_277, %flip_154 : tensor<1x2x8xi32, #blocked1> loc(#loc241) + %right_idx_283 = "tt.reduce"(%right_idx_282) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_743: i32 loc(callsite(#loc1 at #loc242)), %right_idx_744: i32 loc(callsite(#loc1 at #loc242))): + %right_idx_745 = arith.addi %right_idx_743, %right_idx_744 : i32 loc(#loc318) + tt.reduce.return %right_idx_745 : i32 loc(#loc305) + }) : (tensor<1x2x8xi32, #blocked1>) -> tensor<1x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc305) + %right_idx_284 = tt.expand_dims %right_idx_283 {axis = 1 : i32} : tensor<1x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<1x1x8xi32, #blocked1> loc(#loc243) + %right_idx_285 = tt.broadcast %right_idx_284 : tensor<1x1x8xi32, #blocked1> -> tensor<1x2x8xi32, #blocked1> loc(#loc244) + %left_idx_286 = tt.reshape %left_idx_281 : tensor<1x2x8xi32, #blocked1> -> tensor<1x16xi32, #blocked5> loc(#loc245) + %right_idx_287 = tt.reshape %right_idx_285 : tensor<1x2x8xi32, #blocked1> -> tensor<1x16xi32, #blocked5> loc(#loc246) + %cond_288 = arith.cmpi slt, %ileft_275, %iright_276 : tensor<1x16xi32, #blocked5> loc(#loc247) + %eq_289 = arith.cmpi eq, %ileft_275, %iright_276 : tensor<1x16xi32, #blocked5> loc(#loc248) + %cond_290 = arith.cmpi sgt, %left_idx_286, %right_idx_287 : tensor<1x16xi32, #blocked5> loc(#loc249) + %cond_291 = arith.andi %eq_289, %cond_290 : tensor<1x16xi1, #blocked5> loc(#loc250) + %cond_292 = arith.ori %cond_288, %cond_291 : tensor<1x16xi1, #blocked5> loc(#loc251) + %ret_293 = arith.xori %ileft_275, %iright_276 : tensor<1x16xi32, #blocked5> loc(#loc254) + %ret_294 = arith.select %cond_292, %ret_293, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc255) + %ret_295 = arith.xori %ret_261, %ret_294 : tensor<1x16xi32, #blocked5> loc(#loc256) + %new_idxs_296 = arith.xori %left_idx_286, %right_idx_287 : tensor<1x16xi32, #blocked5> loc(#loc257) + %new_idxs_297 = arith.select %cond_292, %new_idxs_296, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc258) + %new_idxs_298 = arith.xori %new_idxs_264, %new_idxs_297 : tensor<1x16xi32, #blocked5> loc(#loc259) + %y_299 = tt.reshape %ret_295 : tensor<1x16xi32, #blocked5> -> tensor<2x2x4xi32, #blocked2> loc(#loc222) + %ileft_300 = arith.muli %y_299, %ileft_157 : tensor<2x2x4xi32, #blocked2> loc(#loc224) + %ileft_301 = "tt.reduce"(%ileft_300) <{axis = 1 : i32}> ({ + ^bb0(%ileft_743: i32 loc(callsite(#loc1 at #loc225)), %ileft_744: i32 loc(callsite(#loc1 at #loc225))): + %ileft_745 = arith.addi %ileft_743, %ileft_744 : i32 loc(#loc315) + tt.reduce.return %ileft_745 : i32 loc(#loc297) + }) : (tensor<2x2x4xi32, #blocked2>) -> tensor<2x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc297) + %ileft_302 = tt.expand_dims %ileft_301 {axis = 1 : i32} : tensor<2x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<2x1x4xi32, #blocked2> loc(#loc226) + %ileft_303 = tt.broadcast %ileft_302 : tensor<2x1x4xi32, #blocked2> -> tensor<2x2x4xi32, #blocked2> loc(#loc227) + %iright_304 = arith.muli %y_299, %flip_79 : tensor<2x2x4xi32, #blocked2> loc(#loc228) + %iright_305 = "tt.reduce"(%iright_304) <{axis = 1 : i32}> ({ + ^bb0(%iright_743: i32 loc(callsite(#loc1 at #loc229)), %iright_744: i32 loc(callsite(#loc1 at #loc229))): + %iright_745 = arith.addi %iright_743, %iright_744 : i32 loc(#loc316) + tt.reduce.return %iright_745 : i32 loc(#loc299) + }) : (tensor<2x2x4xi32, #blocked2>) -> tensor<2x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc299) + %iright_306 = tt.expand_dims %iright_305 {axis = 1 : i32} : tensor<2x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<2x1x4xi32, #blocked2> loc(#loc230) + %iright_307 = tt.broadcast %iright_306 : tensor<2x1x4xi32, #blocked2> -> tensor<2x2x4xi32, #blocked2> loc(#loc231) + %ileft_308 = tt.reshape %ileft_303 : tensor<2x2x4xi32, #blocked2> -> tensor<1x16xi32, #blocked5> loc(#loc232) + %iright_309 = tt.reshape %iright_307 : tensor<2x2x4xi32, #blocked2> -> tensor<1x16xi32, #blocked5> loc(#loc233) + %y_idx_310 = tt.reshape %new_idxs_298 : tensor<1x16xi32, #blocked5> -> tensor<2x2x4xi32, #blocked2> loc(#loc234) + %left_idx_311 = arith.muli %y_idx_310, %ileft_157 : tensor<2x2x4xi32, #blocked2> loc(#loc236) + %left_idx_312 = "tt.reduce"(%left_idx_311) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_743: i32 loc(callsite(#loc1 at #loc237)), %left_idx_744: i32 loc(callsite(#loc1 at #loc237))): + %left_idx_745 = arith.addi %left_idx_743, %left_idx_744 : i32 loc(#loc317) + tt.reduce.return %left_idx_745 : i32 loc(#loc302) + }) : (tensor<2x2x4xi32, #blocked2>) -> tensor<2x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc302) + %left_idx_313 = tt.expand_dims %left_idx_312 {axis = 1 : i32} : tensor<2x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<2x1x4xi32, #blocked2> loc(#loc238) + %left_idx_314 = tt.broadcast %left_idx_313 : tensor<2x1x4xi32, #blocked2> -> tensor<2x2x4xi32, #blocked2> loc(#loc239) + %right_idx_315 = arith.muli %y_idx_310, %flip_79 : tensor<2x2x4xi32, #blocked2> loc(#loc241) + %right_idx_316 = "tt.reduce"(%right_idx_315) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_743: i32 loc(callsite(#loc1 at #loc242)), %right_idx_744: i32 loc(callsite(#loc1 at #loc242))): + %right_idx_745 = arith.addi %right_idx_743, %right_idx_744 : i32 loc(#loc318) + tt.reduce.return %right_idx_745 : i32 loc(#loc305) + }) : (tensor<2x2x4xi32, #blocked2>) -> tensor<2x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc305) + %right_idx_317 = tt.expand_dims %right_idx_316 {axis = 1 : i32} : tensor<2x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<2x1x4xi32, #blocked2> loc(#loc243) + %right_idx_318 = tt.broadcast %right_idx_317 : tensor<2x1x4xi32, #blocked2> -> tensor<2x2x4xi32, #blocked2> loc(#loc244) + %left_idx_319 = tt.reshape %left_idx_314 : tensor<2x2x4xi32, #blocked2> -> tensor<1x16xi32, #blocked5> loc(#loc245) + %right_idx_320 = tt.reshape %right_idx_318 : tensor<2x2x4xi32, #blocked2> -> tensor<1x16xi32, #blocked5> loc(#loc246) + %cond_321 = arith.cmpi slt, %ileft_308, %iright_309 : tensor<1x16xi32, #blocked5> loc(#loc247) + %eq_322 = arith.cmpi eq, %ileft_308, %iright_309 : tensor<1x16xi32, #blocked5> loc(#loc248) + %cond_323 = arith.cmpi sgt, %left_idx_319, %right_idx_320 : tensor<1x16xi32, #blocked5> loc(#loc249) + %cond_324 = arith.andi %eq_322, %cond_323 : tensor<1x16xi1, #blocked5> loc(#loc250) + %cond_325 = arith.ori %cond_321, %cond_324 : tensor<1x16xi1, #blocked5> loc(#loc251) + %ret_326 = arith.xori %ileft_308, %iright_309 : tensor<1x16xi32, #blocked5> loc(#loc254) + %ret_327 = arith.select %cond_325, %ret_326, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc255) + %ret_328 = arith.xori %ret_295, %ret_327 : tensor<1x16xi32, #blocked5> loc(#loc256) + %new_idxs_329 = arith.xori %left_idx_319, %right_idx_320 : tensor<1x16xi32, #blocked5> loc(#loc257) + %new_idxs_330 = arith.select %cond_325, %new_idxs_329, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc258) + %new_idxs_331 = arith.xori %new_idxs_298, %new_idxs_330 : tensor<1x16xi32, #blocked5> loc(#loc259) + %y_332 = tt.reshape %ret_328 : tensor<1x16xi32, #blocked5> -> tensor<4x2x2xi32, #blocked3> loc(#loc222) + %ileft_333 = arith.muli %y_332, %ileft_82 : tensor<4x2x2xi32, #blocked3> loc(#loc224) + %ileft_334 = "tt.reduce"(%ileft_333) <{axis = 1 : i32}> ({ + ^bb0(%ileft_743: i32 loc(callsite(#loc1 at #loc225)), %ileft_744: i32 loc(callsite(#loc1 at #loc225))): + %ileft_745 = arith.addi %ileft_743, %ileft_744 : i32 loc(#loc315) + tt.reduce.return %ileft_745 : i32 loc(#loc297) + }) : (tensor<4x2x2xi32, #blocked3>) -> tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc297) + %ileft_335 = tt.expand_dims %ileft_334 {axis = 1 : i32} : tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<4x1x2xi32, #blocked3> loc(#loc226) + %ileft_336 = tt.broadcast %ileft_335 : tensor<4x1x2xi32, #blocked3> -> tensor<4x2x2xi32, #blocked3> loc(#loc227) + %iright_337 = arith.muli %y_332, %flip_40 : tensor<4x2x2xi32, #blocked3> loc(#loc228) + %iright_338 = "tt.reduce"(%iright_337) <{axis = 1 : i32}> ({ + ^bb0(%iright_743: i32 loc(callsite(#loc1 at #loc229)), %iright_744: i32 loc(callsite(#loc1 at #loc229))): + %iright_745 = arith.addi %iright_743, %iright_744 : i32 loc(#loc316) + tt.reduce.return %iright_745 : i32 loc(#loc299) + }) : (tensor<4x2x2xi32, #blocked3>) -> tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc299) + %iright_339 = tt.expand_dims %iright_338 {axis = 1 : i32} : tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<4x1x2xi32, #blocked3> loc(#loc230) + %iright_340 = tt.broadcast %iright_339 : tensor<4x1x2xi32, #blocked3> -> tensor<4x2x2xi32, #blocked3> loc(#loc231) + %ileft_341 = tt.reshape %ileft_336 : tensor<4x2x2xi32, #blocked3> -> tensor<1x16xi32, #blocked5> loc(#loc232) + %iright_342 = tt.reshape %iright_340 : tensor<4x2x2xi32, #blocked3> -> tensor<1x16xi32, #blocked5> loc(#loc233) + %y_idx_343 = tt.reshape %new_idxs_331 : tensor<1x16xi32, #blocked5> -> tensor<4x2x2xi32, #blocked3> loc(#loc234) + %left_idx_344 = arith.muli %y_idx_343, %ileft_82 : tensor<4x2x2xi32, #blocked3> loc(#loc236) + %left_idx_345 = "tt.reduce"(%left_idx_344) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_743: i32 loc(callsite(#loc1 at #loc237)), %left_idx_744: i32 loc(callsite(#loc1 at #loc237))): + %left_idx_745 = arith.addi %left_idx_743, %left_idx_744 : i32 loc(#loc317) + tt.reduce.return %left_idx_745 : i32 loc(#loc302) + }) : (tensor<4x2x2xi32, #blocked3>) -> tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc302) + %left_idx_346 = tt.expand_dims %left_idx_345 {axis = 1 : i32} : tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<4x1x2xi32, #blocked3> loc(#loc238) + %left_idx_347 = tt.broadcast %left_idx_346 : tensor<4x1x2xi32, #blocked3> -> tensor<4x2x2xi32, #blocked3> loc(#loc239) + %right_idx_348 = arith.muli %y_idx_343, %flip_40 : tensor<4x2x2xi32, #blocked3> loc(#loc241) + %right_idx_349 = "tt.reduce"(%right_idx_348) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_743: i32 loc(callsite(#loc1 at #loc242)), %right_idx_744: i32 loc(callsite(#loc1 at #loc242))): + %right_idx_745 = arith.addi %right_idx_743, %right_idx_744 : i32 loc(#loc318) + tt.reduce.return %right_idx_745 : i32 loc(#loc305) + }) : (tensor<4x2x2xi32, #blocked3>) -> tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc305) + %right_idx_350 = tt.expand_dims %right_idx_349 {axis = 1 : i32} : tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<4x1x2xi32, #blocked3> loc(#loc243) + %right_idx_351 = tt.broadcast %right_idx_350 : tensor<4x1x2xi32, #blocked3> -> tensor<4x2x2xi32, #blocked3> loc(#loc244) + %left_idx_352 = tt.reshape %left_idx_347 : tensor<4x2x2xi32, #blocked3> -> tensor<1x16xi32, #blocked5> loc(#loc245) + %right_idx_353 = tt.reshape %right_idx_351 : tensor<4x2x2xi32, #blocked3> -> tensor<1x16xi32, #blocked5> loc(#loc246) + %cond_354 = arith.cmpi slt, %ileft_341, %iright_342 : tensor<1x16xi32, #blocked5> loc(#loc247) + %eq_355 = arith.cmpi eq, %ileft_341, %iright_342 : tensor<1x16xi32, #blocked5> loc(#loc248) + %cond_356 = arith.cmpi sgt, %left_idx_352, %right_idx_353 : tensor<1x16xi32, #blocked5> loc(#loc249) + %cond_357 = arith.andi %eq_355, %cond_356 : tensor<1x16xi1, #blocked5> loc(#loc250) + %cond_358 = arith.ori %cond_354, %cond_357 : tensor<1x16xi1, #blocked5> loc(#loc251) + %ret_359 = arith.xori %ileft_341, %iright_342 : tensor<1x16xi32, #blocked5> loc(#loc254) + %ret_360 = arith.select %cond_358, %ret_359, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc255) + %ret_361 = arith.xori %ret_328, %ret_360 : tensor<1x16xi32, #blocked5> loc(#loc256) + %new_idxs_362 = arith.xori %left_idx_352, %right_idx_353 : tensor<1x16xi32, #blocked5> loc(#loc257) + %new_idxs_363 = arith.select %cond_358, %new_idxs_362, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc258) + %new_idxs_364 = arith.xori %new_idxs_331, %new_idxs_363 : tensor<1x16xi32, #blocked5> loc(#loc259) + %y_365 = tt.reshape %ret_361 : tensor<1x16xi32, #blocked5> -> tensor<8x2x1xi32, #blocked4> loc(#loc222) + %ileft_366 = arith.muli %y_365, %ileft : tensor<8x2x1xi32, #blocked4> loc(#loc224) + %ileft_367 = "tt.reduce"(%ileft_366) <{axis = 1 : i32}> ({ + ^bb0(%ileft_743: i32 loc(callsite(#loc1 at #loc225)), %ileft_744: i32 loc(callsite(#loc1 at #loc225))): + %ileft_745 = arith.addi %ileft_743, %ileft_744 : i32 loc(#loc315) + tt.reduce.return %ileft_745 : i32 loc(#loc297) + }) : (tensor<8x2x1xi32, #blocked4>) -> tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc297) + %ileft_368 = tt.expand_dims %ileft_367 {axis = 1 : i32} : tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<8x1x1xi32, #blocked4> loc(#loc226) + %ileft_369 = tt.broadcast %ileft_368 : tensor<8x1x1xi32, #blocked4> -> tensor<8x2x1xi32, #blocked4> loc(#loc227) + %iright_370 = arith.muli %y_365, %iright : tensor<8x2x1xi32, #blocked4> loc(#loc228) + %iright_371 = "tt.reduce"(%iright_370) <{axis = 1 : i32}> ({ + ^bb0(%iright_743: i32 loc(callsite(#loc1 at #loc229)), %iright_744: i32 loc(callsite(#loc1 at #loc229))): + %iright_745 = arith.addi %iright_743, %iright_744 : i32 loc(#loc316) + tt.reduce.return %iright_745 : i32 loc(#loc299) + }) : (tensor<8x2x1xi32, #blocked4>) -> tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc299) + %iright_372 = tt.expand_dims %iright_371 {axis = 1 : i32} : tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<8x1x1xi32, #blocked4> loc(#loc230) + %iright_373 = tt.broadcast %iright_372 : tensor<8x1x1xi32, #blocked4> -> tensor<8x2x1xi32, #blocked4> loc(#loc231) + %ileft_374 = tt.reshape %ileft_369 : tensor<8x2x1xi32, #blocked4> -> tensor<1x16xi32, #blocked5> loc(#loc232) + %iright_375 = tt.reshape %iright_373 : tensor<8x2x1xi32, #blocked4> -> tensor<1x16xi32, #blocked5> loc(#loc233) + %y_idx_376 = tt.reshape %new_idxs_364 : tensor<1x16xi32, #blocked5> -> tensor<8x2x1xi32, #blocked4> loc(#loc234) + %left_idx_377 = arith.muli %y_idx_376, %ileft : tensor<8x2x1xi32, #blocked4> loc(#loc236) + %left_idx_378 = "tt.reduce"(%left_idx_377) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_743: i32 loc(callsite(#loc1 at #loc237)), %left_idx_744: i32 loc(callsite(#loc1 at #loc237))): + %left_idx_745 = arith.addi %left_idx_743, %left_idx_744 : i32 loc(#loc317) + tt.reduce.return %left_idx_745 : i32 loc(#loc302) + }) : (tensor<8x2x1xi32, #blocked4>) -> tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc302) + %left_idx_379 = tt.expand_dims %left_idx_378 {axis = 1 : i32} : tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<8x1x1xi32, #blocked4> loc(#loc238) + %left_idx_380 = tt.broadcast %left_idx_379 : tensor<8x1x1xi32, #blocked4> -> tensor<8x2x1xi32, #blocked4> loc(#loc239) + %right_idx_381 = arith.muli %y_idx_376, %iright : tensor<8x2x1xi32, #blocked4> loc(#loc241) + %right_idx_382 = "tt.reduce"(%right_idx_381) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_743: i32 loc(callsite(#loc1 at #loc242)), %right_idx_744: i32 loc(callsite(#loc1 at #loc242))): + %right_idx_745 = arith.addi %right_idx_743, %right_idx_744 : i32 loc(#loc318) + tt.reduce.return %right_idx_745 : i32 loc(#loc305) + }) : (tensor<8x2x1xi32, #blocked4>) -> tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc305) + %right_idx_383 = tt.expand_dims %right_idx_382 {axis = 1 : i32} : tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<8x1x1xi32, #blocked4> loc(#loc243) + %right_idx_384 = tt.broadcast %right_idx_383 : tensor<8x1x1xi32, #blocked4> -> tensor<8x2x1xi32, #blocked4> loc(#loc244) + %left_idx_385 = tt.reshape %left_idx_380 : tensor<8x2x1xi32, #blocked4> -> tensor<1x16xi32, #blocked5> loc(#loc245) + %right_idx_386 = tt.reshape %right_idx_384 : tensor<8x2x1xi32, #blocked4> -> tensor<1x16xi32, #blocked5> loc(#loc246) + %cond_387 = arith.cmpi slt, %ileft_374, %iright_375 : tensor<1x16xi32, #blocked5> loc(#loc247) + %eq_388 = arith.cmpi eq, %ileft_374, %iright_375 : tensor<1x16xi32, #blocked5> loc(#loc248) + %cond_389 = arith.cmpi sgt, %left_idx_385, %right_idx_386 : tensor<1x16xi32, #blocked5> loc(#loc249) + %cond_390 = arith.andi %eq_388, %cond_389 : tensor<1x16xi1, #blocked5> loc(#loc250) + %cond_391 = arith.ori %cond_387, %cond_390 : tensor<1x16xi1, #blocked5> loc(#loc251) + %new_idxs_392 = arith.xori %left_idx_385, %right_idx_386 : tensor<1x16xi32, #blocked5> loc(#loc257) + %new_idxs_393 = arith.select %cond_391, %new_idxs_392, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc258) + %new_idxs_394 = arith.xori %new_idxs_364, %new_idxs_393 : tensor<1x16xi32, #blocked5> loc(#loc259) + %tmp14 = arith.cmpi eq, %tmp0_24, %cst_9 : tensor<1x16xi64, #blocked5> loc(#loc178) + %tmp14_395 = arith.cmpi eq, %tmp0_25, %cst_0 : tensor<1x16xi64, #blocked> loc(#loc178) + %tmp16 = arith.extui %tmp14 : tensor<1x16xi1, #blocked5> to tensor<1x16xi32, #blocked5> loc(#loc213) + %y_396 = tt.reshape %tmp16 : tensor<1x16xi32, #blocked5> -> tensor<8x2x1xi32, #blocked4> loc(#loc260) + %ileft_397 = arith.muli %y_396, %ileft : tensor<8x2x1xi32, #blocked4> loc(#loc261) + %ileft_398 = "tt.reduce"(%ileft_397) <{axis = 1 : i32}> ({ + ^bb0(%ileft_743: i32 loc(callsite(#loc1 at #loc262)), %ileft_744: i32 loc(callsite(#loc1 at #loc262))): + %ileft_745 = arith.addi %ileft_743, %ileft_744 : i32 loc(#loc319) + tt.reduce.return %ileft_745 : i32 loc(#loc307) + }) : (tensor<8x2x1xi32, #blocked4>) -> tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc307) + %ileft_399 = tt.expand_dims %ileft_398 {axis = 1 : i32} : tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<8x1x1xi32, #blocked4> loc(#loc263) + %ileft_400 = tt.broadcast %ileft_399 : tensor<8x1x1xi32, #blocked4> -> tensor<8x2x1xi32, #blocked4> loc(#loc264) + %iright_401 = arith.muli %y_396, %iright : tensor<8x2x1xi32, #blocked4> loc(#loc265) + %iright_402 = "tt.reduce"(%iright_401) <{axis = 1 : i32}> ({ + ^bb0(%iright_743: i32 loc(callsite(#loc1 at #loc266)), %iright_744: i32 loc(callsite(#loc1 at #loc266))): + %iright_745 = arith.addi %iright_743, %iright_744 : i32 loc(#loc320) + tt.reduce.return %iright_745 : i32 loc(#loc309) + }) : (tensor<8x2x1xi32, #blocked4>) -> tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc309) + %iright_403 = tt.expand_dims %iright_402 {axis = 1 : i32} : tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<8x1x1xi32, #blocked4> loc(#loc267) + %iright_404 = tt.broadcast %iright_403 : tensor<8x1x1xi32, #blocked4> -> tensor<8x2x1xi32, #blocked4> loc(#loc268) + %ileft_405 = tt.reshape %ileft_400 : tensor<8x2x1xi32, #blocked4> -> tensor<1x16xi32, #blocked5> loc(#loc269) + %iright_406 = tt.reshape %iright_404 : tensor<8x2x1xi32, #blocked4> -> tensor<1x16xi32, #blocked5> loc(#loc270) + %cond_407 = arith.cmpi slt, %ileft_405, %iright_406 : tensor<1x16xi32, #blocked5> loc(#loc271) + %eq_408 = arith.cmpi eq, %ileft_405, %iright_406 : tensor<1x16xi32, #blocked5> loc(#loc272) + %cond_409 = arith.andi %eq_408, %cond_68 : tensor<1x16xi1, #blocked5> loc(#loc273) + %cond_410 = arith.ori %cond_407, %cond_409 : tensor<1x16xi1, #blocked5> loc(#loc274) + %cond_411 = arith.extui %cond_410 : tensor<1x16xi1, #blocked5> to tensor<1x16xi32, #blocked5> loc(#loc275) + %cond_412 = arith.xori %cond_411, %flip_41 : tensor<1x16xi32, #blocked5> loc(#loc275) + %cond_413 = arith.cmpi ne, %cond_412, %cst_6 : tensor<1x16xi32, #blocked5> loc(#loc276) + %ret_414 = arith.xori %ileft_405, %iright_406 : tensor<1x16xi32, #blocked5> loc(#loc277) + %ret_415 = arith.select %cond_413, %ret_414, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc278) + %ret_416 = arith.xori %tmp16, %ret_415 : tensor<1x16xi32, #blocked5> loc(#loc279) + %new_idxs_417 = arith.select %cond_413, %new_idxs, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc280) + %new_idxs_418 = arith.xori %new_idxs_77, %new_idxs_417 : tensor<1x16xi32, #blocked5> loc(#loc281) + %y_419 = tt.reshape %ret_416 : tensor<1x16xi32, #blocked5> -> tensor<4x2x2xi32, #blocked3> loc(#loc260) + %ileft_420 = arith.muli %y_419, %ileft_82 : tensor<4x2x2xi32, #blocked3> loc(#loc261) + %ileft_421 = "tt.reduce"(%ileft_420) <{axis = 1 : i32}> ({ + ^bb0(%ileft_743: i32 loc(callsite(#loc1 at #loc262)), %ileft_744: i32 loc(callsite(#loc1 at #loc262))): + %ileft_745 = arith.addi %ileft_743, %ileft_744 : i32 loc(#loc319) + tt.reduce.return %ileft_745 : i32 loc(#loc307) + }) : (tensor<4x2x2xi32, #blocked3>) -> tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc307) + %ileft_422 = tt.expand_dims %ileft_421 {axis = 1 : i32} : tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<4x1x2xi32, #blocked3> loc(#loc263) + %ileft_423 = tt.broadcast %ileft_422 : tensor<4x1x2xi32, #blocked3> -> tensor<4x2x2xi32, #blocked3> loc(#loc264) + %iright_424 = arith.muli %y_419, %flip_40 : tensor<4x2x2xi32, #blocked3> loc(#loc265) + %iright_425 = "tt.reduce"(%iright_424) <{axis = 1 : i32}> ({ + ^bb0(%iright_743: i32 loc(callsite(#loc1 at #loc266)), %iright_744: i32 loc(callsite(#loc1 at #loc266))): + %iright_745 = arith.addi %iright_743, %iright_744 : i32 loc(#loc320) + tt.reduce.return %iright_745 : i32 loc(#loc309) + }) : (tensor<4x2x2xi32, #blocked3>) -> tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc309) + %iright_426 = tt.expand_dims %iright_425 {axis = 1 : i32} : tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<4x1x2xi32, #blocked3> loc(#loc267) + %iright_427 = tt.broadcast %iright_426 : tensor<4x1x2xi32, #blocked3> -> tensor<4x2x2xi32, #blocked3> loc(#loc268) + %ileft_428 = tt.reshape %ileft_423 : tensor<4x2x2xi32, #blocked3> -> tensor<1x16xi32, #blocked5> loc(#loc269) + %iright_429 = tt.reshape %iright_427 : tensor<4x2x2xi32, #blocked3> -> tensor<1x16xi32, #blocked5> loc(#loc270) + %y_idx_430 = tt.reshape %new_idxs_418 : tensor<1x16xi32, #blocked5> -> tensor<4x2x2xi32, #blocked3> loc(#loc282) + %left_idx_431 = arith.muli %y_idx_430, %ileft_82 : tensor<4x2x2xi32, #blocked3> loc(#loc283) + %left_idx_432 = "tt.reduce"(%left_idx_431) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_743: i32 loc(callsite(#loc1 at #loc284)), %left_idx_744: i32 loc(callsite(#loc1 at #loc284))): + %left_idx_745 = arith.addi %left_idx_743, %left_idx_744 : i32 loc(#loc321) + tt.reduce.return %left_idx_745 : i32 loc(#loc311) + }) : (tensor<4x2x2xi32, #blocked3>) -> tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc311) + %left_idx_433 = tt.expand_dims %left_idx_432 {axis = 1 : i32} : tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<4x1x2xi32, #blocked3> loc(#loc285) + %left_idx_434 = tt.broadcast %left_idx_433 : tensor<4x1x2xi32, #blocked3> -> tensor<4x2x2xi32, #blocked3> loc(#loc286) + %right_idx_435 = arith.muli %y_idx_430, %flip_40 : tensor<4x2x2xi32, #blocked3> loc(#loc287) + %right_idx_436 = "tt.reduce"(%right_idx_435) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_743: i32 loc(callsite(#loc1 at #loc288)), %right_idx_744: i32 loc(callsite(#loc1 at #loc288))): + %right_idx_745 = arith.addi %right_idx_743, %right_idx_744 : i32 loc(#loc322) + tt.reduce.return %right_idx_745 : i32 loc(#loc313) + }) : (tensor<4x2x2xi32, #blocked3>) -> tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc313) + %right_idx_437 = tt.expand_dims %right_idx_436 {axis = 1 : i32} : tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<4x1x2xi32, #blocked3> loc(#loc289) + %right_idx_438 = tt.broadcast %right_idx_437 : tensor<4x1x2xi32, #blocked3> -> tensor<4x2x2xi32, #blocked3> loc(#loc290) + %left_idx_439 = tt.reshape %left_idx_434 : tensor<4x2x2xi32, #blocked3> -> tensor<1x16xi32, #blocked5> loc(#loc291) + %right_idx_440 = tt.reshape %right_idx_438 : tensor<4x2x2xi32, #blocked3> -> tensor<1x16xi32, #blocked5> loc(#loc292) + %cond_441 = arith.cmpi slt, %ileft_428, %iright_429 : tensor<1x16xi32, #blocked5> loc(#loc271) + %eq_442 = arith.cmpi eq, %ileft_428, %iright_429 : tensor<1x16xi32, #blocked5> loc(#loc272) + %cond_443 = arith.cmpi sgt, %left_idx_439, %right_idx_440 : tensor<1x16xi32, #blocked5> loc(#loc293) + %cond_444 = arith.andi %eq_442, %cond_443 : tensor<1x16xi1, #blocked5> loc(#loc273) + %cond_445 = arith.ori %cond_441, %cond_444 : tensor<1x16xi1, #blocked5> loc(#loc274) + %cond_446 = arith.extui %cond_445 : tensor<1x16xi1, #blocked5> to tensor<1x16xi32, #blocked5> loc(#loc275) + %cond_447 = arith.xori %cond_446, %flip_80 : tensor<1x16xi32, #blocked5> loc(#loc275) + %cond_448 = arith.cmpi ne, %cond_447, %cst_6 : tensor<1x16xi32, #blocked5> loc(#loc276) + %ret_449 = arith.xori %ileft_428, %iright_429 : tensor<1x16xi32, #blocked5> loc(#loc277) + %ret_450 = arith.select %cond_448, %ret_449, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc278) + %ret_451 = arith.xori %ret_416, %ret_450 : tensor<1x16xi32, #blocked5> loc(#loc279) + %new_idxs_452 = arith.xori %left_idx_439, %right_idx_440 : tensor<1x16xi32, #blocked5> loc(#loc294) + %new_idxs_453 = arith.select %cond_448, %new_idxs_452, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc280) + %new_idxs_454 = arith.xori %new_idxs_418, %new_idxs_453 : tensor<1x16xi32, #blocked5> loc(#loc281) + %y_455 = tt.reshape %ret_451 : tensor<1x16xi32, #blocked5> -> tensor<8x2x1xi32, #blocked4> loc(#loc260) + %ileft_456 = arith.muli %y_455, %ileft : tensor<8x2x1xi32, #blocked4> loc(#loc261) + %ileft_457 = "tt.reduce"(%ileft_456) <{axis = 1 : i32}> ({ + ^bb0(%ileft_743: i32 loc(callsite(#loc1 at #loc262)), %ileft_744: i32 loc(callsite(#loc1 at #loc262))): + %ileft_745 = arith.addi %ileft_743, %ileft_744 : i32 loc(#loc319) + tt.reduce.return %ileft_745 : i32 loc(#loc307) + }) : (tensor<8x2x1xi32, #blocked4>) -> tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc307) + %ileft_458 = tt.expand_dims %ileft_457 {axis = 1 : i32} : tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<8x1x1xi32, #blocked4> loc(#loc263) + %ileft_459 = tt.broadcast %ileft_458 : tensor<8x1x1xi32, #blocked4> -> tensor<8x2x1xi32, #blocked4> loc(#loc264) + %iright_460 = arith.muli %y_455, %iright : tensor<8x2x1xi32, #blocked4> loc(#loc265) + %iright_461 = "tt.reduce"(%iright_460) <{axis = 1 : i32}> ({ + ^bb0(%iright_743: i32 loc(callsite(#loc1 at #loc266)), %iright_744: i32 loc(callsite(#loc1 at #loc266))): + %iright_745 = arith.addi %iright_743, %iright_744 : i32 loc(#loc320) + tt.reduce.return %iright_745 : i32 loc(#loc309) + }) : (tensor<8x2x1xi32, #blocked4>) -> tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc309) + %iright_462 = tt.expand_dims %iright_461 {axis = 1 : i32} : tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<8x1x1xi32, #blocked4> loc(#loc267) + %iright_463 = tt.broadcast %iright_462 : tensor<8x1x1xi32, #blocked4> -> tensor<8x2x1xi32, #blocked4> loc(#loc268) + %ileft_464 = tt.reshape %ileft_459 : tensor<8x2x1xi32, #blocked4> -> tensor<1x16xi32, #blocked5> loc(#loc269) + %iright_465 = tt.reshape %iright_463 : tensor<8x2x1xi32, #blocked4> -> tensor<1x16xi32, #blocked5> loc(#loc270) + %y_idx_466 = tt.reshape %new_idxs_454 : tensor<1x16xi32, #blocked5> -> tensor<8x2x1xi32, #blocked4> loc(#loc282) + %left_idx_467 = arith.muli %y_idx_466, %ileft : tensor<8x2x1xi32, #blocked4> loc(#loc283) + %left_idx_468 = "tt.reduce"(%left_idx_467) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_743: i32 loc(callsite(#loc1 at #loc284)), %left_idx_744: i32 loc(callsite(#loc1 at #loc284))): + %left_idx_745 = arith.addi %left_idx_743, %left_idx_744 : i32 loc(#loc321) + tt.reduce.return %left_idx_745 : i32 loc(#loc311) + }) : (tensor<8x2x1xi32, #blocked4>) -> tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc311) + %left_idx_469 = tt.expand_dims %left_idx_468 {axis = 1 : i32} : tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<8x1x1xi32, #blocked4> loc(#loc285) + %left_idx_470 = tt.broadcast %left_idx_469 : tensor<8x1x1xi32, #blocked4> -> tensor<8x2x1xi32, #blocked4> loc(#loc286) + %right_idx_471 = arith.muli %y_idx_466, %iright : tensor<8x2x1xi32, #blocked4> loc(#loc287) + %right_idx_472 = "tt.reduce"(%right_idx_471) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_743: i32 loc(callsite(#loc1 at #loc288)), %right_idx_744: i32 loc(callsite(#loc1 at #loc288))): + %right_idx_745 = arith.addi %right_idx_743, %right_idx_744 : i32 loc(#loc322) + tt.reduce.return %right_idx_745 : i32 loc(#loc313) + }) : (tensor<8x2x1xi32, #blocked4>) -> tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc313) + %right_idx_473 = tt.expand_dims %right_idx_472 {axis = 1 : i32} : tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<8x1x1xi32, #blocked4> loc(#loc289) + %right_idx_474 = tt.broadcast %right_idx_473 : tensor<8x1x1xi32, #blocked4> -> tensor<8x2x1xi32, #blocked4> loc(#loc290) + %left_idx_475 = tt.reshape %left_idx_470 : tensor<8x2x1xi32, #blocked4> -> tensor<1x16xi32, #blocked5> loc(#loc291) + %right_idx_476 = tt.reshape %right_idx_474 : tensor<8x2x1xi32, #blocked4> -> tensor<1x16xi32, #blocked5> loc(#loc292) + %cond_477 = arith.cmpi slt, %ileft_464, %iright_465 : tensor<1x16xi32, #blocked5> loc(#loc271) + %eq_478 = arith.cmpi eq, %ileft_464, %iright_465 : tensor<1x16xi32, #blocked5> loc(#loc272) + %cond_479 = arith.cmpi sgt, %left_idx_475, %right_idx_476 : tensor<1x16xi32, #blocked5> loc(#loc293) + %cond_480 = arith.andi %eq_478, %cond_479 : tensor<1x16xi1, #blocked5> loc(#loc273) + %cond_481 = arith.ori %cond_477, %cond_480 : tensor<1x16xi1, #blocked5> loc(#loc274) + %cond_482 = arith.extui %cond_481 : tensor<1x16xi1, #blocked5> to tensor<1x16xi32, #blocked5> loc(#loc275) + %cond_483 = arith.xori %cond_482, %flip_80 : tensor<1x16xi32, #blocked5> loc(#loc275) + %cond_484 = arith.cmpi ne, %cond_483, %cst_6 : tensor<1x16xi32, #blocked5> loc(#loc276) + %ret_485 = arith.xori %ileft_464, %iright_465 : tensor<1x16xi32, #blocked5> loc(#loc277) + %ret_486 = arith.select %cond_484, %ret_485, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc278) + %ret_487 = arith.xori %ret_451, %ret_486 : tensor<1x16xi32, #blocked5> loc(#loc279) + %new_idxs_488 = arith.xori %left_idx_475, %right_idx_476 : tensor<1x16xi32, #blocked5> loc(#loc294) + %new_idxs_489 = arith.select %cond_484, %new_idxs_488, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc280) + %new_idxs_490 = arith.xori %new_idxs_454, %new_idxs_489 : tensor<1x16xi32, #blocked5> loc(#loc281) + %y_491 = tt.reshape %ret_487 : tensor<1x16xi32, #blocked5> -> tensor<2x2x4xi32, #blocked2> loc(#loc260) + %ileft_492 = arith.muli %y_491, %ileft_157 : tensor<2x2x4xi32, #blocked2> loc(#loc261) + %ileft_493 = "tt.reduce"(%ileft_492) <{axis = 1 : i32}> ({ + ^bb0(%ileft_743: i32 loc(callsite(#loc1 at #loc262)), %ileft_744: i32 loc(callsite(#loc1 at #loc262))): + %ileft_745 = arith.addi %ileft_743, %ileft_744 : i32 loc(#loc319) + tt.reduce.return %ileft_745 : i32 loc(#loc307) + }) : (tensor<2x2x4xi32, #blocked2>) -> tensor<2x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc307) + %ileft_494 = tt.expand_dims %ileft_493 {axis = 1 : i32} : tensor<2x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<2x1x4xi32, #blocked2> loc(#loc263) + %ileft_495 = tt.broadcast %ileft_494 : tensor<2x1x4xi32, #blocked2> -> tensor<2x2x4xi32, #blocked2> loc(#loc264) + %iright_496 = arith.muli %y_491, %flip_79 : tensor<2x2x4xi32, #blocked2> loc(#loc265) + %iright_497 = "tt.reduce"(%iright_496) <{axis = 1 : i32}> ({ + ^bb0(%iright_743: i32 loc(callsite(#loc1 at #loc266)), %iright_744: i32 loc(callsite(#loc1 at #loc266))): + %iright_745 = arith.addi %iright_743, %iright_744 : i32 loc(#loc320) + tt.reduce.return %iright_745 : i32 loc(#loc309) + }) : (tensor<2x2x4xi32, #blocked2>) -> tensor<2x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc309) + %iright_498 = tt.expand_dims %iright_497 {axis = 1 : i32} : tensor<2x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<2x1x4xi32, #blocked2> loc(#loc267) + %iright_499 = tt.broadcast %iright_498 : tensor<2x1x4xi32, #blocked2> -> tensor<2x2x4xi32, #blocked2> loc(#loc268) + %ileft_500 = tt.reshape %ileft_495 : tensor<2x2x4xi32, #blocked2> -> tensor<1x16xi32, #blocked5> loc(#loc269) + %iright_501 = tt.reshape %iright_499 : tensor<2x2x4xi32, #blocked2> -> tensor<1x16xi32, #blocked5> loc(#loc270) + %y_idx_502 = tt.reshape %new_idxs_490 : tensor<1x16xi32, #blocked5> -> tensor<2x2x4xi32, #blocked2> loc(#loc282) + %left_idx_503 = arith.muli %y_idx_502, %ileft_157 : tensor<2x2x4xi32, #blocked2> loc(#loc283) + %left_idx_504 = "tt.reduce"(%left_idx_503) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_743: i32 loc(callsite(#loc1 at #loc284)), %left_idx_744: i32 loc(callsite(#loc1 at #loc284))): + %left_idx_745 = arith.addi %left_idx_743, %left_idx_744 : i32 loc(#loc321) + tt.reduce.return %left_idx_745 : i32 loc(#loc311) + }) : (tensor<2x2x4xi32, #blocked2>) -> tensor<2x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc311) + %left_idx_505 = tt.expand_dims %left_idx_504 {axis = 1 : i32} : tensor<2x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<2x1x4xi32, #blocked2> loc(#loc285) + %left_idx_506 = tt.broadcast %left_idx_505 : tensor<2x1x4xi32, #blocked2> -> tensor<2x2x4xi32, #blocked2> loc(#loc286) + %right_idx_507 = arith.muli %y_idx_502, %flip_79 : tensor<2x2x4xi32, #blocked2> loc(#loc287) + %right_idx_508 = "tt.reduce"(%right_idx_507) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_743: i32 loc(callsite(#loc1 at #loc288)), %right_idx_744: i32 loc(callsite(#loc1 at #loc288))): + %right_idx_745 = arith.addi %right_idx_743, %right_idx_744 : i32 loc(#loc322) + tt.reduce.return %right_idx_745 : i32 loc(#loc313) + }) : (tensor<2x2x4xi32, #blocked2>) -> tensor<2x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc313) + %right_idx_509 = tt.expand_dims %right_idx_508 {axis = 1 : i32} : tensor<2x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<2x1x4xi32, #blocked2> loc(#loc289) + %right_idx_510 = tt.broadcast %right_idx_509 : tensor<2x1x4xi32, #blocked2> -> tensor<2x2x4xi32, #blocked2> loc(#loc290) + %left_idx_511 = tt.reshape %left_idx_506 : tensor<2x2x4xi32, #blocked2> -> tensor<1x16xi32, #blocked5> loc(#loc291) + %right_idx_512 = tt.reshape %right_idx_510 : tensor<2x2x4xi32, #blocked2> -> tensor<1x16xi32, #blocked5> loc(#loc292) + %cond_513 = arith.cmpi slt, %ileft_500, %iright_501 : tensor<1x16xi32, #blocked5> loc(#loc271) + %eq_514 = arith.cmpi eq, %ileft_500, %iright_501 : tensor<1x16xi32, #blocked5> loc(#loc272) + %cond_515 = arith.cmpi sgt, %left_idx_511, %right_idx_512 : tensor<1x16xi32, #blocked5> loc(#loc293) + %cond_516 = arith.andi %eq_514, %cond_515 : tensor<1x16xi1, #blocked5> loc(#loc273) + %cond_517 = arith.ori %cond_513, %cond_516 : tensor<1x16xi1, #blocked5> loc(#loc274) + %cond_518 = arith.extui %cond_517 : tensor<1x16xi1, #blocked5> to tensor<1x16xi32, #blocked5> loc(#loc275) + %cond_519 = arith.xori %cond_518, %flip_155 : tensor<1x16xi32, #blocked5> loc(#loc275) + %cond_520 = arith.cmpi ne, %cond_519, %cst_6 : tensor<1x16xi32, #blocked5> loc(#loc276) + %ret_521 = arith.xori %ileft_500, %iright_501 : tensor<1x16xi32, #blocked5> loc(#loc277) + %ret_522 = arith.select %cond_520, %ret_521, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc278) + %ret_523 = arith.xori %ret_487, %ret_522 : tensor<1x16xi32, #blocked5> loc(#loc279) + %new_idxs_524 = arith.xori %left_idx_511, %right_idx_512 : tensor<1x16xi32, #blocked5> loc(#loc294) + %new_idxs_525 = arith.select %cond_520, %new_idxs_524, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc280) + %new_idxs_526 = arith.xori %new_idxs_490, %new_idxs_525 : tensor<1x16xi32, #blocked5> loc(#loc281) + %y_527 = tt.reshape %ret_523 : tensor<1x16xi32, #blocked5> -> tensor<4x2x2xi32, #blocked3> loc(#loc260) + %ileft_528 = arith.muli %y_527, %ileft_82 : tensor<4x2x2xi32, #blocked3> loc(#loc261) + %ileft_529 = "tt.reduce"(%ileft_528) <{axis = 1 : i32}> ({ + ^bb0(%ileft_743: i32 loc(callsite(#loc1 at #loc262)), %ileft_744: i32 loc(callsite(#loc1 at #loc262))): + %ileft_745 = arith.addi %ileft_743, %ileft_744 : i32 loc(#loc319) + tt.reduce.return %ileft_745 : i32 loc(#loc307) + }) : (tensor<4x2x2xi32, #blocked3>) -> tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc307) + %ileft_530 = tt.expand_dims %ileft_529 {axis = 1 : i32} : tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<4x1x2xi32, #blocked3> loc(#loc263) + %ileft_531 = tt.broadcast %ileft_530 : tensor<4x1x2xi32, #blocked3> -> tensor<4x2x2xi32, #blocked3> loc(#loc264) + %iright_532 = arith.muli %y_527, %flip_40 : tensor<4x2x2xi32, #blocked3> loc(#loc265) + %iright_533 = "tt.reduce"(%iright_532) <{axis = 1 : i32}> ({ + ^bb0(%iright_743: i32 loc(callsite(#loc1 at #loc266)), %iright_744: i32 loc(callsite(#loc1 at #loc266))): + %iright_745 = arith.addi %iright_743, %iright_744 : i32 loc(#loc320) + tt.reduce.return %iright_745 : i32 loc(#loc309) + }) : (tensor<4x2x2xi32, #blocked3>) -> tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc309) + %iright_534 = tt.expand_dims %iright_533 {axis = 1 : i32} : tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<4x1x2xi32, #blocked3> loc(#loc267) + %iright_535 = tt.broadcast %iright_534 : tensor<4x1x2xi32, #blocked3> -> tensor<4x2x2xi32, #blocked3> loc(#loc268) + %ileft_536 = tt.reshape %ileft_531 : tensor<4x2x2xi32, #blocked3> -> tensor<1x16xi32, #blocked5> loc(#loc269) + %iright_537 = tt.reshape %iright_535 : tensor<4x2x2xi32, #blocked3> -> tensor<1x16xi32, #blocked5> loc(#loc270) + %y_idx_538 = tt.reshape %new_idxs_526 : tensor<1x16xi32, #blocked5> -> tensor<4x2x2xi32, #blocked3> loc(#loc282) + %left_idx_539 = arith.muli %y_idx_538, %ileft_82 : tensor<4x2x2xi32, #blocked3> loc(#loc283) + %left_idx_540 = "tt.reduce"(%left_idx_539) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_743: i32 loc(callsite(#loc1 at #loc284)), %left_idx_744: i32 loc(callsite(#loc1 at #loc284))): + %left_idx_745 = arith.addi %left_idx_743, %left_idx_744 : i32 loc(#loc321) + tt.reduce.return %left_idx_745 : i32 loc(#loc311) + }) : (tensor<4x2x2xi32, #blocked3>) -> tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc311) + %left_idx_541 = tt.expand_dims %left_idx_540 {axis = 1 : i32} : tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<4x1x2xi32, #blocked3> loc(#loc285) + %left_idx_542 = tt.broadcast %left_idx_541 : tensor<4x1x2xi32, #blocked3> -> tensor<4x2x2xi32, #blocked3> loc(#loc286) + %right_idx_543 = arith.muli %y_idx_538, %flip_40 : tensor<4x2x2xi32, #blocked3> loc(#loc287) + %right_idx_544 = "tt.reduce"(%right_idx_543) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_743: i32 loc(callsite(#loc1 at #loc288)), %right_idx_744: i32 loc(callsite(#loc1 at #loc288))): + %right_idx_745 = arith.addi %right_idx_743, %right_idx_744 : i32 loc(#loc322) + tt.reduce.return %right_idx_745 : i32 loc(#loc313) + }) : (tensor<4x2x2xi32, #blocked3>) -> tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc313) + %right_idx_545 = tt.expand_dims %right_idx_544 {axis = 1 : i32} : tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<4x1x2xi32, #blocked3> loc(#loc289) + %right_idx_546 = tt.broadcast %right_idx_545 : tensor<4x1x2xi32, #blocked3> -> tensor<4x2x2xi32, #blocked3> loc(#loc290) + %left_idx_547 = tt.reshape %left_idx_542 : tensor<4x2x2xi32, #blocked3> -> tensor<1x16xi32, #blocked5> loc(#loc291) + %right_idx_548 = tt.reshape %right_idx_546 : tensor<4x2x2xi32, #blocked3> -> tensor<1x16xi32, #blocked5> loc(#loc292) + %cond_549 = arith.cmpi slt, %ileft_536, %iright_537 : tensor<1x16xi32, #blocked5> loc(#loc271) + %eq_550 = arith.cmpi eq, %ileft_536, %iright_537 : tensor<1x16xi32, #blocked5> loc(#loc272) + %cond_551 = arith.cmpi sgt, %left_idx_547, %right_idx_548 : tensor<1x16xi32, #blocked5> loc(#loc293) + %cond_552 = arith.andi %eq_550, %cond_551 : tensor<1x16xi1, #blocked5> loc(#loc273) + %cond_553 = arith.ori %cond_549, %cond_552 : tensor<1x16xi1, #blocked5> loc(#loc274) + %cond_554 = arith.extui %cond_553 : tensor<1x16xi1, #blocked5> to tensor<1x16xi32, #blocked5> loc(#loc275) + %cond_555 = arith.xori %cond_554, %flip_155 : tensor<1x16xi32, #blocked5> loc(#loc275) + %cond_556 = arith.cmpi ne, %cond_555, %cst_6 : tensor<1x16xi32, #blocked5> loc(#loc276) + %ret_557 = arith.xori %ileft_536, %iright_537 : tensor<1x16xi32, #blocked5> loc(#loc277) + %ret_558 = arith.select %cond_556, %ret_557, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc278) + %ret_559 = arith.xori %ret_523, %ret_558 : tensor<1x16xi32, #blocked5> loc(#loc279) + %new_idxs_560 = arith.xori %left_idx_547, %right_idx_548 : tensor<1x16xi32, #blocked5> loc(#loc294) + %new_idxs_561 = arith.select %cond_556, %new_idxs_560, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc280) + %new_idxs_562 = arith.xori %new_idxs_526, %new_idxs_561 : tensor<1x16xi32, #blocked5> loc(#loc281) + %y_563 = tt.reshape %ret_559 : tensor<1x16xi32, #blocked5> -> tensor<8x2x1xi32, #blocked4> loc(#loc260) + %ileft_564 = arith.muli %y_563, %ileft : tensor<8x2x1xi32, #blocked4> loc(#loc261) + %ileft_565 = "tt.reduce"(%ileft_564) <{axis = 1 : i32}> ({ + ^bb0(%ileft_743: i32 loc(callsite(#loc1 at #loc262)), %ileft_744: i32 loc(callsite(#loc1 at #loc262))): + %ileft_745 = arith.addi %ileft_743, %ileft_744 : i32 loc(#loc319) + tt.reduce.return %ileft_745 : i32 loc(#loc307) + }) : (tensor<8x2x1xi32, #blocked4>) -> tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc307) + %ileft_566 = tt.expand_dims %ileft_565 {axis = 1 : i32} : tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<8x1x1xi32, #blocked4> loc(#loc263) + %ileft_567 = tt.broadcast %ileft_566 : tensor<8x1x1xi32, #blocked4> -> tensor<8x2x1xi32, #blocked4> loc(#loc264) + %iright_568 = arith.muli %y_563, %iright : tensor<8x2x1xi32, #blocked4> loc(#loc265) + %iright_569 = "tt.reduce"(%iright_568) <{axis = 1 : i32}> ({ + ^bb0(%iright_743: i32 loc(callsite(#loc1 at #loc266)), %iright_744: i32 loc(callsite(#loc1 at #loc266))): + %iright_745 = arith.addi %iright_743, %iright_744 : i32 loc(#loc320) + tt.reduce.return %iright_745 : i32 loc(#loc309) + }) : (tensor<8x2x1xi32, #blocked4>) -> tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc309) + %iright_570 = tt.expand_dims %iright_569 {axis = 1 : i32} : tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<8x1x1xi32, #blocked4> loc(#loc267) + %iright_571 = tt.broadcast %iright_570 : tensor<8x1x1xi32, #blocked4> -> tensor<8x2x1xi32, #blocked4> loc(#loc268) + %ileft_572 = tt.reshape %ileft_567 : tensor<8x2x1xi32, #blocked4> -> tensor<1x16xi32, #blocked5> loc(#loc269) + %iright_573 = tt.reshape %iright_571 : tensor<8x2x1xi32, #blocked4> -> tensor<1x16xi32, #blocked5> loc(#loc270) + %y_idx_574 = tt.reshape %new_idxs_562 : tensor<1x16xi32, #blocked5> -> tensor<8x2x1xi32, #blocked4> loc(#loc282) + %left_idx_575 = arith.muli %y_idx_574, %ileft : tensor<8x2x1xi32, #blocked4> loc(#loc283) + %left_idx_576 = "tt.reduce"(%left_idx_575) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_743: i32 loc(callsite(#loc1 at #loc284)), %left_idx_744: i32 loc(callsite(#loc1 at #loc284))): + %left_idx_745 = arith.addi %left_idx_743, %left_idx_744 : i32 loc(#loc321) + tt.reduce.return %left_idx_745 : i32 loc(#loc311) + }) : (tensor<8x2x1xi32, #blocked4>) -> tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc311) + %left_idx_577 = tt.expand_dims %left_idx_576 {axis = 1 : i32} : tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<8x1x1xi32, #blocked4> loc(#loc285) + %left_idx_578 = tt.broadcast %left_idx_577 : tensor<8x1x1xi32, #blocked4> -> tensor<8x2x1xi32, #blocked4> loc(#loc286) + %right_idx_579 = arith.muli %y_idx_574, %iright : tensor<8x2x1xi32, #blocked4> loc(#loc287) + %right_idx_580 = "tt.reduce"(%right_idx_579) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_743: i32 loc(callsite(#loc1 at #loc288)), %right_idx_744: i32 loc(callsite(#loc1 at #loc288))): + %right_idx_745 = arith.addi %right_idx_743, %right_idx_744 : i32 loc(#loc322) + tt.reduce.return %right_idx_745 : i32 loc(#loc313) + }) : (tensor<8x2x1xi32, #blocked4>) -> tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc313) + %right_idx_581 = tt.expand_dims %right_idx_580 {axis = 1 : i32} : tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<8x1x1xi32, #blocked4> loc(#loc289) + %right_idx_582 = tt.broadcast %right_idx_581 : tensor<8x1x1xi32, #blocked4> -> tensor<8x2x1xi32, #blocked4> loc(#loc290) + %left_idx_583 = tt.reshape %left_idx_578 : tensor<8x2x1xi32, #blocked4> -> tensor<1x16xi32, #blocked5> loc(#loc291) + %right_idx_584 = tt.reshape %right_idx_582 : tensor<8x2x1xi32, #blocked4> -> tensor<1x16xi32, #blocked5> loc(#loc292) + %cond_585 = arith.cmpi slt, %ileft_572, %iright_573 : tensor<1x16xi32, #blocked5> loc(#loc271) + %eq_586 = arith.cmpi eq, %ileft_572, %iright_573 : tensor<1x16xi32, #blocked5> loc(#loc272) + %cond_587 = arith.cmpi sgt, %left_idx_583, %right_idx_584 : tensor<1x16xi32, #blocked5> loc(#loc293) + %cond_588 = arith.andi %eq_586, %cond_587 : tensor<1x16xi1, #blocked5> loc(#loc273) + %cond_589 = arith.ori %cond_585, %cond_588 : tensor<1x16xi1, #blocked5> loc(#loc274) + %cond_590 = arith.extui %cond_589 : tensor<1x16xi1, #blocked5> to tensor<1x16xi32, #blocked5> loc(#loc275) + %cond_591 = arith.xori %cond_590, %flip_155 : tensor<1x16xi32, #blocked5> loc(#loc275) + %cond_592 = arith.cmpi ne, %cond_591, %cst_6 : tensor<1x16xi32, #blocked5> loc(#loc276) + %ret_593 = arith.xori %ileft_572, %iright_573 : tensor<1x16xi32, #blocked5> loc(#loc277) + %ret_594 = arith.select %cond_592, %ret_593, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc278) + %ret_595 = arith.xori %ret_559, %ret_594 : tensor<1x16xi32, #blocked5> loc(#loc279) + %new_idxs_596 = arith.xori %left_idx_583, %right_idx_584 : tensor<1x16xi32, #blocked5> loc(#loc294) + %new_idxs_597 = arith.select %cond_592, %new_idxs_596, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc280) + %new_idxs_598 = arith.xori %new_idxs_562, %new_idxs_597 : tensor<1x16xi32, #blocked5> loc(#loc281) + %y_599 = tt.reshape %ret_595 : tensor<1x16xi32, #blocked5> -> tensor<1x2x8xi32, #blocked1> loc(#loc260) + %ileft_600 = arith.muli %y_599, %ileft_266 : tensor<1x2x8xi32, #blocked1> loc(#loc261) + %ileft_601 = "tt.reduce"(%ileft_600) <{axis = 1 : i32}> ({ + ^bb0(%ileft_743: i32 loc(callsite(#loc1 at #loc262)), %ileft_744: i32 loc(callsite(#loc1 at #loc262))): + %ileft_745 = arith.addi %ileft_743, %ileft_744 : i32 loc(#loc319) + tt.reduce.return %ileft_745 : i32 loc(#loc307) + }) : (tensor<1x2x8xi32, #blocked1>) -> tensor<1x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc307) + %ileft_602 = tt.expand_dims %ileft_601 {axis = 1 : i32} : tensor<1x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<1x1x8xi32, #blocked1> loc(#loc263) + %ileft_603 = tt.broadcast %ileft_602 : tensor<1x1x8xi32, #blocked1> -> tensor<1x2x8xi32, #blocked1> loc(#loc264) + %iright_604 = arith.muli %y_599, %flip_154 : tensor<1x2x8xi32, #blocked1> loc(#loc265) + %iright_605 = "tt.reduce"(%iright_604) <{axis = 1 : i32}> ({ + ^bb0(%iright_743: i32 loc(callsite(#loc1 at #loc266)), %iright_744: i32 loc(callsite(#loc1 at #loc266))): + %iright_745 = arith.addi %iright_743, %iright_744 : i32 loc(#loc320) + tt.reduce.return %iright_745 : i32 loc(#loc309) + }) : (tensor<1x2x8xi32, #blocked1>) -> tensor<1x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc309) + %iright_606 = tt.expand_dims %iright_605 {axis = 1 : i32} : tensor<1x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<1x1x8xi32, #blocked1> loc(#loc267) + %iright_607 = tt.broadcast %iright_606 : tensor<1x1x8xi32, #blocked1> -> tensor<1x2x8xi32, #blocked1> loc(#loc268) + %ileft_608 = tt.reshape %ileft_603 : tensor<1x2x8xi32, #blocked1> -> tensor<1x16xi32, #blocked5> loc(#loc269) + %iright_609 = tt.reshape %iright_607 : tensor<1x2x8xi32, #blocked1> -> tensor<1x16xi32, #blocked5> loc(#loc270) + %y_idx_610 = tt.reshape %new_idxs_598 : tensor<1x16xi32, #blocked5> -> tensor<1x2x8xi32, #blocked1> loc(#loc282) + %left_idx_611 = arith.muli %y_idx_610, %ileft_266 : tensor<1x2x8xi32, #blocked1> loc(#loc283) + %left_idx_612 = "tt.reduce"(%left_idx_611) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_743: i32 loc(callsite(#loc1 at #loc284)), %left_idx_744: i32 loc(callsite(#loc1 at #loc284))): + %left_idx_745 = arith.addi %left_idx_743, %left_idx_744 : i32 loc(#loc321) + tt.reduce.return %left_idx_745 : i32 loc(#loc311) + }) : (tensor<1x2x8xi32, #blocked1>) -> tensor<1x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc311) + %left_idx_613 = tt.expand_dims %left_idx_612 {axis = 1 : i32} : tensor<1x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<1x1x8xi32, #blocked1> loc(#loc285) + %left_idx_614 = tt.broadcast %left_idx_613 : tensor<1x1x8xi32, #blocked1> -> tensor<1x2x8xi32, #blocked1> loc(#loc286) + %right_idx_615 = arith.muli %y_idx_610, %flip_154 : tensor<1x2x8xi32, #blocked1> loc(#loc287) + %right_idx_616 = "tt.reduce"(%right_idx_615) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_743: i32 loc(callsite(#loc1 at #loc288)), %right_idx_744: i32 loc(callsite(#loc1 at #loc288))): + %right_idx_745 = arith.addi %right_idx_743, %right_idx_744 : i32 loc(#loc322) + tt.reduce.return %right_idx_745 : i32 loc(#loc313) + }) : (tensor<1x2x8xi32, #blocked1>) -> tensor<1x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc313) + %right_idx_617 = tt.expand_dims %right_idx_616 {axis = 1 : i32} : tensor<1x8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<1x1x8xi32, #blocked1> loc(#loc289) + %right_idx_618 = tt.broadcast %right_idx_617 : tensor<1x1x8xi32, #blocked1> -> tensor<1x2x8xi32, #blocked1> loc(#loc290) + %left_idx_619 = tt.reshape %left_idx_614 : tensor<1x2x8xi32, #blocked1> -> tensor<1x16xi32, #blocked5> loc(#loc291) + %right_idx_620 = tt.reshape %right_idx_618 : tensor<1x2x8xi32, #blocked1> -> tensor<1x16xi32, #blocked5> loc(#loc292) + %cond_621 = arith.cmpi slt, %ileft_608, %iright_609 : tensor<1x16xi32, #blocked5> loc(#loc271) + %eq_622 = arith.cmpi eq, %ileft_608, %iright_609 : tensor<1x16xi32, #blocked5> loc(#loc272) + %cond_623 = arith.cmpi sgt, %left_idx_619, %right_idx_620 : tensor<1x16xi32, #blocked5> loc(#loc293) + %cond_624 = arith.andi %eq_622, %cond_623 : tensor<1x16xi1, #blocked5> loc(#loc273) + %cond_625 = arith.ori %cond_621, %cond_624 : tensor<1x16xi1, #blocked5> loc(#loc274) + %ret_626 = arith.xori %ileft_608, %iright_609 : tensor<1x16xi32, #blocked5> loc(#loc277) + %ret_627 = arith.select %cond_625, %ret_626, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc278) + %ret_628 = arith.xori %ret_595, %ret_627 : tensor<1x16xi32, #blocked5> loc(#loc279) + %new_idxs_629 = arith.xori %left_idx_619, %right_idx_620 : tensor<1x16xi32, #blocked5> loc(#loc294) + %new_idxs_630 = arith.select %cond_625, %new_idxs_629, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc280) + %new_idxs_631 = arith.xori %new_idxs_598, %new_idxs_630 : tensor<1x16xi32, #blocked5> loc(#loc281) + %y_632 = tt.reshape %ret_628 : tensor<1x16xi32, #blocked5> -> tensor<2x2x4xi32, #blocked2> loc(#loc260) + %ileft_633 = arith.muli %y_632, %ileft_157 : tensor<2x2x4xi32, #blocked2> loc(#loc261) + %ileft_634 = "tt.reduce"(%ileft_633) <{axis = 1 : i32}> ({ + ^bb0(%ileft_743: i32 loc(callsite(#loc1 at #loc262)), %ileft_744: i32 loc(callsite(#loc1 at #loc262))): + %ileft_745 = arith.addi %ileft_743, %ileft_744 : i32 loc(#loc319) + tt.reduce.return %ileft_745 : i32 loc(#loc307) + }) : (tensor<2x2x4xi32, #blocked2>) -> tensor<2x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc307) + %ileft_635 = tt.expand_dims %ileft_634 {axis = 1 : i32} : tensor<2x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<2x1x4xi32, #blocked2> loc(#loc263) + %ileft_636 = tt.broadcast %ileft_635 : tensor<2x1x4xi32, #blocked2> -> tensor<2x2x4xi32, #blocked2> loc(#loc264) + %iright_637 = arith.muli %y_632, %flip_79 : tensor<2x2x4xi32, #blocked2> loc(#loc265) + %iright_638 = "tt.reduce"(%iright_637) <{axis = 1 : i32}> ({ + ^bb0(%iright_743: i32 loc(callsite(#loc1 at #loc266)), %iright_744: i32 loc(callsite(#loc1 at #loc266))): + %iright_745 = arith.addi %iright_743, %iright_744 : i32 loc(#loc320) + tt.reduce.return %iright_745 : i32 loc(#loc309) + }) : (tensor<2x2x4xi32, #blocked2>) -> tensor<2x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc309) + %iright_639 = tt.expand_dims %iright_638 {axis = 1 : i32} : tensor<2x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<2x1x4xi32, #blocked2> loc(#loc267) + %iright_640 = tt.broadcast %iright_639 : tensor<2x1x4xi32, #blocked2> -> tensor<2x2x4xi32, #blocked2> loc(#loc268) + %ileft_641 = tt.reshape %ileft_636 : tensor<2x2x4xi32, #blocked2> -> tensor<1x16xi32, #blocked5> loc(#loc269) + %iright_642 = tt.reshape %iright_640 : tensor<2x2x4xi32, #blocked2> -> tensor<1x16xi32, #blocked5> loc(#loc270) + %y_idx_643 = tt.reshape %new_idxs_631 : tensor<1x16xi32, #blocked5> -> tensor<2x2x4xi32, #blocked2> loc(#loc282) + %left_idx_644 = arith.muli %y_idx_643, %ileft_157 : tensor<2x2x4xi32, #blocked2> loc(#loc283) + %left_idx_645 = "tt.reduce"(%left_idx_644) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_743: i32 loc(callsite(#loc1 at #loc284)), %left_idx_744: i32 loc(callsite(#loc1 at #loc284))): + %left_idx_745 = arith.addi %left_idx_743, %left_idx_744 : i32 loc(#loc321) + tt.reduce.return %left_idx_745 : i32 loc(#loc311) + }) : (tensor<2x2x4xi32, #blocked2>) -> tensor<2x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc311) + %left_idx_646 = tt.expand_dims %left_idx_645 {axis = 1 : i32} : tensor<2x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<2x1x4xi32, #blocked2> loc(#loc285) + %left_idx_647 = tt.broadcast %left_idx_646 : tensor<2x1x4xi32, #blocked2> -> tensor<2x2x4xi32, #blocked2> loc(#loc286) + %right_idx_648 = arith.muli %y_idx_643, %flip_79 : tensor<2x2x4xi32, #blocked2> loc(#loc287) + %right_idx_649 = "tt.reduce"(%right_idx_648) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_743: i32 loc(callsite(#loc1 at #loc288)), %right_idx_744: i32 loc(callsite(#loc1 at #loc288))): + %right_idx_745 = arith.addi %right_idx_743, %right_idx_744 : i32 loc(#loc322) + tt.reduce.return %right_idx_745 : i32 loc(#loc313) + }) : (tensor<2x2x4xi32, #blocked2>) -> tensor<2x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> loc(#loc313) + %right_idx_650 = tt.expand_dims %right_idx_649 {axis = 1 : i32} : tensor<2x4xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<2x1x4xi32, #blocked2> loc(#loc289) + %right_idx_651 = tt.broadcast %right_idx_650 : tensor<2x1x4xi32, #blocked2> -> tensor<2x2x4xi32, #blocked2> loc(#loc290) + %left_idx_652 = tt.reshape %left_idx_647 : tensor<2x2x4xi32, #blocked2> -> tensor<1x16xi32, #blocked5> loc(#loc291) + %right_idx_653 = tt.reshape %right_idx_651 : tensor<2x2x4xi32, #blocked2> -> tensor<1x16xi32, #blocked5> loc(#loc292) + %cond_654 = arith.cmpi slt, %ileft_641, %iright_642 : tensor<1x16xi32, #blocked5> loc(#loc271) + %eq_655 = arith.cmpi eq, %ileft_641, %iright_642 : tensor<1x16xi32, #blocked5> loc(#loc272) + %cond_656 = arith.cmpi sgt, %left_idx_652, %right_idx_653 : tensor<1x16xi32, #blocked5> loc(#loc293) + %cond_657 = arith.andi %eq_655, %cond_656 : tensor<1x16xi1, #blocked5> loc(#loc273) + %cond_658 = arith.ori %cond_654, %cond_657 : tensor<1x16xi1, #blocked5> loc(#loc274) + %ret_659 = arith.xori %ileft_641, %iright_642 : tensor<1x16xi32, #blocked5> loc(#loc277) + %ret_660 = arith.select %cond_658, %ret_659, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc278) + %ret_661 = arith.xori %ret_628, %ret_660 : tensor<1x16xi32, #blocked5> loc(#loc279) + %new_idxs_662 = arith.xori %left_idx_652, %right_idx_653 : tensor<1x16xi32, #blocked5> loc(#loc294) + %new_idxs_663 = arith.select %cond_658, %new_idxs_662, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc280) + %new_idxs_664 = arith.xori %new_idxs_631, %new_idxs_663 : tensor<1x16xi32, #blocked5> loc(#loc281) + %y_665 = tt.reshape %ret_661 : tensor<1x16xi32, #blocked5> -> tensor<4x2x2xi32, #blocked3> loc(#loc260) + %ileft_666 = arith.muli %y_665, %ileft_82 : tensor<4x2x2xi32, #blocked3> loc(#loc261) + %ileft_667 = "tt.reduce"(%ileft_666) <{axis = 1 : i32}> ({ + ^bb0(%ileft_743: i32 loc(callsite(#loc1 at #loc262)), %ileft_744: i32 loc(callsite(#loc1 at #loc262))): + %ileft_745 = arith.addi %ileft_743, %ileft_744 : i32 loc(#loc319) + tt.reduce.return %ileft_745 : i32 loc(#loc307) + }) : (tensor<4x2x2xi32, #blocked3>) -> tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc307) + %ileft_668 = tt.expand_dims %ileft_667 {axis = 1 : i32} : tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<4x1x2xi32, #blocked3> loc(#loc263) + %ileft_669 = tt.broadcast %ileft_668 : tensor<4x1x2xi32, #blocked3> -> tensor<4x2x2xi32, #blocked3> loc(#loc264) + %iright_670 = arith.muli %y_665, %flip_40 : tensor<4x2x2xi32, #blocked3> loc(#loc265) + %iright_671 = "tt.reduce"(%iright_670) <{axis = 1 : i32}> ({ + ^bb0(%iright_743: i32 loc(callsite(#loc1 at #loc266)), %iright_744: i32 loc(callsite(#loc1 at #loc266))): + %iright_745 = arith.addi %iright_743, %iright_744 : i32 loc(#loc320) + tt.reduce.return %iright_745 : i32 loc(#loc309) + }) : (tensor<4x2x2xi32, #blocked3>) -> tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc309) + %iright_672 = tt.expand_dims %iright_671 {axis = 1 : i32} : tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<4x1x2xi32, #blocked3> loc(#loc267) + %iright_673 = tt.broadcast %iright_672 : tensor<4x1x2xi32, #blocked3> -> tensor<4x2x2xi32, #blocked3> loc(#loc268) + %ileft_674 = tt.reshape %ileft_669 : tensor<4x2x2xi32, #blocked3> -> tensor<1x16xi32, #blocked5> loc(#loc269) + %iright_675 = tt.reshape %iright_673 : tensor<4x2x2xi32, #blocked3> -> tensor<1x16xi32, #blocked5> loc(#loc270) + %y_idx_676 = tt.reshape %new_idxs_664 : tensor<1x16xi32, #blocked5> -> tensor<4x2x2xi32, #blocked3> loc(#loc282) + %left_idx_677 = arith.muli %y_idx_676, %ileft_82 : tensor<4x2x2xi32, #blocked3> loc(#loc283) + %left_idx_678 = "tt.reduce"(%left_idx_677) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_743: i32 loc(callsite(#loc1 at #loc284)), %left_idx_744: i32 loc(callsite(#loc1 at #loc284))): + %left_idx_745 = arith.addi %left_idx_743, %left_idx_744 : i32 loc(#loc321) + tt.reduce.return %left_idx_745 : i32 loc(#loc311) + }) : (tensor<4x2x2xi32, #blocked3>) -> tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc311) + %left_idx_679 = tt.expand_dims %left_idx_678 {axis = 1 : i32} : tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<4x1x2xi32, #blocked3> loc(#loc285) + %left_idx_680 = tt.broadcast %left_idx_679 : tensor<4x1x2xi32, #blocked3> -> tensor<4x2x2xi32, #blocked3> loc(#loc286) + %right_idx_681 = arith.muli %y_idx_676, %flip_40 : tensor<4x2x2xi32, #blocked3> loc(#loc287) + %right_idx_682 = "tt.reduce"(%right_idx_681) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_743: i32 loc(callsite(#loc1 at #loc288)), %right_idx_744: i32 loc(callsite(#loc1 at #loc288))): + %right_idx_745 = arith.addi %right_idx_743, %right_idx_744 : i32 loc(#loc322) + tt.reduce.return %right_idx_745 : i32 loc(#loc313) + }) : (tensor<4x2x2xi32, #blocked3>) -> tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> loc(#loc313) + %right_idx_683 = tt.expand_dims %right_idx_682 {axis = 1 : i32} : tensor<4x2xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<4x1x2xi32, #blocked3> loc(#loc289) + %right_idx_684 = tt.broadcast %right_idx_683 : tensor<4x1x2xi32, #blocked3> -> tensor<4x2x2xi32, #blocked3> loc(#loc290) + %left_idx_685 = tt.reshape %left_idx_680 : tensor<4x2x2xi32, #blocked3> -> tensor<1x16xi32, #blocked5> loc(#loc291) + %right_idx_686 = tt.reshape %right_idx_684 : tensor<4x2x2xi32, #blocked3> -> tensor<1x16xi32, #blocked5> loc(#loc292) + %cond_687 = arith.cmpi slt, %ileft_674, %iright_675 : tensor<1x16xi32, #blocked5> loc(#loc271) + %eq_688 = arith.cmpi eq, %ileft_674, %iright_675 : tensor<1x16xi32, #blocked5> loc(#loc272) + %cond_689 = arith.cmpi sgt, %left_idx_685, %right_idx_686 : tensor<1x16xi32, #blocked5> loc(#loc293) + %cond_690 = arith.andi %eq_688, %cond_689 : tensor<1x16xi1, #blocked5> loc(#loc273) + %cond_691 = arith.ori %cond_687, %cond_690 : tensor<1x16xi1, #blocked5> loc(#loc274) + %ret_692 = arith.xori %ileft_674, %iright_675 : tensor<1x16xi32, #blocked5> loc(#loc277) + %ret_693 = arith.select %cond_691, %ret_692, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc278) + %ret_694 = arith.xori %ret_661, %ret_693 : tensor<1x16xi32, #blocked5> loc(#loc279) + %new_idxs_695 = arith.xori %left_idx_685, %right_idx_686 : tensor<1x16xi32, #blocked5> loc(#loc294) + %new_idxs_696 = arith.select %cond_691, %new_idxs_695, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc280) + %new_idxs_697 = arith.xori %new_idxs_664, %new_idxs_696 : tensor<1x16xi32, #blocked5> loc(#loc281) + %y_698 = tt.reshape %ret_694 : tensor<1x16xi32, #blocked5> -> tensor<8x2x1xi32, #blocked4> loc(#loc260) + %ileft_699 = arith.muli %y_698, %ileft : tensor<8x2x1xi32, #blocked4> loc(#loc261) + %ileft_700 = "tt.reduce"(%ileft_699) <{axis = 1 : i32}> ({ + ^bb0(%ileft_743: i32 loc(callsite(#loc1 at #loc262)), %ileft_744: i32 loc(callsite(#loc1 at #loc262))): + %ileft_745 = arith.addi %ileft_743, %ileft_744 : i32 loc(#loc319) + tt.reduce.return %ileft_745 : i32 loc(#loc307) + }) : (tensor<8x2x1xi32, #blocked4>) -> tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc307) + %ileft_701 = tt.expand_dims %ileft_700 {axis = 1 : i32} : tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<8x1x1xi32, #blocked4> loc(#loc263) + %ileft_702 = tt.broadcast %ileft_701 : tensor<8x1x1xi32, #blocked4> -> tensor<8x2x1xi32, #blocked4> loc(#loc264) + %iright_703 = arith.muli %y_698, %iright : tensor<8x2x1xi32, #blocked4> loc(#loc265) + %iright_704 = "tt.reduce"(%iright_703) <{axis = 1 : i32}> ({ + ^bb0(%iright_743: i32 loc(callsite(#loc1 at #loc266)), %iright_744: i32 loc(callsite(#loc1 at #loc266))): + %iright_745 = arith.addi %iright_743, %iright_744 : i32 loc(#loc320) + tt.reduce.return %iright_745 : i32 loc(#loc309) + }) : (tensor<8x2x1xi32, #blocked4>) -> tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc309) + %iright_705 = tt.expand_dims %iright_704 {axis = 1 : i32} : tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<8x1x1xi32, #blocked4> loc(#loc267) + %iright_706 = tt.broadcast %iright_705 : tensor<8x1x1xi32, #blocked4> -> tensor<8x2x1xi32, #blocked4> loc(#loc268) + %ileft_707 = tt.reshape %ileft_702 : tensor<8x2x1xi32, #blocked4> -> tensor<1x16xi32, #blocked5> loc(#loc269) + %iright_708 = tt.reshape %iright_706 : tensor<8x2x1xi32, #blocked4> -> tensor<1x16xi32, #blocked5> loc(#loc270) + %y_idx_709 = tt.reshape %new_idxs_697 : tensor<1x16xi32, #blocked5> -> tensor<8x2x1xi32, #blocked4> loc(#loc282) + %left_idx_710 = arith.muli %y_idx_709, %ileft : tensor<8x2x1xi32, #blocked4> loc(#loc283) + %left_idx_711 = "tt.reduce"(%left_idx_710) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_743: i32 loc(callsite(#loc1 at #loc284)), %left_idx_744: i32 loc(callsite(#loc1 at #loc284))): + %left_idx_745 = arith.addi %left_idx_743, %left_idx_744 : i32 loc(#loc321) + tt.reduce.return %left_idx_745 : i32 loc(#loc311) + }) : (tensor<8x2x1xi32, #blocked4>) -> tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc311) + %left_idx_712 = tt.expand_dims %left_idx_711 {axis = 1 : i32} : tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<8x1x1xi32, #blocked4> loc(#loc285) + %left_idx_713 = tt.broadcast %left_idx_712 : tensor<8x1x1xi32, #blocked4> -> tensor<8x2x1xi32, #blocked4> loc(#loc286) + %right_idx_714 = arith.muli %y_idx_709, %iright : tensor<8x2x1xi32, #blocked4> loc(#loc287) + %right_idx_715 = "tt.reduce"(%right_idx_714) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_743: i32 loc(callsite(#loc1 at #loc288)), %right_idx_744: i32 loc(callsite(#loc1 at #loc288))): + %right_idx_745 = arith.addi %right_idx_743, %right_idx_744 : i32 loc(#loc322) + tt.reduce.return %right_idx_745 : i32 loc(#loc313) + }) : (tensor<8x2x1xi32, #blocked4>) -> tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> loc(#loc313) + %right_idx_716 = tt.expand_dims %right_idx_715 {axis = 1 : i32} : tensor<8x1xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<8x1x1xi32, #blocked4> loc(#loc289) + %right_idx_717 = tt.broadcast %right_idx_716 : tensor<8x1x1xi32, #blocked4> -> tensor<8x2x1xi32, #blocked4> loc(#loc290) + %left_idx_718 = tt.reshape %left_idx_713 : tensor<8x2x1xi32, #blocked4> -> tensor<1x16xi32, #blocked5> loc(#loc291) + %right_idx_719 = tt.reshape %right_idx_717 : tensor<8x2x1xi32, #blocked4> -> tensor<1x16xi32, #blocked5> loc(#loc292) + %cond_720 = arith.cmpi slt, %ileft_707, %iright_708 : tensor<1x16xi32, #blocked5> loc(#loc271) + %eq_721 = arith.cmpi eq, %ileft_707, %iright_708 : tensor<1x16xi32, #blocked5> loc(#loc272) + %cond_722 = arith.cmpi sgt, %left_idx_718, %right_idx_719 : tensor<1x16xi32, #blocked5> loc(#loc293) + %cond_723 = arith.andi %eq_721, %cond_722 : tensor<1x16xi1, #blocked5> loc(#loc273) + %cond_724 = arith.ori %cond_720, %cond_723 : tensor<1x16xi1, #blocked5> loc(#loc274) + %new_idxs_725 = arith.xori %left_idx_718, %right_idx_719 : tensor<1x16xi32, #blocked5> loc(#loc294) + %new_idxs_726 = arith.select %cond_724, %new_idxs_725, %cst_6 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc280) + %new_idxs_727 = arith.xori %new_idxs_697, %new_idxs_726 : tensor<1x16xi32, #blocked5> loc(#loc281) + %tmp20 = arith.extui %tmp5 : tensor<1x16xi1, #blocked5> to tensor<1x16xi64, #blocked5> loc(#loc215) + %tmp20_728 = arith.extui %tmp5_28 : tensor<1x16xi1, #blocked> to tensor<1x16xi64, #blocked> loc(#loc215) + %tmp23 = arith.select %tmp0_22, %tmp20, %cst_10 : tensor<1x16xi1, #blocked5>, tensor<1x16xi64, #blocked5> loc(#loc183) + %tmp23_729 = arith.select %tmp0_23, %tmp20_728, %cst : tensor<1x16xi1, #blocked>, tensor<1x16xi64, #blocked> loc(#loc183) + %tmp24 = "tt.reduce"(%tmp23) <{axis = 1 : i32}> ({ + ^bb0(%tmp24_743: i64 loc(callsite(#loc1 at #loc184)), %tmp24_744: i64 loc(callsite(#loc1 at #loc184))): + %tmp24_745 = arith.addi %tmp24_743, %tmp24_744 : i64 loc(#loc295) + tt.reduce.return %tmp24_745 : i64 loc(#loc216) + }) : (tensor<1x16xi64, #blocked5>) -> tensor<1xi64, #ttg.slice<{dim = 1, parent = #blocked5}>> loc(#loc216) + %tmp24_730 = "tt.reduce"(%tmp23_729) <{axis = 1 : i32}> ({ + ^bb0(%tmp24_743: i64 loc(callsite(#loc1 at #loc184)), %tmp24_744: i64 loc(callsite(#loc1 at #loc184))): + %tmp24_745 = arith.addi %tmp24_743, %tmp24_744 : i64 loc(#loc295) + tt.reduce.return %tmp24_745 : i64 loc(#loc216) + }) : (tensor<1x16xi64, #blocked>) -> tensor<1xi64, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc216) + %tmp24_731 = tt.expand_dims %tmp24 {axis = 1 : i32} : tensor<1xi64, #ttg.slice<{dim = 1, parent = #blocked5}>> -> tensor<1x1xi64, #blocked5> loc(#loc185) + %tmp24_732 = tt.expand_dims %tmp24_730 {axis = 1 : i32} : tensor<1xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<1x1xi64, #blocked> loc(#loc185) + %tmp25 = arith.extui %tmp14 : tensor<1x16xi1, #blocked5> to tensor<1x16xi64, #blocked5> loc(#loc218) + %tmp25_733 = arith.extui %tmp14_395 : tensor<1x16xi1, #blocked> to tensor<1x16xi64, #blocked> loc(#loc218) + %tmp28 = arith.select %tmp0_22, %tmp25, %cst_10 : tensor<1x16xi1, #blocked5>, tensor<1x16xi64, #blocked5> loc(#loc187) + %tmp28_734 = arith.select %tmp0_23, %tmp25_733, %cst : tensor<1x16xi1, #blocked>, tensor<1x16xi64, #blocked> loc(#loc187) + %tmp29 = "tt.reduce"(%tmp28) <{axis = 1 : i32}> ({ + ^bb0(%tmp29_743: i64 loc(callsite(#loc1 at #loc188)), %tmp29_744: i64 loc(callsite(#loc1 at #loc188))): + %tmp29_745 = arith.addi %tmp29_743, %tmp29_744 : i64 loc(#loc296) + tt.reduce.return %tmp29_745 : i64 loc(#loc219) + }) : (tensor<1x16xi64, #blocked5>) -> tensor<1xi64, #ttg.slice<{dim = 1, parent = #blocked5}>> loc(#loc219) + %tmp29_735 = "tt.reduce"(%tmp28_734) <{axis = 1 : i32}> ({ + ^bb0(%tmp29_743: i64 loc(callsite(#loc1 at #loc188)), %tmp29_744: i64 loc(callsite(#loc1 at #loc188))): + %tmp29_745 = arith.addi %tmp29_743, %tmp29_744 : i64 loc(#loc296) + tt.reduce.return %tmp29_745 : i64 loc(#loc219) + }) : (tensor<1x16xi64, #blocked>) -> tensor<1xi64, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc219) + %tmp29_736 = tt.expand_dims %tmp29 {axis = 1 : i32} : tensor<1xi64, #ttg.slice<{dim = 1, parent = #blocked5}>> -> tensor<1x1xi64, #blocked5> loc(#loc189) + %tmp29_737 = tt.expand_dims %tmp29_735 {axis = 1 : i32} : tensor<1xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<1x1xi64, #blocked> loc(#loc189) + %tmp30 = arith.trunci %tmp24_731 : tensor<1x1xi64, #blocked5> to tensor<1x1xi32, #blocked5> loc(#loc190) + %tmp30_738 = arith.trunci %tmp24_732 : tensor<1x1xi64, #blocked> to tensor<1x1xi32, #blocked> loc(#loc190) + %tmp31 = arith.trunci %tmp29_736 : tensor<1x1xi64, #blocked5> to tensor<1x1xi32, #blocked5> loc(#loc191) + %tmp31_739 = arith.trunci %tmp29_737 : tensor<1x1xi64, #blocked> to tensor<1x1xi32, #blocked> loc(#loc191) + %tmp34 = tt.broadcast %tmp30 : tensor<1x1xi32, #blocked5> -> tensor<1x16xi32, #blocked5> loc(#loc192) + %tmp34_740 = arith.cmpi slt, %r0_index_12, %tmp34 : tensor<1x16xi32, #blocked5> loc(#loc192) + %tmp36 = arith.select %tmp34_740, %new_idxs_394, %cst_8 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc193) + %tmp38 = arith.addi %tmp36, %cst_7 : tensor<1x16xi32, #blocked5> loc(#loc194) + %tmp39 = arith.cmpi slt, %tmp36, %cst_6 : tensor<1x16xi32, #blocked5> loc(#loc195) + %tmp40 = arith.select %tmp39, %tmp38, %tmp36 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc196) + %0 = arith.cmpi sge, %tmp40, %cst_6 : tensor<1x16xi32, #blocked5> loc(#loc81) + %1 = arith.cmpi slt, %tmp40, %cst_7 : tensor<1x16xi32, #blocked5> loc(#loc82) + %2 = arith.andi %0, %1 : tensor<1x16xi1, #blocked5> loc(#loc83) + %xmask_741 = arith.cmpi sge, %xoffset, %c32_i32 : i32 loc(#loc221) + %3 = tt.splat %xmask_741 : i1 -> tensor<1x16xi1, #blocked5> loc(#loc197) + %4 = arith.ori %2, %3 : tensor<1x16xi1, #blocked5> loc(#loc85) + tt.assert %4, "index out of bounds: 0 <= tmp40 < 17" : tensor<1x16xi1, #blocked5> loc(#loc86) + %tmp45 = tt.broadcast %tmp31 : tensor<1x1xi32, #blocked5> -> tensor<1x16xi32, #blocked5> loc(#loc198) + %tmp45_742 = arith.cmpi slt, %r0_index_12, %tmp45 : tensor<1x16xi32, #blocked5> loc(#loc198) + %tmp46 = arith.select %tmp45_742, %new_idxs_727, %cst_8 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc199) + %tmp47 = arith.addi %tmp46, %cst_7 : tensor<1x16xi32, #blocked5> loc(#loc200) + %tmp48 = arith.cmpi slt, %tmp46, %cst_6 : tensor<1x16xi32, #blocked5> loc(#loc201) + %tmp49 = arith.select %tmp48, %tmp47, %tmp46 : tensor<1x16xi1, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc202) + %5 = arith.cmpi sge, %tmp49, %cst_6 : tensor<1x16xi32, #blocked5> loc(#loc92) + %6 = arith.cmpi slt, %tmp49, %cst_7 : tensor<1x16xi32, #blocked5> loc(#loc93) + %7 = arith.andi %5, %6 : tensor<1x16xi1, #blocked5> loc(#loc94) + %8 = arith.ori %7, %3 : tensor<1x16xi1, #blocked5> loc(#loc95) + tt.assert %8, "index out of bounds: 0 <= tmp49 < 17" : tensor<1x16xi1, #blocked5> loc(#loc96) + %9 = tt.addptr %out_ptr4, %xoffset : !tt.ptr, i32 loc(#loc97) + %10 = tt.splat %9 : !tt.ptr -> tensor<1x1x!tt.ptr, #blocked> loc(#loc98) + %11 = tt.splat %xmask : i1 -> tensor<1x1xi1, #blocked> loc(#loc98) + tt.store %10, %tmp30_738, %11 : tensor<1x1x!tt.ptr, #blocked> loc(#loc98) + %12 = tt.addptr %out_ptr5, %xoffset : !tt.ptr, i32 loc(#loc99) + %13 = tt.splat %12 : !tt.ptr -> tensor<1x1x!tt.ptr, #blocked> loc(#loc100) + tt.store %13, %tmp31_739, %11 : tensor<1x1x!tt.ptr, #blocked> loc(#loc100) + %14 = tt.splat %out_ptr6 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked5> loc(#loc101) + %15 = tt.addptr %14, %tmp0_16 : tensor<1x16x!tt.ptr, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc101) + tt.store %15, %new_idxs_394, %tmp0_22 : tensor<1x16x!tt.ptr, #blocked5> loc(#loc102) + %16 = arith.muli %xoffset, %c17_i32 : i32 loc(#loc103) + %17 = tt.splat %16 : i32 -> tensor<1x16xi32, #blocked5> loc(#loc203) + %18 = arith.addi %tmp40, %17 : tensor<1x16xi32, #blocked5> loc(#loc104) + %19 = tt.splat %out_ptr7 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked5> loc(#loc105) + %20 = tt.addptr %19, %18 : tensor<1x16x!tt.ptr, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc105) + %21 = ttg.convert_layout %20 : tensor<1x16x!tt.ptr, #blocked5> -> tensor<1x16x!tt.ptr, #blocked> loc(#loc106) + tt.store %21, %cst_5, %tmp0_23 : tensor<1x16x!tt.ptr, #blocked> loc(#loc106) + %22 = tt.splat %out_ptr8 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked5> loc(#loc107) + %23 = tt.addptr %22, %tmp0_16 : tensor<1x16x!tt.ptr, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc107) + tt.store %23, %new_idxs_727, %tmp0_22 : tensor<1x16x!tt.ptr, #blocked5> loc(#loc108) + %24 = arith.addi %tmp49, %17 : tensor<1x16xi32, #blocked5> loc(#loc109) + %25 = tt.splat %out_ptr9 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked5> loc(#loc110) + %26 = tt.addptr %25, %24 : tensor<1x16x!tt.ptr, #blocked5>, tensor<1x16xi32, #blocked5> loc(#loc110) + %27 = ttg.convert_layout %26 : tensor<1x16x!tt.ptr, #blocked5> -> tensor<1x16x!tt.ptr, #blocked> loc(#loc111) + tt.store %27, %cst_5, %tmp0_23 : tensor<1x16x!tt.ptr, #blocked> loc(#loc111) + tt.return loc(#loc112) + } loc(#loc) +} loc(#loc) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":24:28) +#loc3 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":26:21) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":27:38) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":34:40) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":34:37) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":34:30) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":34:45) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":36:18) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":38:18) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":39:18) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":41:19) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":40:19) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":43:19) +#loc15 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":627:44) +#loc18 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":627:60) +#loc19 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":627:68) +#loc20 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":533:22) +#loc22 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":537:21) +#loc23 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":538:40) +#loc24 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:36) +#loc26 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:15) +#loc27 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":538:65) +#loc28 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":538:78) +#loc29 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":539:41) +#loc31 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":539:67) +#loc32 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":539:80) +#loc33 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":540:30) +#loc34 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":541:32) +#loc35 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":546:29) +#loc36 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:36) +#loc37 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:23) +#loc38 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":290:25) +#loc40 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:53) +#loc41 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:66) +#loc42 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:37) +#loc43 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:23) +#loc45 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:54) +#loc46 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:67) +#loc47 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":553:36) +#loc48 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":554:38) +#loc49 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":574:22) +#loc50 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":591:21) +#loc51 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":594:40) +#loc52 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":594:29) +#loc53 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":594:23) +#loc54 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":599:19) +#loc55 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":599:28) +#loc56 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":600:38) +#loc57 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":600:46) +#loc58 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":600:15) +#loc59 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":601:48) +#loc60 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":601:59) +#loc61 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":601:22) +#loc62 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":47:20) +#loc63 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":49:21) +#loc64 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":48:21) +#loc66 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":52:20) +#loc67 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":54:35) +#loc69 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":55:29) +#loc70 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":56:21) +#loc71 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":58:35) +#loc73 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":59:29) +#loc74 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":60:21) +#loc75 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":61:21) +#loc76 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":64:19) +#loc77 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":66:35) +#loc78 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":68:20) +#loc79 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":69:20) +#loc80 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":70:35) +#loc81 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":71:28) +#loc82 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":71:46) +#loc83 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":71:38) +#loc84 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":71:55) +#loc85 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":71:53) +#loc86 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":71:63) +#loc87 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":75:19) +#loc88 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":76:35) +#loc89 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":77:20) +#loc90 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":78:20) +#loc91 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":79:35) +#loc92 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":80:28) +#loc93 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":80:46) +#loc94 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":80:38) +#loc95 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":80:53) +#loc96 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":80:63) +#loc97 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":81:25) +#loc98 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":81:37) +#loc99 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":82:25) +#loc100 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":82:37) +#loc101 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":83:25) +#loc102 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":83:47) +#loc103 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":84:52) +#loc104 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":84:49) +#loc105 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":84:25) +#loc106 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":84:85) +#loc107 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":85:25) +#loc108 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":85:47) +#loc109 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":86:49) +#loc110 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":86:25) +#loc111 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":86:85) +#loc112 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":86:4) +#loc122 = loc("xoffset"(#loc2)) +#loc123 = loc("xmask"(#loc3)) +#loc124 = loc("r0_index"(#loc4)) +#loc125 = loc("tmp0"(#loc5)) +#loc126 = loc("tmp0"(#loc6)) +#loc127 = loc("tmp0"(#loc7)) +#loc128 = loc("tmp0"(#loc8)) +#loc129 = loc("tmp2"(#loc9)) +#loc130 = loc("tmp4"(#loc10)) +#loc131 = loc("tmp5"(#loc11)) +#loc132 = loc("tmp7"(#loc12)) +#loc133 = loc("tmp6"(#loc13)) +#loc134 = loc("tmp9"(#loc14)) +#loc135 = loc("flip"(#loc15)) +#loc137 = loc("flip"(#loc18)) +#loc138 = loc("flip"(#loc19)) +#loc139 = loc("y"(#loc20)) +#loc140 = loc("left_mask"(#loc22)) +#loc141 = loc("ileft"(#loc23)) +#loc143 = loc("ileft"(#loc27)) +#loc144 = loc("ileft"(#loc28)) +#loc145 = loc("iright"(#loc29)) +#loc147 = loc("iright"(#loc31)) +#loc148 = loc("iright"(#loc32)) +#loc149 = loc("ileft"(#loc33)) +#loc150 = loc("iright"(#loc34)) +#loc151 = loc("y_idx"(#loc35)) +#loc152 = loc("left_idx"(#loc36)) +#loc153 = loc("left_idx"(#loc37)) +#loc154 = loc("input"(#loc38)) +#loc156 = loc("left_idx"(#loc40)) +#loc157 = loc("left_idx"(#loc41)) +#loc158 = loc("right_idx"(#loc42)) +#loc159 = loc("right_idx"(#loc43)) +#loc161 = loc("right_idx"(#loc45)) +#loc162 = loc("right_idx"(#loc46)) +#loc163 = loc("left_idx"(#loc47)) +#loc164 = loc("right_idx"(#loc48)) +#loc165 = loc("cond"(#loc49)) +#loc166 = loc("eq"(#loc50)) +#loc167 = loc("cond"(#loc51)) +#loc168 = loc("cond"(#loc52)) +#loc169 = loc("cond"(#loc53)) +#loc170 = loc("cond"(#loc54)) +#loc171 = loc("cond"(#loc55)) +#loc172 = loc("ret"(#loc56)) +#loc173 = loc("ret"(#loc57)) +#loc174 = loc("ret"(#loc58)) +#loc175 = loc("new_idxs"(#loc59)) +#loc176 = loc("new_idxs"(#loc60)) +#loc177 = loc("new_idxs"(#loc61)) +#loc178 = loc("tmp14"(#loc62)) +#loc179 = loc("tmp16"(#loc63)) +#loc180 = loc("tmp15"(#loc64)) +#loc182 = loc("tmp20"(#loc66)) +#loc183 = loc("tmp23"(#loc67)) +#loc185 = loc("tmp24"(#loc69)) +#loc186 = loc("tmp25"(#loc70)) +#loc187 = loc("tmp28"(#loc71)) +#loc189 = loc("tmp29"(#loc73)) +#loc190 = loc("tmp30"(#loc74)) +#loc191 = loc("tmp31"(#loc75)) +#loc192 = loc("tmp34"(#loc76)) +#loc193 = loc("tmp36"(#loc77)) +#loc194 = loc("tmp38"(#loc78)) +#loc195 = loc("tmp39"(#loc79)) +#loc196 = loc("tmp40"(#loc80)) +#loc197 = loc(fused[#loc85, #loc84]) +#loc198 = loc("tmp45"(#loc87)) +#loc199 = loc("tmp46"(#loc88)) +#loc200 = loc("tmp47"(#loc89)) +#loc201 = loc("tmp48"(#loc90)) +#loc202 = loc("tmp49"(#loc91)) +#loc203 = loc(fused[#loc104, #loc103]) +#loc204 = loc(fused[#loc126, #loc125]) +#loc205 = loc(fused[#loc128, #loc123]) +#loc206 = loc(fused[#loc132, #loc133]) +#loc207 = loc(callsite(#loc135 at #loc136)) +#loc208 = loc(callsite(#loc137 at #loc136)) +#loc209 = loc(callsite(#loc138 at #loc136)) +#loc211 = loc("cond"(#loc165)) +#loc212 = loc("eq"(#loc166)) +#loc213 = loc(fused[#loc179, #loc180]) +#loc215 = loc(fused[#loc182, #loc132, #loc133]) +#loc216 = loc(callsite(#loc24 at #loc184)) +#loc218 = loc(fused[#loc186, #loc179, #loc180]) +#loc219 = loc(callsite(#loc24 at #loc188)) +#loc221 = loc(fused[#loc84, #loc123]) +#loc222 = loc(callsite(#loc139 at #loc210)) +#loc223 = loc(callsite(#loc140 at #loc210)) +#loc224 = loc(callsite(#loc141 at #loc210)) +#loc226 = loc(callsite(#loc143 at #loc210)) +#loc227 = loc(callsite(#loc144 at #loc210)) +#loc228 = loc(callsite(#loc145 at #loc210)) +#loc230 = loc(callsite(#loc147 at #loc210)) +#loc231 = loc(callsite(#loc148 at #loc210)) +#loc232 = loc(callsite(#loc149 at #loc210)) +#loc233 = loc(callsite(#loc150 at #loc210)) +#loc234 = loc(callsite(#loc151 at #loc210)) +#loc235 = loc(callsite(#loc152 at #loc210)) +#loc236 = loc(callsite(#loc153 at #loc210)) +#loc238 = loc(callsite(#loc156 at #loc210)) +#loc239 = loc(callsite(#loc157 at #loc210)) +#loc240 = loc(callsite(#loc158 at #loc210)) +#loc241 = loc(callsite(#loc159 at #loc210)) +#loc243 = loc(callsite(#loc161 at #loc210)) +#loc244 = loc(callsite(#loc162 at #loc210)) +#loc245 = loc(callsite(#loc163 at #loc210)) +#loc246 = loc(callsite(#loc164 at #loc210)) +#loc247 = loc(callsite(#loc211 at #loc210)) +#loc248 = loc(callsite(#loc212 at #loc210)) +#loc249 = loc(callsite(#loc167 at #loc210)) +#loc250 = loc(callsite(#loc168 at #loc210)) +#loc251 = loc(callsite(#loc169 at #loc210)) +#loc252 = loc(callsite(#loc170 at #loc210)) +#loc253 = loc(callsite(#loc171 at #loc210)) +#loc254 = loc(callsite(#loc172 at #loc210)) +#loc255 = loc(callsite(#loc173 at #loc210)) +#loc256 = loc(callsite(#loc174 at #loc210)) +#loc257 = loc(callsite(#loc175 at #loc210)) +#loc258 = loc(callsite(#loc176 at #loc210)) +#loc259 = loc(callsite(#loc177 at #loc210)) +#loc260 = loc(callsite(#loc139 at #loc214)) +#loc261 = loc(callsite(#loc141 at #loc214)) +#loc263 = loc(callsite(#loc143 at #loc214)) +#loc264 = loc(callsite(#loc144 at #loc214)) +#loc265 = loc(callsite(#loc145 at #loc214)) +#loc267 = loc(callsite(#loc147 at #loc214)) +#loc268 = loc(callsite(#loc148 at #loc214)) +#loc269 = loc(callsite(#loc149 at #loc214)) +#loc270 = loc(callsite(#loc150 at #loc214)) +#loc271 = loc(callsite(#loc211 at #loc214)) +#loc272 = loc(callsite(#loc212 at #loc214)) +#loc273 = loc(callsite(#loc168 at #loc214)) +#loc274 = loc(callsite(#loc169 at #loc214)) +#loc275 = loc(callsite(#loc170 at #loc214)) +#loc276 = loc(callsite(#loc171 at #loc214)) +#loc277 = loc(callsite(#loc172 at #loc214)) +#loc278 = loc(callsite(#loc173 at #loc214)) +#loc279 = loc(callsite(#loc174 at #loc214)) +#loc280 = loc(callsite(#loc176 at #loc214)) +#loc281 = loc(callsite(#loc177 at #loc214)) +#loc282 = loc(callsite(#loc151 at #loc214)) +#loc283 = loc(callsite(#loc153 at #loc214)) +#loc285 = loc(callsite(#loc156 at #loc214)) +#loc286 = loc(callsite(#loc157 at #loc214)) +#loc287 = loc(callsite(#loc159 at #loc214)) +#loc289 = loc(callsite(#loc161 at #loc214)) +#loc290 = loc(callsite(#loc162 at #loc214)) +#loc291 = loc(callsite(#loc163 at #loc214)) +#loc292 = loc(callsite(#loc164 at #loc214)) +#loc293 = loc(callsite(#loc167 at #loc214)) +#loc294 = loc(callsite(#loc175 at #loc214)) +#loc295 = loc(callsite(#loc26 at #loc216)) +#loc296 = loc(callsite(#loc26 at #loc219)) +#loc297 = loc(callsite(#loc24 at #loc225)) +#loc299 = loc(callsite(#loc24 at #loc229)) +#loc301 = loc(callsite(#loc154 at #loc237)) +#loc302 = loc(callsite(#loc24 at #loc237)) +#loc304 = loc(callsite(#loc154 at #loc242)) +#loc305 = loc(callsite(#loc24 at #loc242)) +#loc307 = loc(callsite(#loc24 at #loc262)) +#loc309 = loc(callsite(#loc24 at #loc266)) +#loc311 = loc(callsite(#loc24 at #loc284)) +#loc313 = loc(callsite(#loc24 at #loc288)) +#loc315 = loc(callsite(#loc26 at #loc297)) +#loc316 = loc(callsite(#loc26 at #loc299)) +#loc317 = loc(callsite(#loc26 at #loc302)) +#loc318 = loc(callsite(#loc26 at #loc305)) +#loc319 = loc(callsite(#loc26 at #loc307)) +#loc320 = loc(callsite(#loc26 at #loc309)) +#loc321 = loc(callsite(#loc26 at #loc311)) +#loc322 = loc(callsite(#loc26 at #loc313)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/NFXDDTZIEQAGR5JBWWCXV73QZILXJMIVVJVPZH3CHAIULDAND5UQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ttir b/SpecForge-ext/cache/compiled_kernels/triton/0/NFXDDTZIEQAGR5JBWWCXV73QZILXJMIVVJVPZH3CHAIULDAND5UQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ttir new file mode 100644 index 0000000000000000000000000000000000000000..7d325f5225b251ec47e9ce6f63ed73811b93cc69 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/NFXDDTZIEQAGR5JBWWCXV73QZILXJMIVVJVPZH3CHAIULDAND5UQ/triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.ttir @@ -0,0 +1,1437 @@ +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":18:0) +#loc1 = loc(unknown) +#loc17 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":662:12) +#loc18 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":46:71) +#loc23 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":634:73) +#loc27 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":538:51) +#loc32 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":539:53) +#loc41 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:50) +#loc46 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:51) +#loc67 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":51:71) +#loc70 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":55:26) +#loc74 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":59:26) +#loc115 = loc("in_ptr0"(#loc)) +#loc116 = loc("out_ptr4"(#loc)) +#loc117 = loc("out_ptr5"(#loc)) +#loc118 = loc("out_ptr6"(#loc)) +#loc119 = loc("out_ptr7"(#loc)) +#loc120 = loc("out_ptr8"(#loc)) +#loc121 = loc("out_ptr9"(#loc)) +#loc122 = loc("xnumel"(#loc)) +#loc123 = loc("r0_numel"(#loc)) +#loc139 = loc(callsite(#loc17 at #loc18)) +#loc146 = loc("ileft"(#loc27)) +#loc150 = loc("iright"(#loc32)) +#loc159 = loc("left_idx"(#loc41)) +#loc164 = loc("right_idx"(#loc46)) +#loc185 = loc(callsite(#loc17 at #loc67)) +#loc188 = loc("tmp24"(#loc70)) +#loc192 = loc("tmp29"(#loc74)) +#loc215 = loc(callsite(#loc23 at #loc139)) +#loc219 = loc(callsite(#loc23 at #loc185)) +#loc222 = loc(callsite(#loc1 at #loc188)) +#loc225 = loc(callsite(#loc1 at #loc192)) +#loc229 = loc(callsite(#loc146 at #loc215)) +#loc233 = loc(callsite(#loc150 at #loc215)) +#loc241 = loc(callsite(#loc159 at #loc215)) +#loc246 = loc(callsite(#loc164 at #loc215)) +#loc266 = loc(callsite(#loc146 at #loc219)) +#loc270 = loc(callsite(#loc150 at #loc219)) +#loc288 = loc(callsite(#loc159 at #loc219)) +#loc292 = loc(callsite(#loc164 at #loc219)) +#loc302 = loc(callsite(#loc1 at #loc229)) +#loc304 = loc(callsite(#loc1 at #loc233)) +#loc307 = loc(callsite(#loc1 at #loc241)) +#loc310 = loc(callsite(#loc1 at #loc246)) +#loc312 = loc(callsite(#loc1 at #loc266)) +#loc314 = loc(callsite(#loc1 at #loc270)) +#loc316 = loc(callsite(#loc1 at #loc288)) +#loc318 = loc(callsite(#loc1 at #loc292)) +module { + tt.func public @triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %out_ptr4: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr4"(#loc)), %out_ptr5: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr5"(#loc)), %out_ptr6: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr6"(#loc)), %out_ptr7: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr7"(#loc)), %out_ptr8: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr8"(#loc)), %out_ptr9: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr9"(#loc)), %xnumel: i32 {tt.divisibility = 16 : i32} loc("xnumel"(#loc)), %r0_numel: i32 {tt.divisibility = 16 : i32} loc("r0_numel"(#loc))) attributes {noinline = false} { + %c17_i32 = arith.constant 17 : i32 loc(#loc1) + %true = arith.constant true loc(#loc1) + %c16_i32 = arith.constant 16 : i32 loc(#loc1) + %xmask = arith.constant 32 : i32 loc(#loc124) + %cst = arith.constant dense<1> : tensor<1x2x1xi32> loc(#loc1) + %cst_0 = arith.constant dense<1> : tensor<1x16xi32> loc(#loc1) + %cst_1 = arith.constant dense<0> : tensor<1x16xi32> loc(#loc1) + %cst_2 = arith.constant dense<17> : tensor<1x16xi32> loc(#loc1) + %cst_3 = arith.constant dense<16> : tensor<1x16xi32> loc(#loc1) + %cst_4 = arith.constant dense<16384> : tensor<1x16xi64> loc(#loc1) + %cst_5 = arith.constant dense<0> : tensor<1x16xi64> loc(#loc1) + %xoffset = tt.get_program_id x : i32 loc(#loc125) + %xmask_6 = arith.cmpi slt, %xoffset, %xmask : i32 loc(#loc124) + %xmask_7 = tt.splat %xmask_6 : i1 -> tensor<1x1xi1> loc(#loc124) + %r0_index = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc126) + %r0_index_8 = tt.expand_dims %r0_index {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> loc(#loc127) + %tmp0 = arith.muli %xoffset, %c16_i32 : i32 loc(#loc128) + %tmp0_9 = tt.splat %tmp0 : i32 -> tensor<1x16xi32> loc(#loc208) + %tmp0_10 = arith.addi %r0_index_8, %tmp0_9 : tensor<1x16xi32> loc(#loc129) + %tmp0_11 = tt.splat %in_ptr0 : !tt.ptr -> tensor<1x16x!tt.ptr> loc(#loc130) + %tmp0_12 = tt.addptr %tmp0_11, %tmp0_10 : tensor<1x16x!tt.ptr>, tensor<1x16xi32> loc(#loc130) + %tmp0_13 = tt.splat %xmask_6 : i1 -> tensor<1x16xi1> loc(#loc209) + %tmp0_14 = tt.load %tmp0_12, %tmp0_13, %cst_5 : tensor<1x16x!tt.ptr> loc(#loc131) + %tmp2 = arith.cmpi sgt, %tmp0_14, %cst_5 : tensor<1x16xi64> loc(#loc132) + %tmp4 = arith.cmpi slt, %tmp0_14, %cst_4 : tensor<1x16xi64> loc(#loc133) + %tmp5 = arith.andi %tmp2, %tmp4 : tensor<1x16xi1> loc(#loc134) + %tmp7 = arith.extui %tmp5 : tensor<1x16xi1> to tensor<1x16xi32> loc(#loc210) + %tmp9 = arith.trunci %r0_index_8 : tensor<1x16xi32> to tensor<1x16xi16> loc(#loc137) + %flip = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc211) + %flip_15 = tt.expand_dims %flip {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc212) + %flip_16 = tt.expand_dims %flip_15 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc212) + %flip_17 = tt.broadcast %flip_16 : tensor<1x2x1xi32> -> tensor<4x2x2xi32> loc(#loc213) + %flip_18 = tt.reshape %flip_17 : tensor<4x2x2xi32> -> tensor<1x16xi32> loc(#loc214) + %y = tt.reshape %tmp7 : tensor<1x16xi32> -> tensor<8x2x1xi32> loc(#loc226) + %left_mask = arith.subi %cst, %flip_16 : tensor<1x2x1xi32> loc(#loc227) + %ileft = tt.broadcast %left_mask : tensor<1x2x1xi32> -> tensor<8x2x1xi32> loc(#loc228) + %ileft_19 = arith.muli %y, %ileft : tensor<8x2x1xi32> loc(#loc228) + %ileft_20 = "tt.reduce"(%ileft_19) <{axis = 1 : i32}> ({ + ^bb0(%ileft_705: i32 loc(callsite(#loc1 at #loc229)), %ileft_706: i32 loc(callsite(#loc1 at #loc229))): + %ileft_707 = arith.addi %ileft_705, %ileft_706 : i32 loc(#loc319) + tt.reduce.return %ileft_707 : i32 loc(#loc301) + }) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc301) + %ileft_21 = tt.expand_dims %ileft_20 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc230) + %ileft_22 = tt.broadcast %ileft_21 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc231) + %iright = tt.broadcast %flip_16 : tensor<1x2x1xi32> -> tensor<8x2x1xi32> loc(#loc232) + %iright_23 = arith.muli %y, %iright : tensor<8x2x1xi32> loc(#loc232) + %iright_24 = "tt.reduce"(%iright_23) <{axis = 1 : i32}> ({ + ^bb0(%iright_705: i32 loc(callsite(#loc1 at #loc233)), %iright_706: i32 loc(callsite(#loc1 at #loc233))): + %iright_707 = arith.addi %iright_705, %iright_706 : i32 loc(#loc320) + tt.reduce.return %iright_707 : i32 loc(#loc303) + }) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc303) + %iright_25 = tt.expand_dims %iright_24 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc234) + %iright_26 = tt.broadcast %iright_25 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc235) + %ileft_27 = tt.reshape %ileft_22 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc236) + %iright_28 = tt.reshape %iright_26 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc237) + %y_idx = tt.reshape %tmp9 : tensor<1x16xi16> -> tensor<8x2x1xi16> loc(#loc238) + %left_idx = arith.trunci %left_mask : tensor<1x2x1xi32> to tensor<1x2x1xi16> loc(#loc239) + %left_idx_29 = tt.broadcast %left_idx : tensor<1x2x1xi16> -> tensor<8x2x1xi16> loc(#loc240) + %left_idx_30 = arith.muli %y_idx, %left_idx_29 : tensor<8x2x1xi16> loc(#loc240) + %input = arith.extsi %left_idx_30 : tensor<8x2x1xi16> to tensor<8x2x1xi32> loc(#loc305) + %left_idx_31 = "tt.reduce"(%input) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_705: i32 loc(callsite(#loc1 at #loc241)), %left_idx_706: i32 loc(callsite(#loc1 at #loc241))): + %left_idx_707 = arith.addi %left_idx_705, %left_idx_706 : i32 loc(#loc321) + tt.reduce.return %left_idx_707 : i32 loc(#loc306) + }) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc306) + %left_idx_32 = tt.expand_dims %left_idx_31 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc242) + %left_idx_33 = tt.broadcast %left_idx_32 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc243) + %right_idx = arith.trunci %flip_16 : tensor<1x2x1xi32> to tensor<1x2x1xi16> loc(#loc244) + %right_idx_34 = tt.broadcast %right_idx : tensor<1x2x1xi16> -> tensor<8x2x1xi16> loc(#loc245) + %right_idx_35 = arith.muli %y_idx, %right_idx_34 : tensor<8x2x1xi16> loc(#loc245) + %input_36 = arith.extsi %right_idx_35 : tensor<8x2x1xi16> to tensor<8x2x1xi32> loc(#loc308) + %right_idx_37 = "tt.reduce"(%input_36) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_705: i32 loc(callsite(#loc1 at #loc246)), %right_idx_706: i32 loc(callsite(#loc1 at #loc246))): + %right_idx_707 = arith.addi %right_idx_705, %right_idx_706 : i32 loc(#loc322) + tt.reduce.return %right_idx_707 : i32 loc(#loc309) + }) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc309) + %right_idx_38 = tt.expand_dims %right_idx_37 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc247) + %right_idx_39 = tt.broadcast %right_idx_38 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc248) + %left_idx_40 = tt.reshape %left_idx_33 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc249) + %right_idx_41 = tt.reshape %right_idx_39 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc250) + %cond = arith.cmpi slt, %ileft_27, %iright_28 : tensor<1x16xi32> loc(#loc251) + %eq = arith.cmpi eq, %ileft_27, %iright_28 : tensor<1x16xi32> loc(#loc252) + %cond_42 = arith.cmpi sgt, %left_idx_40, %right_idx_41 : tensor<1x16xi32> loc(#loc253) + %cond_43 = arith.andi %eq, %cond_42 : tensor<1x16xi1> loc(#loc254) + %cond_44 = arith.ori %cond, %cond_43 : tensor<1x16xi1> loc(#loc255) + %cond_45 = arith.extui %cond_44 : tensor<1x16xi1> to tensor<1x16xi32> loc(#loc256) + %cond_46 = arith.xori %cond_45, %flip_18 : tensor<1x16xi32> loc(#loc256) + %cond_47 = arith.cmpi ne, %cond_46, %cst_1 : tensor<1x16xi32> loc(#loc257) + %ret = arith.xori %ileft_27, %iright_28 : tensor<1x16xi32> loc(#loc258) + %ret_48 = arith.select %cond_47, %ret, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc259) + %ret_49 = arith.xori %tmp7, %ret_48 : tensor<1x16xi32> loc(#loc260) + %new_idxs = arith.xori %left_idx_40, %right_idx_41 : tensor<1x16xi32> loc(#loc261) + %new_idxs_50 = arith.select %cond_47, %new_idxs, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc262) + %new_idxs_51 = arith.extsi %tmp9 : tensor<1x16xi16> to tensor<1x16xi32> loc(#loc263) + %new_idxs_52 = arith.xori %new_idxs_51, %new_idxs_50 : tensor<1x16xi32> loc(#loc263) + %flip_53 = tt.broadcast %flip_16 : tensor<1x2x1xi32> -> tensor<2x2x4xi32> loc(#loc213) + %flip_54 = tt.reshape %flip_53 : tensor<2x2x4xi32> -> tensor<1x16xi32> loc(#loc214) + %y_55 = tt.reshape %ret_49 : tensor<1x16xi32> -> tensor<4x2x2xi32> loc(#loc226) + %ileft_56 = tt.broadcast %left_mask : tensor<1x2x1xi32> -> tensor<4x2x2xi32> loc(#loc228) + %ileft_57 = arith.muli %y_55, %ileft_56 : tensor<4x2x2xi32> loc(#loc228) + %ileft_58 = "tt.reduce"(%ileft_57) <{axis = 1 : i32}> ({ + ^bb0(%ileft_705: i32 loc(callsite(#loc1 at #loc229)), %ileft_706: i32 loc(callsite(#loc1 at #loc229))): + %ileft_707 = arith.addi %ileft_705, %ileft_706 : i32 loc(#loc319) + tt.reduce.return %ileft_707 : i32 loc(#loc301) + }) : (tensor<4x2x2xi32>) -> tensor<4x2xi32> loc(#loc301) + %ileft_59 = tt.expand_dims %ileft_58 {axis = 1 : i32} : tensor<4x2xi32> -> tensor<4x1x2xi32> loc(#loc230) + %ileft_60 = tt.broadcast %ileft_59 : tensor<4x1x2xi32> -> tensor<4x2x2xi32> loc(#loc231) + %iright_61 = arith.muli %y_55, %flip_17 : tensor<4x2x2xi32> loc(#loc232) + %iright_62 = "tt.reduce"(%iright_61) <{axis = 1 : i32}> ({ + ^bb0(%iright_705: i32 loc(callsite(#loc1 at #loc233)), %iright_706: i32 loc(callsite(#loc1 at #loc233))): + %iright_707 = arith.addi %iright_705, %iright_706 : i32 loc(#loc320) + tt.reduce.return %iright_707 : i32 loc(#loc303) + }) : (tensor<4x2x2xi32>) -> tensor<4x2xi32> loc(#loc303) + %iright_63 = tt.expand_dims %iright_62 {axis = 1 : i32} : tensor<4x2xi32> -> tensor<4x1x2xi32> loc(#loc234) + %iright_64 = tt.broadcast %iright_63 : tensor<4x1x2xi32> -> tensor<4x2x2xi32> loc(#loc235) + %ileft_65 = tt.reshape %ileft_60 : tensor<4x2x2xi32> -> tensor<1x16xi32> loc(#loc236) + %iright_66 = tt.reshape %iright_64 : tensor<4x2x2xi32> -> tensor<1x16xi32> loc(#loc237) + %y_idx_67 = tt.reshape %new_idxs_52 : tensor<1x16xi32> -> tensor<4x2x2xi32> loc(#loc238) + %left_idx_68 = arith.muli %y_idx_67, %ileft_56 : tensor<4x2x2xi32> loc(#loc240) + %left_idx_69 = "tt.reduce"(%left_idx_68) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_705: i32 loc(callsite(#loc1 at #loc241)), %left_idx_706: i32 loc(callsite(#loc1 at #loc241))): + %left_idx_707 = arith.addi %left_idx_705, %left_idx_706 : i32 loc(#loc321) + tt.reduce.return %left_idx_707 : i32 loc(#loc306) + }) : (tensor<4x2x2xi32>) -> tensor<4x2xi32> loc(#loc306) + %left_idx_70 = tt.expand_dims %left_idx_69 {axis = 1 : i32} : tensor<4x2xi32> -> tensor<4x1x2xi32> loc(#loc242) + %left_idx_71 = tt.broadcast %left_idx_70 : tensor<4x1x2xi32> -> tensor<4x2x2xi32> loc(#loc243) + %right_idx_72 = arith.muli %y_idx_67, %flip_17 : tensor<4x2x2xi32> loc(#loc245) + %right_idx_73 = "tt.reduce"(%right_idx_72) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_705: i32 loc(callsite(#loc1 at #loc246)), %right_idx_706: i32 loc(callsite(#loc1 at #loc246))): + %right_idx_707 = arith.addi %right_idx_705, %right_idx_706 : i32 loc(#loc322) + tt.reduce.return %right_idx_707 : i32 loc(#loc309) + }) : (tensor<4x2x2xi32>) -> tensor<4x2xi32> loc(#loc309) + %right_idx_74 = tt.expand_dims %right_idx_73 {axis = 1 : i32} : tensor<4x2xi32> -> tensor<4x1x2xi32> loc(#loc247) + %right_idx_75 = tt.broadcast %right_idx_74 : tensor<4x1x2xi32> -> tensor<4x2x2xi32> loc(#loc248) + %left_idx_76 = tt.reshape %left_idx_71 : tensor<4x2x2xi32> -> tensor<1x16xi32> loc(#loc249) + %right_idx_77 = tt.reshape %right_idx_75 : tensor<4x2x2xi32> -> tensor<1x16xi32> loc(#loc250) + %cond_78 = arith.cmpi slt, %ileft_65, %iright_66 : tensor<1x16xi32> loc(#loc251) + %eq_79 = arith.cmpi eq, %ileft_65, %iright_66 : tensor<1x16xi32> loc(#loc252) + %cond_80 = arith.cmpi sgt, %left_idx_76, %right_idx_77 : tensor<1x16xi32> loc(#loc253) + %cond_81 = arith.andi %eq_79, %cond_80 : tensor<1x16xi1> loc(#loc254) + %cond_82 = arith.ori %cond_78, %cond_81 : tensor<1x16xi1> loc(#loc255) + %cond_83 = arith.extui %cond_82 : tensor<1x16xi1> to tensor<1x16xi32> loc(#loc256) + %cond_84 = arith.xori %cond_83, %flip_54 : tensor<1x16xi32> loc(#loc256) + %cond_85 = arith.cmpi ne, %cond_84, %cst_1 : tensor<1x16xi32> loc(#loc257) + %ret_86 = arith.xori %ileft_65, %iright_66 : tensor<1x16xi32> loc(#loc258) + %ret_87 = arith.select %cond_85, %ret_86, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc259) + %ret_88 = arith.xori %ret_49, %ret_87 : tensor<1x16xi32> loc(#loc260) + %new_idxs_89 = arith.xori %left_idx_76, %right_idx_77 : tensor<1x16xi32> loc(#loc261) + %new_idxs_90 = arith.select %cond_85, %new_idxs_89, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc262) + %new_idxs_91 = arith.xori %new_idxs_52, %new_idxs_90 : tensor<1x16xi32> loc(#loc263) + %y_92 = tt.reshape %ret_88 : tensor<1x16xi32> -> tensor<8x2x1xi32> loc(#loc226) + %ileft_93 = arith.muli %y_92, %ileft : tensor<8x2x1xi32> loc(#loc228) + %ileft_94 = "tt.reduce"(%ileft_93) <{axis = 1 : i32}> ({ + ^bb0(%ileft_705: i32 loc(callsite(#loc1 at #loc229)), %ileft_706: i32 loc(callsite(#loc1 at #loc229))): + %ileft_707 = arith.addi %ileft_705, %ileft_706 : i32 loc(#loc319) + tt.reduce.return %ileft_707 : i32 loc(#loc301) + }) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc301) + %ileft_95 = tt.expand_dims %ileft_94 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc230) + %ileft_96 = tt.broadcast %ileft_95 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc231) + %iright_97 = arith.muli %y_92, %iright : tensor<8x2x1xi32> loc(#loc232) + %iright_98 = "tt.reduce"(%iright_97) <{axis = 1 : i32}> ({ + ^bb0(%iright_705: i32 loc(callsite(#loc1 at #loc233)), %iright_706: i32 loc(callsite(#loc1 at #loc233))): + %iright_707 = arith.addi %iright_705, %iright_706 : i32 loc(#loc320) + tt.reduce.return %iright_707 : i32 loc(#loc303) + }) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc303) + %iright_99 = tt.expand_dims %iright_98 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc234) + %iright_100 = tt.broadcast %iright_99 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc235) + %ileft_101 = tt.reshape %ileft_96 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc236) + %iright_102 = tt.reshape %iright_100 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc237) + %y_idx_103 = tt.reshape %new_idxs_91 : tensor<1x16xi32> -> tensor<8x2x1xi32> loc(#loc238) + %left_idx_104 = arith.muli %y_idx_103, %ileft : tensor<8x2x1xi32> loc(#loc240) + %left_idx_105 = "tt.reduce"(%left_idx_104) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_705: i32 loc(callsite(#loc1 at #loc241)), %left_idx_706: i32 loc(callsite(#loc1 at #loc241))): + %left_idx_707 = arith.addi %left_idx_705, %left_idx_706 : i32 loc(#loc321) + tt.reduce.return %left_idx_707 : i32 loc(#loc306) + }) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc306) + %left_idx_106 = tt.expand_dims %left_idx_105 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc242) + %left_idx_107 = tt.broadcast %left_idx_106 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc243) + %right_idx_108 = arith.muli %y_idx_103, %iright : tensor<8x2x1xi32> loc(#loc245) + %right_idx_109 = "tt.reduce"(%right_idx_108) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_705: i32 loc(callsite(#loc1 at #loc246)), %right_idx_706: i32 loc(callsite(#loc1 at #loc246))): + %right_idx_707 = arith.addi %right_idx_705, %right_idx_706 : i32 loc(#loc322) + tt.reduce.return %right_idx_707 : i32 loc(#loc309) + }) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc309) + %right_idx_110 = tt.expand_dims %right_idx_109 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc247) + %right_idx_111 = tt.broadcast %right_idx_110 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc248) + %left_idx_112 = tt.reshape %left_idx_107 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc249) + %right_idx_113 = tt.reshape %right_idx_111 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc250) + %cond_114 = arith.cmpi slt, %ileft_101, %iright_102 : tensor<1x16xi32> loc(#loc251) + %eq_115 = arith.cmpi eq, %ileft_101, %iright_102 : tensor<1x16xi32> loc(#loc252) + %cond_116 = arith.cmpi sgt, %left_idx_112, %right_idx_113 : tensor<1x16xi32> loc(#loc253) + %cond_117 = arith.andi %eq_115, %cond_116 : tensor<1x16xi1> loc(#loc254) + %cond_118 = arith.ori %cond_114, %cond_117 : tensor<1x16xi1> loc(#loc255) + %cond_119 = arith.extui %cond_118 : tensor<1x16xi1> to tensor<1x16xi32> loc(#loc256) + %cond_120 = arith.xori %cond_119, %flip_54 : tensor<1x16xi32> loc(#loc256) + %cond_121 = arith.cmpi ne, %cond_120, %cst_1 : tensor<1x16xi32> loc(#loc257) + %ret_122 = arith.xori %ileft_101, %iright_102 : tensor<1x16xi32> loc(#loc258) + %ret_123 = arith.select %cond_121, %ret_122, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc259) + %ret_124 = arith.xori %ret_88, %ret_123 : tensor<1x16xi32> loc(#loc260) + %new_idxs_125 = arith.xori %left_idx_112, %right_idx_113 : tensor<1x16xi32> loc(#loc261) + %new_idxs_126 = arith.select %cond_121, %new_idxs_125, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc262) + %new_idxs_127 = arith.xori %new_idxs_91, %new_idxs_126 : tensor<1x16xi32> loc(#loc263) + %flip_128 = tt.broadcast %flip_16 : tensor<1x2x1xi32> -> tensor<1x2x8xi32> loc(#loc213) + %flip_129 = tt.reshape %flip_128 : tensor<1x2x8xi32> -> tensor<1x16xi32> loc(#loc214) + %y_130 = tt.reshape %ret_124 : tensor<1x16xi32> -> tensor<2x2x4xi32> loc(#loc226) + %ileft_131 = tt.broadcast %left_mask : tensor<1x2x1xi32> -> tensor<2x2x4xi32> loc(#loc228) + %ileft_132 = arith.muli %y_130, %ileft_131 : tensor<2x2x4xi32> loc(#loc228) + %ileft_133 = "tt.reduce"(%ileft_132) <{axis = 1 : i32}> ({ + ^bb0(%ileft_705: i32 loc(callsite(#loc1 at #loc229)), %ileft_706: i32 loc(callsite(#loc1 at #loc229))): + %ileft_707 = arith.addi %ileft_705, %ileft_706 : i32 loc(#loc319) + tt.reduce.return %ileft_707 : i32 loc(#loc301) + }) : (tensor<2x2x4xi32>) -> tensor<2x4xi32> loc(#loc301) + %ileft_134 = tt.expand_dims %ileft_133 {axis = 1 : i32} : tensor<2x4xi32> -> tensor<2x1x4xi32> loc(#loc230) + %ileft_135 = tt.broadcast %ileft_134 : tensor<2x1x4xi32> -> tensor<2x2x4xi32> loc(#loc231) + %iright_136 = arith.muli %y_130, %flip_53 : tensor<2x2x4xi32> loc(#loc232) + %iright_137 = "tt.reduce"(%iright_136) <{axis = 1 : i32}> ({ + ^bb0(%iright_705: i32 loc(callsite(#loc1 at #loc233)), %iright_706: i32 loc(callsite(#loc1 at #loc233))): + %iright_707 = arith.addi %iright_705, %iright_706 : i32 loc(#loc320) + tt.reduce.return %iright_707 : i32 loc(#loc303) + }) : (tensor<2x2x4xi32>) -> tensor<2x4xi32> loc(#loc303) + %iright_138 = tt.expand_dims %iright_137 {axis = 1 : i32} : tensor<2x4xi32> -> tensor<2x1x4xi32> loc(#loc234) + %iright_139 = tt.broadcast %iright_138 : tensor<2x1x4xi32> -> tensor<2x2x4xi32> loc(#loc235) + %ileft_140 = tt.reshape %ileft_135 : tensor<2x2x4xi32> -> tensor<1x16xi32> loc(#loc236) + %iright_141 = tt.reshape %iright_139 : tensor<2x2x4xi32> -> tensor<1x16xi32> loc(#loc237) + %y_idx_142 = tt.reshape %new_idxs_127 : tensor<1x16xi32> -> tensor<2x2x4xi32> loc(#loc238) + %left_idx_143 = arith.muli %y_idx_142, %ileft_131 : tensor<2x2x4xi32> loc(#loc240) + %left_idx_144 = "tt.reduce"(%left_idx_143) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_705: i32 loc(callsite(#loc1 at #loc241)), %left_idx_706: i32 loc(callsite(#loc1 at #loc241))): + %left_idx_707 = arith.addi %left_idx_705, %left_idx_706 : i32 loc(#loc321) + tt.reduce.return %left_idx_707 : i32 loc(#loc306) + }) : (tensor<2x2x4xi32>) -> tensor<2x4xi32> loc(#loc306) + %left_idx_145 = tt.expand_dims %left_idx_144 {axis = 1 : i32} : tensor<2x4xi32> -> tensor<2x1x4xi32> loc(#loc242) + %left_idx_146 = tt.broadcast %left_idx_145 : tensor<2x1x4xi32> -> tensor<2x2x4xi32> loc(#loc243) + %right_idx_147 = arith.muli %y_idx_142, %flip_53 : tensor<2x2x4xi32> loc(#loc245) + %right_idx_148 = "tt.reduce"(%right_idx_147) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_705: i32 loc(callsite(#loc1 at #loc246)), %right_idx_706: i32 loc(callsite(#loc1 at #loc246))): + %right_idx_707 = arith.addi %right_idx_705, %right_idx_706 : i32 loc(#loc322) + tt.reduce.return %right_idx_707 : i32 loc(#loc309) + }) : (tensor<2x2x4xi32>) -> tensor<2x4xi32> loc(#loc309) + %right_idx_149 = tt.expand_dims %right_idx_148 {axis = 1 : i32} : tensor<2x4xi32> -> tensor<2x1x4xi32> loc(#loc247) + %right_idx_150 = tt.broadcast %right_idx_149 : tensor<2x1x4xi32> -> tensor<2x2x4xi32> loc(#loc248) + %left_idx_151 = tt.reshape %left_idx_146 : tensor<2x2x4xi32> -> tensor<1x16xi32> loc(#loc249) + %right_idx_152 = tt.reshape %right_idx_150 : tensor<2x2x4xi32> -> tensor<1x16xi32> loc(#loc250) + %cond_153 = arith.cmpi slt, %ileft_140, %iright_141 : tensor<1x16xi32> loc(#loc251) + %eq_154 = arith.cmpi eq, %ileft_140, %iright_141 : tensor<1x16xi32> loc(#loc252) + %cond_155 = arith.cmpi sgt, %left_idx_151, %right_idx_152 : tensor<1x16xi32> loc(#loc253) + %cond_156 = arith.andi %eq_154, %cond_155 : tensor<1x16xi1> loc(#loc254) + %cond_157 = arith.ori %cond_153, %cond_156 : tensor<1x16xi1> loc(#loc255) + %cond_158 = arith.extui %cond_157 : tensor<1x16xi1> to tensor<1x16xi32> loc(#loc256) + %cond_159 = arith.xori %cond_158, %flip_129 : tensor<1x16xi32> loc(#loc256) + %cond_160 = arith.cmpi ne, %cond_159, %cst_1 : tensor<1x16xi32> loc(#loc257) + %ret_161 = arith.xori %ileft_140, %iright_141 : tensor<1x16xi32> loc(#loc258) + %ret_162 = arith.select %cond_160, %ret_161, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc259) + %ret_163 = arith.xori %ret_124, %ret_162 : tensor<1x16xi32> loc(#loc260) + %new_idxs_164 = arith.xori %left_idx_151, %right_idx_152 : tensor<1x16xi32> loc(#loc261) + %new_idxs_165 = arith.select %cond_160, %new_idxs_164, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc262) + %new_idxs_166 = arith.xori %new_idxs_127, %new_idxs_165 : tensor<1x16xi32> loc(#loc263) + %y_167 = tt.reshape %ret_163 : tensor<1x16xi32> -> tensor<4x2x2xi32> loc(#loc226) + %ileft_168 = arith.muli %y_167, %ileft_56 : tensor<4x2x2xi32> loc(#loc228) + %ileft_169 = "tt.reduce"(%ileft_168) <{axis = 1 : i32}> ({ + ^bb0(%ileft_705: i32 loc(callsite(#loc1 at #loc229)), %ileft_706: i32 loc(callsite(#loc1 at #loc229))): + %ileft_707 = arith.addi %ileft_705, %ileft_706 : i32 loc(#loc319) + tt.reduce.return %ileft_707 : i32 loc(#loc301) + }) : (tensor<4x2x2xi32>) -> tensor<4x2xi32> loc(#loc301) + %ileft_170 = tt.expand_dims %ileft_169 {axis = 1 : i32} : tensor<4x2xi32> -> tensor<4x1x2xi32> loc(#loc230) + %ileft_171 = tt.broadcast %ileft_170 : tensor<4x1x2xi32> -> tensor<4x2x2xi32> loc(#loc231) + %iright_172 = arith.muli %y_167, %flip_17 : tensor<4x2x2xi32> loc(#loc232) + %iright_173 = "tt.reduce"(%iright_172) <{axis = 1 : i32}> ({ + ^bb0(%iright_705: i32 loc(callsite(#loc1 at #loc233)), %iright_706: i32 loc(callsite(#loc1 at #loc233))): + %iright_707 = arith.addi %iright_705, %iright_706 : i32 loc(#loc320) + tt.reduce.return %iright_707 : i32 loc(#loc303) + }) : (tensor<4x2x2xi32>) -> tensor<4x2xi32> loc(#loc303) + %iright_174 = tt.expand_dims %iright_173 {axis = 1 : i32} : tensor<4x2xi32> -> tensor<4x1x2xi32> loc(#loc234) + %iright_175 = tt.broadcast %iright_174 : tensor<4x1x2xi32> -> tensor<4x2x2xi32> loc(#loc235) + %ileft_176 = tt.reshape %ileft_171 : tensor<4x2x2xi32> -> tensor<1x16xi32> loc(#loc236) + %iright_177 = tt.reshape %iright_175 : tensor<4x2x2xi32> -> tensor<1x16xi32> loc(#loc237) + %y_idx_178 = tt.reshape %new_idxs_166 : tensor<1x16xi32> -> tensor<4x2x2xi32> loc(#loc238) + %left_idx_179 = arith.muli %y_idx_178, %ileft_56 : tensor<4x2x2xi32> loc(#loc240) + %left_idx_180 = "tt.reduce"(%left_idx_179) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_705: i32 loc(callsite(#loc1 at #loc241)), %left_idx_706: i32 loc(callsite(#loc1 at #loc241))): + %left_idx_707 = arith.addi %left_idx_705, %left_idx_706 : i32 loc(#loc321) + tt.reduce.return %left_idx_707 : i32 loc(#loc306) + }) : (tensor<4x2x2xi32>) -> tensor<4x2xi32> loc(#loc306) + %left_idx_181 = tt.expand_dims %left_idx_180 {axis = 1 : i32} : tensor<4x2xi32> -> tensor<4x1x2xi32> loc(#loc242) + %left_idx_182 = tt.broadcast %left_idx_181 : tensor<4x1x2xi32> -> tensor<4x2x2xi32> loc(#loc243) + %right_idx_183 = arith.muli %y_idx_178, %flip_17 : tensor<4x2x2xi32> loc(#loc245) + %right_idx_184 = "tt.reduce"(%right_idx_183) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_705: i32 loc(callsite(#loc1 at #loc246)), %right_idx_706: i32 loc(callsite(#loc1 at #loc246))): + %right_idx_707 = arith.addi %right_idx_705, %right_idx_706 : i32 loc(#loc322) + tt.reduce.return %right_idx_707 : i32 loc(#loc309) + }) : (tensor<4x2x2xi32>) -> tensor<4x2xi32> loc(#loc309) + %right_idx_185 = tt.expand_dims %right_idx_184 {axis = 1 : i32} : tensor<4x2xi32> -> tensor<4x1x2xi32> loc(#loc247) + %right_idx_186 = tt.broadcast %right_idx_185 : tensor<4x1x2xi32> -> tensor<4x2x2xi32> loc(#loc248) + %left_idx_187 = tt.reshape %left_idx_182 : tensor<4x2x2xi32> -> tensor<1x16xi32> loc(#loc249) + %right_idx_188 = tt.reshape %right_idx_186 : tensor<4x2x2xi32> -> tensor<1x16xi32> loc(#loc250) + %cond_189 = arith.cmpi slt, %ileft_176, %iright_177 : tensor<1x16xi32> loc(#loc251) + %eq_190 = arith.cmpi eq, %ileft_176, %iright_177 : tensor<1x16xi32> loc(#loc252) + %cond_191 = arith.cmpi sgt, %left_idx_187, %right_idx_188 : tensor<1x16xi32> loc(#loc253) + %cond_192 = arith.andi %eq_190, %cond_191 : tensor<1x16xi1> loc(#loc254) + %cond_193 = arith.ori %cond_189, %cond_192 : tensor<1x16xi1> loc(#loc255) + %cond_194 = arith.extui %cond_193 : tensor<1x16xi1> to tensor<1x16xi32> loc(#loc256) + %cond_195 = arith.xori %cond_194, %flip_129 : tensor<1x16xi32> loc(#loc256) + %cond_196 = arith.cmpi ne, %cond_195, %cst_1 : tensor<1x16xi32> loc(#loc257) + %ret_197 = arith.xori %ileft_176, %iright_177 : tensor<1x16xi32> loc(#loc258) + %ret_198 = arith.select %cond_196, %ret_197, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc259) + %ret_199 = arith.xori %ret_163, %ret_198 : tensor<1x16xi32> loc(#loc260) + %new_idxs_200 = arith.xori %left_idx_187, %right_idx_188 : tensor<1x16xi32> loc(#loc261) + %new_idxs_201 = arith.select %cond_196, %new_idxs_200, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc262) + %new_idxs_202 = arith.xori %new_idxs_166, %new_idxs_201 : tensor<1x16xi32> loc(#loc263) + %y_203 = tt.reshape %ret_199 : tensor<1x16xi32> -> tensor<8x2x1xi32> loc(#loc226) + %ileft_204 = arith.muli %y_203, %ileft : tensor<8x2x1xi32> loc(#loc228) + %ileft_205 = "tt.reduce"(%ileft_204) <{axis = 1 : i32}> ({ + ^bb0(%ileft_705: i32 loc(callsite(#loc1 at #loc229)), %ileft_706: i32 loc(callsite(#loc1 at #loc229))): + %ileft_707 = arith.addi %ileft_705, %ileft_706 : i32 loc(#loc319) + tt.reduce.return %ileft_707 : i32 loc(#loc301) + }) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc301) + %ileft_206 = tt.expand_dims %ileft_205 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc230) + %ileft_207 = tt.broadcast %ileft_206 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc231) + %iright_208 = arith.muli %y_203, %iright : tensor<8x2x1xi32> loc(#loc232) + %iright_209 = "tt.reduce"(%iright_208) <{axis = 1 : i32}> ({ + ^bb0(%iright_705: i32 loc(callsite(#loc1 at #loc233)), %iright_706: i32 loc(callsite(#loc1 at #loc233))): + %iright_707 = arith.addi %iright_705, %iright_706 : i32 loc(#loc320) + tt.reduce.return %iright_707 : i32 loc(#loc303) + }) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc303) + %iright_210 = tt.expand_dims %iright_209 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc234) + %iright_211 = tt.broadcast %iright_210 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc235) + %ileft_212 = tt.reshape %ileft_207 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc236) + %iright_213 = tt.reshape %iright_211 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc237) + %y_idx_214 = tt.reshape %new_idxs_202 : tensor<1x16xi32> -> tensor<8x2x1xi32> loc(#loc238) + %left_idx_215 = arith.muli %y_idx_214, %ileft : tensor<8x2x1xi32> loc(#loc240) + %left_idx_216 = "tt.reduce"(%left_idx_215) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_705: i32 loc(callsite(#loc1 at #loc241)), %left_idx_706: i32 loc(callsite(#loc1 at #loc241))): + %left_idx_707 = arith.addi %left_idx_705, %left_idx_706 : i32 loc(#loc321) + tt.reduce.return %left_idx_707 : i32 loc(#loc306) + }) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc306) + %left_idx_217 = tt.expand_dims %left_idx_216 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc242) + %left_idx_218 = tt.broadcast %left_idx_217 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc243) + %right_idx_219 = arith.muli %y_idx_214, %iright : tensor<8x2x1xi32> loc(#loc245) + %right_idx_220 = "tt.reduce"(%right_idx_219) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_705: i32 loc(callsite(#loc1 at #loc246)), %right_idx_706: i32 loc(callsite(#loc1 at #loc246))): + %right_idx_707 = arith.addi %right_idx_705, %right_idx_706 : i32 loc(#loc322) + tt.reduce.return %right_idx_707 : i32 loc(#loc309) + }) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc309) + %right_idx_221 = tt.expand_dims %right_idx_220 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc247) + %right_idx_222 = tt.broadcast %right_idx_221 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc248) + %left_idx_223 = tt.reshape %left_idx_218 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc249) + %right_idx_224 = tt.reshape %right_idx_222 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc250) + %cond_225 = arith.cmpi slt, %ileft_212, %iright_213 : tensor<1x16xi32> loc(#loc251) + %eq_226 = arith.cmpi eq, %ileft_212, %iright_213 : tensor<1x16xi32> loc(#loc252) + %cond_227 = arith.cmpi sgt, %left_idx_223, %right_idx_224 : tensor<1x16xi32> loc(#loc253) + %cond_228 = arith.andi %eq_226, %cond_227 : tensor<1x16xi1> loc(#loc254) + %cond_229 = arith.ori %cond_225, %cond_228 : tensor<1x16xi1> loc(#loc255) + %cond_230 = arith.extui %cond_229 : tensor<1x16xi1> to tensor<1x16xi32> loc(#loc256) + %cond_231 = arith.xori %cond_230, %flip_129 : tensor<1x16xi32> loc(#loc256) + %cond_232 = arith.cmpi ne, %cond_231, %cst_1 : tensor<1x16xi32> loc(#loc257) + %ret_233 = arith.xori %ileft_212, %iright_213 : tensor<1x16xi32> loc(#loc258) + %ret_234 = arith.select %cond_232, %ret_233, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc259) + %ret_235 = arith.xori %ret_199, %ret_234 : tensor<1x16xi32> loc(#loc260) + %new_idxs_236 = arith.xori %left_idx_223, %right_idx_224 : tensor<1x16xi32> loc(#loc261) + %new_idxs_237 = arith.select %cond_232, %new_idxs_236, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc262) + %new_idxs_238 = arith.xori %new_idxs_202, %new_idxs_237 : tensor<1x16xi32> loc(#loc263) + %y_239 = tt.reshape %ret_235 : tensor<1x16xi32> -> tensor<1x2x8xi32> loc(#loc226) + %ileft_240 = tt.broadcast %left_mask : tensor<1x2x1xi32> -> tensor<1x2x8xi32> loc(#loc228) + %ileft_241 = arith.muli %y_239, %ileft_240 : tensor<1x2x8xi32> loc(#loc228) + %ileft_242 = "tt.reduce"(%ileft_241) <{axis = 1 : i32}> ({ + ^bb0(%ileft_705: i32 loc(callsite(#loc1 at #loc229)), %ileft_706: i32 loc(callsite(#loc1 at #loc229))): + %ileft_707 = arith.addi %ileft_705, %ileft_706 : i32 loc(#loc319) + tt.reduce.return %ileft_707 : i32 loc(#loc301) + }) : (tensor<1x2x8xi32>) -> tensor<1x8xi32> loc(#loc301) + %ileft_243 = tt.expand_dims %ileft_242 {axis = 1 : i32} : tensor<1x8xi32> -> tensor<1x1x8xi32> loc(#loc230) + %ileft_244 = tt.broadcast %ileft_243 : tensor<1x1x8xi32> -> tensor<1x2x8xi32> loc(#loc231) + %iright_245 = arith.muli %y_239, %flip_128 : tensor<1x2x8xi32> loc(#loc232) + %iright_246 = "tt.reduce"(%iright_245) <{axis = 1 : i32}> ({ + ^bb0(%iright_705: i32 loc(callsite(#loc1 at #loc233)), %iright_706: i32 loc(callsite(#loc1 at #loc233))): + %iright_707 = arith.addi %iright_705, %iright_706 : i32 loc(#loc320) + tt.reduce.return %iright_707 : i32 loc(#loc303) + }) : (tensor<1x2x8xi32>) -> tensor<1x8xi32> loc(#loc303) + %iright_247 = tt.expand_dims %iright_246 {axis = 1 : i32} : tensor<1x8xi32> -> tensor<1x1x8xi32> loc(#loc234) + %iright_248 = tt.broadcast %iright_247 : tensor<1x1x8xi32> -> tensor<1x2x8xi32> loc(#loc235) + %ileft_249 = tt.reshape %ileft_244 : tensor<1x2x8xi32> -> tensor<1x16xi32> loc(#loc236) + %iright_250 = tt.reshape %iright_248 : tensor<1x2x8xi32> -> tensor<1x16xi32> loc(#loc237) + %y_idx_251 = tt.reshape %new_idxs_238 : tensor<1x16xi32> -> tensor<1x2x8xi32> loc(#loc238) + %left_idx_252 = arith.muli %y_idx_251, %ileft_240 : tensor<1x2x8xi32> loc(#loc240) + %left_idx_253 = "tt.reduce"(%left_idx_252) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_705: i32 loc(callsite(#loc1 at #loc241)), %left_idx_706: i32 loc(callsite(#loc1 at #loc241))): + %left_idx_707 = arith.addi %left_idx_705, %left_idx_706 : i32 loc(#loc321) + tt.reduce.return %left_idx_707 : i32 loc(#loc306) + }) : (tensor<1x2x8xi32>) -> tensor<1x8xi32> loc(#loc306) + %left_idx_254 = tt.expand_dims %left_idx_253 {axis = 1 : i32} : tensor<1x8xi32> -> tensor<1x1x8xi32> loc(#loc242) + %left_idx_255 = tt.broadcast %left_idx_254 : tensor<1x1x8xi32> -> tensor<1x2x8xi32> loc(#loc243) + %right_idx_256 = arith.muli %y_idx_251, %flip_128 : tensor<1x2x8xi32> loc(#loc245) + %right_idx_257 = "tt.reduce"(%right_idx_256) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_705: i32 loc(callsite(#loc1 at #loc246)), %right_idx_706: i32 loc(callsite(#loc1 at #loc246))): + %right_idx_707 = arith.addi %right_idx_705, %right_idx_706 : i32 loc(#loc322) + tt.reduce.return %right_idx_707 : i32 loc(#loc309) + }) : (tensor<1x2x8xi32>) -> tensor<1x8xi32> loc(#loc309) + %right_idx_258 = tt.expand_dims %right_idx_257 {axis = 1 : i32} : tensor<1x8xi32> -> tensor<1x1x8xi32> loc(#loc247) + %right_idx_259 = tt.broadcast %right_idx_258 : tensor<1x1x8xi32> -> tensor<1x2x8xi32> loc(#loc248) + %left_idx_260 = tt.reshape %left_idx_255 : tensor<1x2x8xi32> -> tensor<1x16xi32> loc(#loc249) + %right_idx_261 = tt.reshape %right_idx_259 : tensor<1x2x8xi32> -> tensor<1x16xi32> loc(#loc250) + %cond_262 = arith.cmpi slt, %ileft_249, %iright_250 : tensor<1x16xi32> loc(#loc251) + %eq_263 = arith.cmpi eq, %ileft_249, %iright_250 : tensor<1x16xi32> loc(#loc252) + %cond_264 = arith.cmpi sgt, %left_idx_260, %right_idx_261 : tensor<1x16xi32> loc(#loc253) + %cond_265 = arith.andi %eq_263, %cond_264 : tensor<1x16xi1> loc(#loc254) + %cond_266 = arith.ori %cond_262, %cond_265 : tensor<1x16xi1> loc(#loc255) + %ret_267 = arith.xori %ileft_249, %iright_250 : tensor<1x16xi32> loc(#loc258) + %ret_268 = arith.select %cond_266, %ret_267, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc259) + %ret_269 = arith.xori %ret_235, %ret_268 : tensor<1x16xi32> loc(#loc260) + %new_idxs_270 = arith.xori %left_idx_260, %right_idx_261 : tensor<1x16xi32> loc(#loc261) + %new_idxs_271 = arith.select %cond_266, %new_idxs_270, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc262) + %new_idxs_272 = arith.xori %new_idxs_238, %new_idxs_271 : tensor<1x16xi32> loc(#loc263) + %y_273 = tt.reshape %ret_269 : tensor<1x16xi32> -> tensor<2x2x4xi32> loc(#loc226) + %ileft_274 = arith.muli %y_273, %ileft_131 : tensor<2x2x4xi32> loc(#loc228) + %ileft_275 = "tt.reduce"(%ileft_274) <{axis = 1 : i32}> ({ + ^bb0(%ileft_705: i32 loc(callsite(#loc1 at #loc229)), %ileft_706: i32 loc(callsite(#loc1 at #loc229))): + %ileft_707 = arith.addi %ileft_705, %ileft_706 : i32 loc(#loc319) + tt.reduce.return %ileft_707 : i32 loc(#loc301) + }) : (tensor<2x2x4xi32>) -> tensor<2x4xi32> loc(#loc301) + %ileft_276 = tt.expand_dims %ileft_275 {axis = 1 : i32} : tensor<2x4xi32> -> tensor<2x1x4xi32> loc(#loc230) + %ileft_277 = tt.broadcast %ileft_276 : tensor<2x1x4xi32> -> tensor<2x2x4xi32> loc(#loc231) + %iright_278 = arith.muli %y_273, %flip_53 : tensor<2x2x4xi32> loc(#loc232) + %iright_279 = "tt.reduce"(%iright_278) <{axis = 1 : i32}> ({ + ^bb0(%iright_705: i32 loc(callsite(#loc1 at #loc233)), %iright_706: i32 loc(callsite(#loc1 at #loc233))): + %iright_707 = arith.addi %iright_705, %iright_706 : i32 loc(#loc320) + tt.reduce.return %iright_707 : i32 loc(#loc303) + }) : (tensor<2x2x4xi32>) -> tensor<2x4xi32> loc(#loc303) + %iright_280 = tt.expand_dims %iright_279 {axis = 1 : i32} : tensor<2x4xi32> -> tensor<2x1x4xi32> loc(#loc234) + %iright_281 = tt.broadcast %iright_280 : tensor<2x1x4xi32> -> tensor<2x2x4xi32> loc(#loc235) + %ileft_282 = tt.reshape %ileft_277 : tensor<2x2x4xi32> -> tensor<1x16xi32> loc(#loc236) + %iright_283 = tt.reshape %iright_281 : tensor<2x2x4xi32> -> tensor<1x16xi32> loc(#loc237) + %y_idx_284 = tt.reshape %new_idxs_272 : tensor<1x16xi32> -> tensor<2x2x4xi32> loc(#loc238) + %left_idx_285 = arith.muli %y_idx_284, %ileft_131 : tensor<2x2x4xi32> loc(#loc240) + %left_idx_286 = "tt.reduce"(%left_idx_285) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_705: i32 loc(callsite(#loc1 at #loc241)), %left_idx_706: i32 loc(callsite(#loc1 at #loc241))): + %left_idx_707 = arith.addi %left_idx_705, %left_idx_706 : i32 loc(#loc321) + tt.reduce.return %left_idx_707 : i32 loc(#loc306) + }) : (tensor<2x2x4xi32>) -> tensor<2x4xi32> loc(#loc306) + %left_idx_287 = tt.expand_dims %left_idx_286 {axis = 1 : i32} : tensor<2x4xi32> -> tensor<2x1x4xi32> loc(#loc242) + %left_idx_288 = tt.broadcast %left_idx_287 : tensor<2x1x4xi32> -> tensor<2x2x4xi32> loc(#loc243) + %right_idx_289 = arith.muli %y_idx_284, %flip_53 : tensor<2x2x4xi32> loc(#loc245) + %right_idx_290 = "tt.reduce"(%right_idx_289) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_705: i32 loc(callsite(#loc1 at #loc246)), %right_idx_706: i32 loc(callsite(#loc1 at #loc246))): + %right_idx_707 = arith.addi %right_idx_705, %right_idx_706 : i32 loc(#loc322) + tt.reduce.return %right_idx_707 : i32 loc(#loc309) + }) : (tensor<2x2x4xi32>) -> tensor<2x4xi32> loc(#loc309) + %right_idx_291 = tt.expand_dims %right_idx_290 {axis = 1 : i32} : tensor<2x4xi32> -> tensor<2x1x4xi32> loc(#loc247) + %right_idx_292 = tt.broadcast %right_idx_291 : tensor<2x1x4xi32> -> tensor<2x2x4xi32> loc(#loc248) + %left_idx_293 = tt.reshape %left_idx_288 : tensor<2x2x4xi32> -> tensor<1x16xi32> loc(#loc249) + %right_idx_294 = tt.reshape %right_idx_292 : tensor<2x2x4xi32> -> tensor<1x16xi32> loc(#loc250) + %cond_295 = arith.cmpi slt, %ileft_282, %iright_283 : tensor<1x16xi32> loc(#loc251) + %eq_296 = arith.cmpi eq, %ileft_282, %iright_283 : tensor<1x16xi32> loc(#loc252) + %cond_297 = arith.cmpi sgt, %left_idx_293, %right_idx_294 : tensor<1x16xi32> loc(#loc253) + %cond_298 = arith.andi %eq_296, %cond_297 : tensor<1x16xi1> loc(#loc254) + %cond_299 = arith.ori %cond_295, %cond_298 : tensor<1x16xi1> loc(#loc255) + %ret_300 = arith.xori %ileft_282, %iright_283 : tensor<1x16xi32> loc(#loc258) + %ret_301 = arith.select %cond_299, %ret_300, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc259) + %ret_302 = arith.xori %ret_269, %ret_301 : tensor<1x16xi32> loc(#loc260) + %new_idxs_303 = arith.xori %left_idx_293, %right_idx_294 : tensor<1x16xi32> loc(#loc261) + %new_idxs_304 = arith.select %cond_299, %new_idxs_303, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc262) + %new_idxs_305 = arith.xori %new_idxs_272, %new_idxs_304 : tensor<1x16xi32> loc(#loc263) + %y_306 = tt.reshape %ret_302 : tensor<1x16xi32> -> tensor<4x2x2xi32> loc(#loc226) + %ileft_307 = arith.muli %y_306, %ileft_56 : tensor<4x2x2xi32> loc(#loc228) + %ileft_308 = "tt.reduce"(%ileft_307) <{axis = 1 : i32}> ({ + ^bb0(%ileft_705: i32 loc(callsite(#loc1 at #loc229)), %ileft_706: i32 loc(callsite(#loc1 at #loc229))): + %ileft_707 = arith.addi %ileft_705, %ileft_706 : i32 loc(#loc319) + tt.reduce.return %ileft_707 : i32 loc(#loc301) + }) : (tensor<4x2x2xi32>) -> tensor<4x2xi32> loc(#loc301) + %ileft_309 = tt.expand_dims %ileft_308 {axis = 1 : i32} : tensor<4x2xi32> -> tensor<4x1x2xi32> loc(#loc230) + %ileft_310 = tt.broadcast %ileft_309 : tensor<4x1x2xi32> -> tensor<4x2x2xi32> loc(#loc231) + %iright_311 = arith.muli %y_306, %flip_17 : tensor<4x2x2xi32> loc(#loc232) + %iright_312 = "tt.reduce"(%iright_311) <{axis = 1 : i32}> ({ + ^bb0(%iright_705: i32 loc(callsite(#loc1 at #loc233)), %iright_706: i32 loc(callsite(#loc1 at #loc233))): + %iright_707 = arith.addi %iright_705, %iright_706 : i32 loc(#loc320) + tt.reduce.return %iright_707 : i32 loc(#loc303) + }) : (tensor<4x2x2xi32>) -> tensor<4x2xi32> loc(#loc303) + %iright_313 = tt.expand_dims %iright_312 {axis = 1 : i32} : tensor<4x2xi32> -> tensor<4x1x2xi32> loc(#loc234) + %iright_314 = tt.broadcast %iright_313 : tensor<4x1x2xi32> -> tensor<4x2x2xi32> loc(#loc235) + %ileft_315 = tt.reshape %ileft_310 : tensor<4x2x2xi32> -> tensor<1x16xi32> loc(#loc236) + %iright_316 = tt.reshape %iright_314 : tensor<4x2x2xi32> -> tensor<1x16xi32> loc(#loc237) + %y_idx_317 = tt.reshape %new_idxs_305 : tensor<1x16xi32> -> tensor<4x2x2xi32> loc(#loc238) + %left_idx_318 = arith.muli %y_idx_317, %ileft_56 : tensor<4x2x2xi32> loc(#loc240) + %left_idx_319 = "tt.reduce"(%left_idx_318) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_705: i32 loc(callsite(#loc1 at #loc241)), %left_idx_706: i32 loc(callsite(#loc1 at #loc241))): + %left_idx_707 = arith.addi %left_idx_705, %left_idx_706 : i32 loc(#loc321) + tt.reduce.return %left_idx_707 : i32 loc(#loc306) + }) : (tensor<4x2x2xi32>) -> tensor<4x2xi32> loc(#loc306) + %left_idx_320 = tt.expand_dims %left_idx_319 {axis = 1 : i32} : tensor<4x2xi32> -> tensor<4x1x2xi32> loc(#loc242) + %left_idx_321 = tt.broadcast %left_idx_320 : tensor<4x1x2xi32> -> tensor<4x2x2xi32> loc(#loc243) + %right_idx_322 = arith.muli %y_idx_317, %flip_17 : tensor<4x2x2xi32> loc(#loc245) + %right_idx_323 = "tt.reduce"(%right_idx_322) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_705: i32 loc(callsite(#loc1 at #loc246)), %right_idx_706: i32 loc(callsite(#loc1 at #loc246))): + %right_idx_707 = arith.addi %right_idx_705, %right_idx_706 : i32 loc(#loc322) + tt.reduce.return %right_idx_707 : i32 loc(#loc309) + }) : (tensor<4x2x2xi32>) -> tensor<4x2xi32> loc(#loc309) + %right_idx_324 = tt.expand_dims %right_idx_323 {axis = 1 : i32} : tensor<4x2xi32> -> tensor<4x1x2xi32> loc(#loc247) + %right_idx_325 = tt.broadcast %right_idx_324 : tensor<4x1x2xi32> -> tensor<4x2x2xi32> loc(#loc248) + %left_idx_326 = tt.reshape %left_idx_321 : tensor<4x2x2xi32> -> tensor<1x16xi32> loc(#loc249) + %right_idx_327 = tt.reshape %right_idx_325 : tensor<4x2x2xi32> -> tensor<1x16xi32> loc(#loc250) + %cond_328 = arith.cmpi slt, %ileft_315, %iright_316 : tensor<1x16xi32> loc(#loc251) + %eq_329 = arith.cmpi eq, %ileft_315, %iright_316 : tensor<1x16xi32> loc(#loc252) + %cond_330 = arith.cmpi sgt, %left_idx_326, %right_idx_327 : tensor<1x16xi32> loc(#loc253) + %cond_331 = arith.andi %eq_329, %cond_330 : tensor<1x16xi1> loc(#loc254) + %cond_332 = arith.ori %cond_328, %cond_331 : tensor<1x16xi1> loc(#loc255) + %ret_333 = arith.xori %ileft_315, %iright_316 : tensor<1x16xi32> loc(#loc258) + %ret_334 = arith.select %cond_332, %ret_333, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc259) + %ret_335 = arith.xori %ret_302, %ret_334 : tensor<1x16xi32> loc(#loc260) + %new_idxs_336 = arith.xori %left_idx_326, %right_idx_327 : tensor<1x16xi32> loc(#loc261) + %new_idxs_337 = arith.select %cond_332, %new_idxs_336, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc262) + %new_idxs_338 = arith.xori %new_idxs_305, %new_idxs_337 : tensor<1x16xi32> loc(#loc263) + %y_339 = tt.reshape %ret_335 : tensor<1x16xi32> -> tensor<8x2x1xi32> loc(#loc226) + %ileft_340 = arith.muli %y_339, %ileft : tensor<8x2x1xi32> loc(#loc228) + %ileft_341 = "tt.reduce"(%ileft_340) <{axis = 1 : i32}> ({ + ^bb0(%ileft_705: i32 loc(callsite(#loc1 at #loc229)), %ileft_706: i32 loc(callsite(#loc1 at #loc229))): + %ileft_707 = arith.addi %ileft_705, %ileft_706 : i32 loc(#loc319) + tt.reduce.return %ileft_707 : i32 loc(#loc301) + }) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc301) + %ileft_342 = tt.expand_dims %ileft_341 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc230) + %ileft_343 = tt.broadcast %ileft_342 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc231) + %iright_344 = arith.muli %y_339, %iright : tensor<8x2x1xi32> loc(#loc232) + %iright_345 = "tt.reduce"(%iright_344) <{axis = 1 : i32}> ({ + ^bb0(%iright_705: i32 loc(callsite(#loc1 at #loc233)), %iright_706: i32 loc(callsite(#loc1 at #loc233))): + %iright_707 = arith.addi %iright_705, %iright_706 : i32 loc(#loc320) + tt.reduce.return %iright_707 : i32 loc(#loc303) + }) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc303) + %iright_346 = tt.expand_dims %iright_345 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc234) + %iright_347 = tt.broadcast %iright_346 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc235) + %ileft_348 = tt.reshape %ileft_343 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc236) + %iright_349 = tt.reshape %iright_347 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc237) + %y_idx_350 = tt.reshape %new_idxs_338 : tensor<1x16xi32> -> tensor<8x2x1xi32> loc(#loc238) + %left_idx_351 = arith.muli %y_idx_350, %ileft : tensor<8x2x1xi32> loc(#loc240) + %left_idx_352 = "tt.reduce"(%left_idx_351) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_705: i32 loc(callsite(#loc1 at #loc241)), %left_idx_706: i32 loc(callsite(#loc1 at #loc241))): + %left_idx_707 = arith.addi %left_idx_705, %left_idx_706 : i32 loc(#loc321) + tt.reduce.return %left_idx_707 : i32 loc(#loc306) + }) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc306) + %left_idx_353 = tt.expand_dims %left_idx_352 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc242) + %left_idx_354 = tt.broadcast %left_idx_353 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc243) + %right_idx_355 = arith.muli %y_idx_350, %iright : tensor<8x2x1xi32> loc(#loc245) + %right_idx_356 = "tt.reduce"(%right_idx_355) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_705: i32 loc(callsite(#loc1 at #loc246)), %right_idx_706: i32 loc(callsite(#loc1 at #loc246))): + %right_idx_707 = arith.addi %right_idx_705, %right_idx_706 : i32 loc(#loc322) + tt.reduce.return %right_idx_707 : i32 loc(#loc309) + }) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc309) + %right_idx_357 = tt.expand_dims %right_idx_356 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc247) + %right_idx_358 = tt.broadcast %right_idx_357 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc248) + %left_idx_359 = tt.reshape %left_idx_354 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc249) + %right_idx_360 = tt.reshape %right_idx_358 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc250) + %cond_361 = arith.cmpi slt, %ileft_348, %iright_349 : tensor<1x16xi32> loc(#loc251) + %eq_362 = arith.cmpi eq, %ileft_348, %iright_349 : tensor<1x16xi32> loc(#loc252) + %cond_363 = arith.cmpi sgt, %left_idx_359, %right_idx_360 : tensor<1x16xi32> loc(#loc253) + %cond_364 = arith.andi %eq_362, %cond_363 : tensor<1x16xi1> loc(#loc254) + %cond_365 = arith.ori %cond_361, %cond_364 : tensor<1x16xi1> loc(#loc255) + %new_idxs_366 = arith.xori %left_idx_359, %right_idx_360 : tensor<1x16xi32> loc(#loc261) + %new_idxs_367 = arith.select %cond_365, %new_idxs_366, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc262) + %new_idxs_368 = arith.xori %new_idxs_338, %new_idxs_367 : tensor<1x16xi32> loc(#loc263) + %tmp14 = arith.cmpi eq, %tmp0_14, %cst_4 : tensor<1x16xi64> loc(#loc182) + %tmp16 = arith.extui %tmp14 : tensor<1x16xi1> to tensor<1x16xi32> loc(#loc218) + %y_369 = tt.reshape %tmp16 : tensor<1x16xi32> -> tensor<8x2x1xi32> loc(#loc264) + %ileft_370 = arith.muli %y_369, %ileft : tensor<8x2x1xi32> loc(#loc265) + %ileft_371 = "tt.reduce"(%ileft_370) <{axis = 1 : i32}> ({ + ^bb0(%ileft_705: i32 loc(callsite(#loc1 at #loc266)), %ileft_706: i32 loc(callsite(#loc1 at #loc266))): + %ileft_707 = arith.addi %ileft_705, %ileft_706 : i32 loc(#loc323) + tt.reduce.return %ileft_707 : i32 loc(#loc311) + }) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc311) + %ileft_372 = tt.expand_dims %ileft_371 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc267) + %ileft_373 = tt.broadcast %ileft_372 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc268) + %iright_374 = arith.muli %y_369, %iright : tensor<8x2x1xi32> loc(#loc269) + %iright_375 = "tt.reduce"(%iright_374) <{axis = 1 : i32}> ({ + ^bb0(%iright_705: i32 loc(callsite(#loc1 at #loc270)), %iright_706: i32 loc(callsite(#loc1 at #loc270))): + %iright_707 = arith.addi %iright_705, %iright_706 : i32 loc(#loc324) + tt.reduce.return %iright_707 : i32 loc(#loc313) + }) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc313) + %iright_376 = tt.expand_dims %iright_375 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc271) + %iright_377 = tt.broadcast %iright_376 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc272) + %ileft_378 = tt.reshape %ileft_373 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc273) + %iright_379 = tt.reshape %iright_377 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc274) + %cond_380 = arith.cmpi slt, %ileft_378, %iright_379 : tensor<1x16xi32> loc(#loc275) + %eq_381 = arith.cmpi eq, %ileft_378, %iright_379 : tensor<1x16xi32> loc(#loc276) + %cond_382 = arith.andi %eq_381, %cond_42 : tensor<1x16xi1> loc(#loc277) + %cond_383 = arith.ori %cond_380, %cond_382 : tensor<1x16xi1> loc(#loc278) + %cond_384 = arith.extui %cond_383 : tensor<1x16xi1> to tensor<1x16xi32> loc(#loc279) + %cond_385 = arith.xori %cond_384, %flip_18 : tensor<1x16xi32> loc(#loc279) + %cond_386 = arith.cmpi ne, %cond_385, %cst_1 : tensor<1x16xi32> loc(#loc280) + %ret_387 = arith.xori %ileft_378, %iright_379 : tensor<1x16xi32> loc(#loc281) + %ret_388 = arith.select %cond_386, %ret_387, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc282) + %ret_389 = arith.xori %tmp16, %ret_388 : tensor<1x16xi32> loc(#loc283) + %new_idxs_390 = arith.select %cond_386, %new_idxs, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc284) + %new_idxs_391 = arith.xori %new_idxs_51, %new_idxs_390 : tensor<1x16xi32> loc(#loc285) + %y_392 = tt.reshape %ret_389 : tensor<1x16xi32> -> tensor<4x2x2xi32> loc(#loc264) + %ileft_393 = arith.muli %y_392, %ileft_56 : tensor<4x2x2xi32> loc(#loc265) + %ileft_394 = "tt.reduce"(%ileft_393) <{axis = 1 : i32}> ({ + ^bb0(%ileft_705: i32 loc(callsite(#loc1 at #loc266)), %ileft_706: i32 loc(callsite(#loc1 at #loc266))): + %ileft_707 = arith.addi %ileft_705, %ileft_706 : i32 loc(#loc323) + tt.reduce.return %ileft_707 : i32 loc(#loc311) + }) : (tensor<4x2x2xi32>) -> tensor<4x2xi32> loc(#loc311) + %ileft_395 = tt.expand_dims %ileft_394 {axis = 1 : i32} : tensor<4x2xi32> -> tensor<4x1x2xi32> loc(#loc267) + %ileft_396 = tt.broadcast %ileft_395 : tensor<4x1x2xi32> -> tensor<4x2x2xi32> loc(#loc268) + %iright_397 = arith.muli %y_392, %flip_17 : tensor<4x2x2xi32> loc(#loc269) + %iright_398 = "tt.reduce"(%iright_397) <{axis = 1 : i32}> ({ + ^bb0(%iright_705: i32 loc(callsite(#loc1 at #loc270)), %iright_706: i32 loc(callsite(#loc1 at #loc270))): + %iright_707 = arith.addi %iright_705, %iright_706 : i32 loc(#loc324) + tt.reduce.return %iright_707 : i32 loc(#loc313) + }) : (tensor<4x2x2xi32>) -> tensor<4x2xi32> loc(#loc313) + %iright_399 = tt.expand_dims %iright_398 {axis = 1 : i32} : tensor<4x2xi32> -> tensor<4x1x2xi32> loc(#loc271) + %iright_400 = tt.broadcast %iright_399 : tensor<4x1x2xi32> -> tensor<4x2x2xi32> loc(#loc272) + %ileft_401 = tt.reshape %ileft_396 : tensor<4x2x2xi32> -> tensor<1x16xi32> loc(#loc273) + %iright_402 = tt.reshape %iright_400 : tensor<4x2x2xi32> -> tensor<1x16xi32> loc(#loc274) + %y_idx_403 = tt.reshape %new_idxs_391 : tensor<1x16xi32> -> tensor<4x2x2xi32> loc(#loc286) + %left_idx_404 = arith.muli %y_idx_403, %ileft_56 : tensor<4x2x2xi32> loc(#loc287) + %left_idx_405 = "tt.reduce"(%left_idx_404) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_705: i32 loc(callsite(#loc1 at #loc288)), %left_idx_706: i32 loc(callsite(#loc1 at #loc288))): + %left_idx_707 = arith.addi %left_idx_705, %left_idx_706 : i32 loc(#loc325) + tt.reduce.return %left_idx_707 : i32 loc(#loc315) + }) : (tensor<4x2x2xi32>) -> tensor<4x2xi32> loc(#loc315) + %left_idx_406 = tt.expand_dims %left_idx_405 {axis = 1 : i32} : tensor<4x2xi32> -> tensor<4x1x2xi32> loc(#loc289) + %left_idx_407 = tt.broadcast %left_idx_406 : tensor<4x1x2xi32> -> tensor<4x2x2xi32> loc(#loc290) + %right_idx_408 = arith.muli %y_idx_403, %flip_17 : tensor<4x2x2xi32> loc(#loc291) + %right_idx_409 = "tt.reduce"(%right_idx_408) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_705: i32 loc(callsite(#loc1 at #loc292)), %right_idx_706: i32 loc(callsite(#loc1 at #loc292))): + %right_idx_707 = arith.addi %right_idx_705, %right_idx_706 : i32 loc(#loc326) + tt.reduce.return %right_idx_707 : i32 loc(#loc317) + }) : (tensor<4x2x2xi32>) -> tensor<4x2xi32> loc(#loc317) + %right_idx_410 = tt.expand_dims %right_idx_409 {axis = 1 : i32} : tensor<4x2xi32> -> tensor<4x1x2xi32> loc(#loc293) + %right_idx_411 = tt.broadcast %right_idx_410 : tensor<4x1x2xi32> -> tensor<4x2x2xi32> loc(#loc294) + %left_idx_412 = tt.reshape %left_idx_407 : tensor<4x2x2xi32> -> tensor<1x16xi32> loc(#loc295) + %right_idx_413 = tt.reshape %right_idx_411 : tensor<4x2x2xi32> -> tensor<1x16xi32> loc(#loc296) + %cond_414 = arith.cmpi slt, %ileft_401, %iright_402 : tensor<1x16xi32> loc(#loc275) + %eq_415 = arith.cmpi eq, %ileft_401, %iright_402 : tensor<1x16xi32> loc(#loc276) + %cond_416 = arith.cmpi sgt, %left_idx_412, %right_idx_413 : tensor<1x16xi32> loc(#loc297) + %cond_417 = arith.andi %eq_415, %cond_416 : tensor<1x16xi1> loc(#loc277) + %cond_418 = arith.ori %cond_414, %cond_417 : tensor<1x16xi1> loc(#loc278) + %cond_419 = arith.extui %cond_418 : tensor<1x16xi1> to tensor<1x16xi32> loc(#loc279) + %cond_420 = arith.xori %cond_419, %flip_54 : tensor<1x16xi32> loc(#loc279) + %cond_421 = arith.cmpi ne, %cond_420, %cst_1 : tensor<1x16xi32> loc(#loc280) + %ret_422 = arith.xori %ileft_401, %iright_402 : tensor<1x16xi32> loc(#loc281) + %ret_423 = arith.select %cond_421, %ret_422, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc282) + %ret_424 = arith.xori %ret_389, %ret_423 : tensor<1x16xi32> loc(#loc283) + %new_idxs_425 = arith.xori %left_idx_412, %right_idx_413 : tensor<1x16xi32> loc(#loc298) + %new_idxs_426 = arith.select %cond_421, %new_idxs_425, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc284) + %new_idxs_427 = arith.xori %new_idxs_391, %new_idxs_426 : tensor<1x16xi32> loc(#loc285) + %y_428 = tt.reshape %ret_424 : tensor<1x16xi32> -> tensor<8x2x1xi32> loc(#loc264) + %ileft_429 = arith.muli %y_428, %ileft : tensor<8x2x1xi32> loc(#loc265) + %ileft_430 = "tt.reduce"(%ileft_429) <{axis = 1 : i32}> ({ + ^bb0(%ileft_705: i32 loc(callsite(#loc1 at #loc266)), %ileft_706: i32 loc(callsite(#loc1 at #loc266))): + %ileft_707 = arith.addi %ileft_705, %ileft_706 : i32 loc(#loc323) + tt.reduce.return %ileft_707 : i32 loc(#loc311) + }) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc311) + %ileft_431 = tt.expand_dims %ileft_430 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc267) + %ileft_432 = tt.broadcast %ileft_431 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc268) + %iright_433 = arith.muli %y_428, %iright : tensor<8x2x1xi32> loc(#loc269) + %iright_434 = "tt.reduce"(%iright_433) <{axis = 1 : i32}> ({ + ^bb0(%iright_705: i32 loc(callsite(#loc1 at #loc270)), %iright_706: i32 loc(callsite(#loc1 at #loc270))): + %iright_707 = arith.addi %iright_705, %iright_706 : i32 loc(#loc324) + tt.reduce.return %iright_707 : i32 loc(#loc313) + }) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc313) + %iright_435 = tt.expand_dims %iright_434 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc271) + %iright_436 = tt.broadcast %iright_435 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc272) + %ileft_437 = tt.reshape %ileft_432 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc273) + %iright_438 = tt.reshape %iright_436 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc274) + %y_idx_439 = tt.reshape %new_idxs_427 : tensor<1x16xi32> -> tensor<8x2x1xi32> loc(#loc286) + %left_idx_440 = arith.muli %y_idx_439, %ileft : tensor<8x2x1xi32> loc(#loc287) + %left_idx_441 = "tt.reduce"(%left_idx_440) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_705: i32 loc(callsite(#loc1 at #loc288)), %left_idx_706: i32 loc(callsite(#loc1 at #loc288))): + %left_idx_707 = arith.addi %left_idx_705, %left_idx_706 : i32 loc(#loc325) + tt.reduce.return %left_idx_707 : i32 loc(#loc315) + }) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc315) + %left_idx_442 = tt.expand_dims %left_idx_441 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc289) + %left_idx_443 = tt.broadcast %left_idx_442 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc290) + %right_idx_444 = arith.muli %y_idx_439, %iright : tensor<8x2x1xi32> loc(#loc291) + %right_idx_445 = "tt.reduce"(%right_idx_444) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_705: i32 loc(callsite(#loc1 at #loc292)), %right_idx_706: i32 loc(callsite(#loc1 at #loc292))): + %right_idx_707 = arith.addi %right_idx_705, %right_idx_706 : i32 loc(#loc326) + tt.reduce.return %right_idx_707 : i32 loc(#loc317) + }) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc317) + %right_idx_446 = tt.expand_dims %right_idx_445 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc293) + %right_idx_447 = tt.broadcast %right_idx_446 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc294) + %left_idx_448 = tt.reshape %left_idx_443 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc295) + %right_idx_449 = tt.reshape %right_idx_447 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc296) + %cond_450 = arith.cmpi slt, %ileft_437, %iright_438 : tensor<1x16xi32> loc(#loc275) + %eq_451 = arith.cmpi eq, %ileft_437, %iright_438 : tensor<1x16xi32> loc(#loc276) + %cond_452 = arith.cmpi sgt, %left_idx_448, %right_idx_449 : tensor<1x16xi32> loc(#loc297) + %cond_453 = arith.andi %eq_451, %cond_452 : tensor<1x16xi1> loc(#loc277) + %cond_454 = arith.ori %cond_450, %cond_453 : tensor<1x16xi1> loc(#loc278) + %cond_455 = arith.extui %cond_454 : tensor<1x16xi1> to tensor<1x16xi32> loc(#loc279) + %cond_456 = arith.xori %cond_455, %flip_54 : tensor<1x16xi32> loc(#loc279) + %cond_457 = arith.cmpi ne, %cond_456, %cst_1 : tensor<1x16xi32> loc(#loc280) + %ret_458 = arith.xori %ileft_437, %iright_438 : tensor<1x16xi32> loc(#loc281) + %ret_459 = arith.select %cond_457, %ret_458, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc282) + %ret_460 = arith.xori %ret_424, %ret_459 : tensor<1x16xi32> loc(#loc283) + %new_idxs_461 = arith.xori %left_idx_448, %right_idx_449 : tensor<1x16xi32> loc(#loc298) + %new_idxs_462 = arith.select %cond_457, %new_idxs_461, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc284) + %new_idxs_463 = arith.xori %new_idxs_427, %new_idxs_462 : tensor<1x16xi32> loc(#loc285) + %y_464 = tt.reshape %ret_460 : tensor<1x16xi32> -> tensor<2x2x4xi32> loc(#loc264) + %ileft_465 = arith.muli %y_464, %ileft_131 : tensor<2x2x4xi32> loc(#loc265) + %ileft_466 = "tt.reduce"(%ileft_465) <{axis = 1 : i32}> ({ + ^bb0(%ileft_705: i32 loc(callsite(#loc1 at #loc266)), %ileft_706: i32 loc(callsite(#loc1 at #loc266))): + %ileft_707 = arith.addi %ileft_705, %ileft_706 : i32 loc(#loc323) + tt.reduce.return %ileft_707 : i32 loc(#loc311) + }) : (tensor<2x2x4xi32>) -> tensor<2x4xi32> loc(#loc311) + %ileft_467 = tt.expand_dims %ileft_466 {axis = 1 : i32} : tensor<2x4xi32> -> tensor<2x1x4xi32> loc(#loc267) + %ileft_468 = tt.broadcast %ileft_467 : tensor<2x1x4xi32> -> tensor<2x2x4xi32> loc(#loc268) + %iright_469 = arith.muli %y_464, %flip_53 : tensor<2x2x4xi32> loc(#loc269) + %iright_470 = "tt.reduce"(%iright_469) <{axis = 1 : i32}> ({ + ^bb0(%iright_705: i32 loc(callsite(#loc1 at #loc270)), %iright_706: i32 loc(callsite(#loc1 at #loc270))): + %iright_707 = arith.addi %iright_705, %iright_706 : i32 loc(#loc324) + tt.reduce.return %iright_707 : i32 loc(#loc313) + }) : (tensor<2x2x4xi32>) -> tensor<2x4xi32> loc(#loc313) + %iright_471 = tt.expand_dims %iright_470 {axis = 1 : i32} : tensor<2x4xi32> -> tensor<2x1x4xi32> loc(#loc271) + %iright_472 = tt.broadcast %iright_471 : tensor<2x1x4xi32> -> tensor<2x2x4xi32> loc(#loc272) + %ileft_473 = tt.reshape %ileft_468 : tensor<2x2x4xi32> -> tensor<1x16xi32> loc(#loc273) + %iright_474 = tt.reshape %iright_472 : tensor<2x2x4xi32> -> tensor<1x16xi32> loc(#loc274) + %y_idx_475 = tt.reshape %new_idxs_463 : tensor<1x16xi32> -> tensor<2x2x4xi32> loc(#loc286) + %left_idx_476 = arith.muli %y_idx_475, %ileft_131 : tensor<2x2x4xi32> loc(#loc287) + %left_idx_477 = "tt.reduce"(%left_idx_476) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_705: i32 loc(callsite(#loc1 at #loc288)), %left_idx_706: i32 loc(callsite(#loc1 at #loc288))): + %left_idx_707 = arith.addi %left_idx_705, %left_idx_706 : i32 loc(#loc325) + tt.reduce.return %left_idx_707 : i32 loc(#loc315) + }) : (tensor<2x2x4xi32>) -> tensor<2x4xi32> loc(#loc315) + %left_idx_478 = tt.expand_dims %left_idx_477 {axis = 1 : i32} : tensor<2x4xi32> -> tensor<2x1x4xi32> loc(#loc289) + %left_idx_479 = tt.broadcast %left_idx_478 : tensor<2x1x4xi32> -> tensor<2x2x4xi32> loc(#loc290) + %right_idx_480 = arith.muli %y_idx_475, %flip_53 : tensor<2x2x4xi32> loc(#loc291) + %right_idx_481 = "tt.reduce"(%right_idx_480) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_705: i32 loc(callsite(#loc1 at #loc292)), %right_idx_706: i32 loc(callsite(#loc1 at #loc292))): + %right_idx_707 = arith.addi %right_idx_705, %right_idx_706 : i32 loc(#loc326) + tt.reduce.return %right_idx_707 : i32 loc(#loc317) + }) : (tensor<2x2x4xi32>) -> tensor<2x4xi32> loc(#loc317) + %right_idx_482 = tt.expand_dims %right_idx_481 {axis = 1 : i32} : tensor<2x4xi32> -> tensor<2x1x4xi32> loc(#loc293) + %right_idx_483 = tt.broadcast %right_idx_482 : tensor<2x1x4xi32> -> tensor<2x2x4xi32> loc(#loc294) + %left_idx_484 = tt.reshape %left_idx_479 : tensor<2x2x4xi32> -> tensor<1x16xi32> loc(#loc295) + %right_idx_485 = tt.reshape %right_idx_483 : tensor<2x2x4xi32> -> tensor<1x16xi32> loc(#loc296) + %cond_486 = arith.cmpi slt, %ileft_473, %iright_474 : tensor<1x16xi32> loc(#loc275) + %eq_487 = arith.cmpi eq, %ileft_473, %iright_474 : tensor<1x16xi32> loc(#loc276) + %cond_488 = arith.cmpi sgt, %left_idx_484, %right_idx_485 : tensor<1x16xi32> loc(#loc297) + %cond_489 = arith.andi %eq_487, %cond_488 : tensor<1x16xi1> loc(#loc277) + %cond_490 = arith.ori %cond_486, %cond_489 : tensor<1x16xi1> loc(#loc278) + %cond_491 = arith.extui %cond_490 : tensor<1x16xi1> to tensor<1x16xi32> loc(#loc279) + %cond_492 = arith.xori %cond_491, %flip_129 : tensor<1x16xi32> loc(#loc279) + %cond_493 = arith.cmpi ne, %cond_492, %cst_1 : tensor<1x16xi32> loc(#loc280) + %ret_494 = arith.xori %ileft_473, %iright_474 : tensor<1x16xi32> loc(#loc281) + %ret_495 = arith.select %cond_493, %ret_494, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc282) + %ret_496 = arith.xori %ret_460, %ret_495 : tensor<1x16xi32> loc(#loc283) + %new_idxs_497 = arith.xori %left_idx_484, %right_idx_485 : tensor<1x16xi32> loc(#loc298) + %new_idxs_498 = arith.select %cond_493, %new_idxs_497, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc284) + %new_idxs_499 = arith.xori %new_idxs_463, %new_idxs_498 : tensor<1x16xi32> loc(#loc285) + %y_500 = tt.reshape %ret_496 : tensor<1x16xi32> -> tensor<4x2x2xi32> loc(#loc264) + %ileft_501 = arith.muli %y_500, %ileft_56 : tensor<4x2x2xi32> loc(#loc265) + %ileft_502 = "tt.reduce"(%ileft_501) <{axis = 1 : i32}> ({ + ^bb0(%ileft_705: i32 loc(callsite(#loc1 at #loc266)), %ileft_706: i32 loc(callsite(#loc1 at #loc266))): + %ileft_707 = arith.addi %ileft_705, %ileft_706 : i32 loc(#loc323) + tt.reduce.return %ileft_707 : i32 loc(#loc311) + }) : (tensor<4x2x2xi32>) -> tensor<4x2xi32> loc(#loc311) + %ileft_503 = tt.expand_dims %ileft_502 {axis = 1 : i32} : tensor<4x2xi32> -> tensor<4x1x2xi32> loc(#loc267) + %ileft_504 = tt.broadcast %ileft_503 : tensor<4x1x2xi32> -> tensor<4x2x2xi32> loc(#loc268) + %iright_505 = arith.muli %y_500, %flip_17 : tensor<4x2x2xi32> loc(#loc269) + %iright_506 = "tt.reduce"(%iright_505) <{axis = 1 : i32}> ({ + ^bb0(%iright_705: i32 loc(callsite(#loc1 at #loc270)), %iright_706: i32 loc(callsite(#loc1 at #loc270))): + %iright_707 = arith.addi %iright_705, %iright_706 : i32 loc(#loc324) + tt.reduce.return %iright_707 : i32 loc(#loc313) + }) : (tensor<4x2x2xi32>) -> tensor<4x2xi32> loc(#loc313) + %iright_507 = tt.expand_dims %iright_506 {axis = 1 : i32} : tensor<4x2xi32> -> tensor<4x1x2xi32> loc(#loc271) + %iright_508 = tt.broadcast %iright_507 : tensor<4x1x2xi32> -> tensor<4x2x2xi32> loc(#loc272) + %ileft_509 = tt.reshape %ileft_504 : tensor<4x2x2xi32> -> tensor<1x16xi32> loc(#loc273) + %iright_510 = tt.reshape %iright_508 : tensor<4x2x2xi32> -> tensor<1x16xi32> loc(#loc274) + %y_idx_511 = tt.reshape %new_idxs_499 : tensor<1x16xi32> -> tensor<4x2x2xi32> loc(#loc286) + %left_idx_512 = arith.muli %y_idx_511, %ileft_56 : tensor<4x2x2xi32> loc(#loc287) + %left_idx_513 = "tt.reduce"(%left_idx_512) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_705: i32 loc(callsite(#loc1 at #loc288)), %left_idx_706: i32 loc(callsite(#loc1 at #loc288))): + %left_idx_707 = arith.addi %left_idx_705, %left_idx_706 : i32 loc(#loc325) + tt.reduce.return %left_idx_707 : i32 loc(#loc315) + }) : (tensor<4x2x2xi32>) -> tensor<4x2xi32> loc(#loc315) + %left_idx_514 = tt.expand_dims %left_idx_513 {axis = 1 : i32} : tensor<4x2xi32> -> tensor<4x1x2xi32> loc(#loc289) + %left_idx_515 = tt.broadcast %left_idx_514 : tensor<4x1x2xi32> -> tensor<4x2x2xi32> loc(#loc290) + %right_idx_516 = arith.muli %y_idx_511, %flip_17 : tensor<4x2x2xi32> loc(#loc291) + %right_idx_517 = "tt.reduce"(%right_idx_516) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_705: i32 loc(callsite(#loc1 at #loc292)), %right_idx_706: i32 loc(callsite(#loc1 at #loc292))): + %right_idx_707 = arith.addi %right_idx_705, %right_idx_706 : i32 loc(#loc326) + tt.reduce.return %right_idx_707 : i32 loc(#loc317) + }) : (tensor<4x2x2xi32>) -> tensor<4x2xi32> loc(#loc317) + %right_idx_518 = tt.expand_dims %right_idx_517 {axis = 1 : i32} : tensor<4x2xi32> -> tensor<4x1x2xi32> loc(#loc293) + %right_idx_519 = tt.broadcast %right_idx_518 : tensor<4x1x2xi32> -> tensor<4x2x2xi32> loc(#loc294) + %left_idx_520 = tt.reshape %left_idx_515 : tensor<4x2x2xi32> -> tensor<1x16xi32> loc(#loc295) + %right_idx_521 = tt.reshape %right_idx_519 : tensor<4x2x2xi32> -> tensor<1x16xi32> loc(#loc296) + %cond_522 = arith.cmpi slt, %ileft_509, %iright_510 : tensor<1x16xi32> loc(#loc275) + %eq_523 = arith.cmpi eq, %ileft_509, %iright_510 : tensor<1x16xi32> loc(#loc276) + %cond_524 = arith.cmpi sgt, %left_idx_520, %right_idx_521 : tensor<1x16xi32> loc(#loc297) + %cond_525 = arith.andi %eq_523, %cond_524 : tensor<1x16xi1> loc(#loc277) + %cond_526 = arith.ori %cond_522, %cond_525 : tensor<1x16xi1> loc(#loc278) + %cond_527 = arith.extui %cond_526 : tensor<1x16xi1> to tensor<1x16xi32> loc(#loc279) + %cond_528 = arith.xori %cond_527, %flip_129 : tensor<1x16xi32> loc(#loc279) + %cond_529 = arith.cmpi ne, %cond_528, %cst_1 : tensor<1x16xi32> loc(#loc280) + %ret_530 = arith.xori %ileft_509, %iright_510 : tensor<1x16xi32> loc(#loc281) + %ret_531 = arith.select %cond_529, %ret_530, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc282) + %ret_532 = arith.xori %ret_496, %ret_531 : tensor<1x16xi32> loc(#loc283) + %new_idxs_533 = arith.xori %left_idx_520, %right_idx_521 : tensor<1x16xi32> loc(#loc298) + %new_idxs_534 = arith.select %cond_529, %new_idxs_533, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc284) + %new_idxs_535 = arith.xori %new_idxs_499, %new_idxs_534 : tensor<1x16xi32> loc(#loc285) + %y_536 = tt.reshape %ret_532 : tensor<1x16xi32> -> tensor<8x2x1xi32> loc(#loc264) + %ileft_537 = arith.muli %y_536, %ileft : tensor<8x2x1xi32> loc(#loc265) + %ileft_538 = "tt.reduce"(%ileft_537) <{axis = 1 : i32}> ({ + ^bb0(%ileft_705: i32 loc(callsite(#loc1 at #loc266)), %ileft_706: i32 loc(callsite(#loc1 at #loc266))): + %ileft_707 = arith.addi %ileft_705, %ileft_706 : i32 loc(#loc323) + tt.reduce.return %ileft_707 : i32 loc(#loc311) + }) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc311) + %ileft_539 = tt.expand_dims %ileft_538 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc267) + %ileft_540 = tt.broadcast %ileft_539 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc268) + %iright_541 = arith.muli %y_536, %iright : tensor<8x2x1xi32> loc(#loc269) + %iright_542 = "tt.reduce"(%iright_541) <{axis = 1 : i32}> ({ + ^bb0(%iright_705: i32 loc(callsite(#loc1 at #loc270)), %iright_706: i32 loc(callsite(#loc1 at #loc270))): + %iright_707 = arith.addi %iright_705, %iright_706 : i32 loc(#loc324) + tt.reduce.return %iright_707 : i32 loc(#loc313) + }) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc313) + %iright_543 = tt.expand_dims %iright_542 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc271) + %iright_544 = tt.broadcast %iright_543 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc272) + %ileft_545 = tt.reshape %ileft_540 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc273) + %iright_546 = tt.reshape %iright_544 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc274) + %y_idx_547 = tt.reshape %new_idxs_535 : tensor<1x16xi32> -> tensor<8x2x1xi32> loc(#loc286) + %left_idx_548 = arith.muli %y_idx_547, %ileft : tensor<8x2x1xi32> loc(#loc287) + %left_idx_549 = "tt.reduce"(%left_idx_548) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_705: i32 loc(callsite(#loc1 at #loc288)), %left_idx_706: i32 loc(callsite(#loc1 at #loc288))): + %left_idx_707 = arith.addi %left_idx_705, %left_idx_706 : i32 loc(#loc325) + tt.reduce.return %left_idx_707 : i32 loc(#loc315) + }) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc315) + %left_idx_550 = tt.expand_dims %left_idx_549 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc289) + %left_idx_551 = tt.broadcast %left_idx_550 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc290) + %right_idx_552 = arith.muli %y_idx_547, %iright : tensor<8x2x1xi32> loc(#loc291) + %right_idx_553 = "tt.reduce"(%right_idx_552) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_705: i32 loc(callsite(#loc1 at #loc292)), %right_idx_706: i32 loc(callsite(#loc1 at #loc292))): + %right_idx_707 = arith.addi %right_idx_705, %right_idx_706 : i32 loc(#loc326) + tt.reduce.return %right_idx_707 : i32 loc(#loc317) + }) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc317) + %right_idx_554 = tt.expand_dims %right_idx_553 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc293) + %right_idx_555 = tt.broadcast %right_idx_554 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc294) + %left_idx_556 = tt.reshape %left_idx_551 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc295) + %right_idx_557 = tt.reshape %right_idx_555 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc296) + %cond_558 = arith.cmpi slt, %ileft_545, %iright_546 : tensor<1x16xi32> loc(#loc275) + %eq_559 = arith.cmpi eq, %ileft_545, %iright_546 : tensor<1x16xi32> loc(#loc276) + %cond_560 = arith.cmpi sgt, %left_idx_556, %right_idx_557 : tensor<1x16xi32> loc(#loc297) + %cond_561 = arith.andi %eq_559, %cond_560 : tensor<1x16xi1> loc(#loc277) + %cond_562 = arith.ori %cond_558, %cond_561 : tensor<1x16xi1> loc(#loc278) + %cond_563 = arith.extui %cond_562 : tensor<1x16xi1> to tensor<1x16xi32> loc(#loc279) + %cond_564 = arith.xori %cond_563, %flip_129 : tensor<1x16xi32> loc(#loc279) + %cond_565 = arith.cmpi ne, %cond_564, %cst_1 : tensor<1x16xi32> loc(#loc280) + %ret_566 = arith.xori %ileft_545, %iright_546 : tensor<1x16xi32> loc(#loc281) + %ret_567 = arith.select %cond_565, %ret_566, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc282) + %ret_568 = arith.xori %ret_532, %ret_567 : tensor<1x16xi32> loc(#loc283) + %new_idxs_569 = arith.xori %left_idx_556, %right_idx_557 : tensor<1x16xi32> loc(#loc298) + %new_idxs_570 = arith.select %cond_565, %new_idxs_569, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc284) + %new_idxs_571 = arith.xori %new_idxs_535, %new_idxs_570 : tensor<1x16xi32> loc(#loc285) + %y_572 = tt.reshape %ret_568 : tensor<1x16xi32> -> tensor<1x2x8xi32> loc(#loc264) + %ileft_573 = arith.muli %y_572, %ileft_240 : tensor<1x2x8xi32> loc(#loc265) + %ileft_574 = "tt.reduce"(%ileft_573) <{axis = 1 : i32}> ({ + ^bb0(%ileft_705: i32 loc(callsite(#loc1 at #loc266)), %ileft_706: i32 loc(callsite(#loc1 at #loc266))): + %ileft_707 = arith.addi %ileft_705, %ileft_706 : i32 loc(#loc323) + tt.reduce.return %ileft_707 : i32 loc(#loc311) + }) : (tensor<1x2x8xi32>) -> tensor<1x8xi32> loc(#loc311) + %ileft_575 = tt.expand_dims %ileft_574 {axis = 1 : i32} : tensor<1x8xi32> -> tensor<1x1x8xi32> loc(#loc267) + %ileft_576 = tt.broadcast %ileft_575 : tensor<1x1x8xi32> -> tensor<1x2x8xi32> loc(#loc268) + %iright_577 = arith.muli %y_572, %flip_128 : tensor<1x2x8xi32> loc(#loc269) + %iright_578 = "tt.reduce"(%iright_577) <{axis = 1 : i32}> ({ + ^bb0(%iright_705: i32 loc(callsite(#loc1 at #loc270)), %iright_706: i32 loc(callsite(#loc1 at #loc270))): + %iright_707 = arith.addi %iright_705, %iright_706 : i32 loc(#loc324) + tt.reduce.return %iright_707 : i32 loc(#loc313) + }) : (tensor<1x2x8xi32>) -> tensor<1x8xi32> loc(#loc313) + %iright_579 = tt.expand_dims %iright_578 {axis = 1 : i32} : tensor<1x8xi32> -> tensor<1x1x8xi32> loc(#loc271) + %iright_580 = tt.broadcast %iright_579 : tensor<1x1x8xi32> -> tensor<1x2x8xi32> loc(#loc272) + %ileft_581 = tt.reshape %ileft_576 : tensor<1x2x8xi32> -> tensor<1x16xi32> loc(#loc273) + %iright_582 = tt.reshape %iright_580 : tensor<1x2x8xi32> -> tensor<1x16xi32> loc(#loc274) + %y_idx_583 = tt.reshape %new_idxs_571 : tensor<1x16xi32> -> tensor<1x2x8xi32> loc(#loc286) + %left_idx_584 = arith.muli %y_idx_583, %ileft_240 : tensor<1x2x8xi32> loc(#loc287) + %left_idx_585 = "tt.reduce"(%left_idx_584) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_705: i32 loc(callsite(#loc1 at #loc288)), %left_idx_706: i32 loc(callsite(#loc1 at #loc288))): + %left_idx_707 = arith.addi %left_idx_705, %left_idx_706 : i32 loc(#loc325) + tt.reduce.return %left_idx_707 : i32 loc(#loc315) + }) : (tensor<1x2x8xi32>) -> tensor<1x8xi32> loc(#loc315) + %left_idx_586 = tt.expand_dims %left_idx_585 {axis = 1 : i32} : tensor<1x8xi32> -> tensor<1x1x8xi32> loc(#loc289) + %left_idx_587 = tt.broadcast %left_idx_586 : tensor<1x1x8xi32> -> tensor<1x2x8xi32> loc(#loc290) + %right_idx_588 = arith.muli %y_idx_583, %flip_128 : tensor<1x2x8xi32> loc(#loc291) + %right_idx_589 = "tt.reduce"(%right_idx_588) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_705: i32 loc(callsite(#loc1 at #loc292)), %right_idx_706: i32 loc(callsite(#loc1 at #loc292))): + %right_idx_707 = arith.addi %right_idx_705, %right_idx_706 : i32 loc(#loc326) + tt.reduce.return %right_idx_707 : i32 loc(#loc317) + }) : (tensor<1x2x8xi32>) -> tensor<1x8xi32> loc(#loc317) + %right_idx_590 = tt.expand_dims %right_idx_589 {axis = 1 : i32} : tensor<1x8xi32> -> tensor<1x1x8xi32> loc(#loc293) + %right_idx_591 = tt.broadcast %right_idx_590 : tensor<1x1x8xi32> -> tensor<1x2x8xi32> loc(#loc294) + %left_idx_592 = tt.reshape %left_idx_587 : tensor<1x2x8xi32> -> tensor<1x16xi32> loc(#loc295) + %right_idx_593 = tt.reshape %right_idx_591 : tensor<1x2x8xi32> -> tensor<1x16xi32> loc(#loc296) + %cond_594 = arith.cmpi slt, %ileft_581, %iright_582 : tensor<1x16xi32> loc(#loc275) + %eq_595 = arith.cmpi eq, %ileft_581, %iright_582 : tensor<1x16xi32> loc(#loc276) + %cond_596 = arith.cmpi sgt, %left_idx_592, %right_idx_593 : tensor<1x16xi32> loc(#loc297) + %cond_597 = arith.andi %eq_595, %cond_596 : tensor<1x16xi1> loc(#loc277) + %cond_598 = arith.ori %cond_594, %cond_597 : tensor<1x16xi1> loc(#loc278) + %ret_599 = arith.xori %ileft_581, %iright_582 : tensor<1x16xi32> loc(#loc281) + %ret_600 = arith.select %cond_598, %ret_599, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc282) + %ret_601 = arith.xori %ret_568, %ret_600 : tensor<1x16xi32> loc(#loc283) + %new_idxs_602 = arith.xori %left_idx_592, %right_idx_593 : tensor<1x16xi32> loc(#loc298) + %new_idxs_603 = arith.select %cond_598, %new_idxs_602, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc284) + %new_idxs_604 = arith.xori %new_idxs_571, %new_idxs_603 : tensor<1x16xi32> loc(#loc285) + %y_605 = tt.reshape %ret_601 : tensor<1x16xi32> -> tensor<2x2x4xi32> loc(#loc264) + %ileft_606 = arith.muli %y_605, %ileft_131 : tensor<2x2x4xi32> loc(#loc265) + %ileft_607 = "tt.reduce"(%ileft_606) <{axis = 1 : i32}> ({ + ^bb0(%ileft_705: i32 loc(callsite(#loc1 at #loc266)), %ileft_706: i32 loc(callsite(#loc1 at #loc266))): + %ileft_707 = arith.addi %ileft_705, %ileft_706 : i32 loc(#loc323) + tt.reduce.return %ileft_707 : i32 loc(#loc311) + }) : (tensor<2x2x4xi32>) -> tensor<2x4xi32> loc(#loc311) + %ileft_608 = tt.expand_dims %ileft_607 {axis = 1 : i32} : tensor<2x4xi32> -> tensor<2x1x4xi32> loc(#loc267) + %ileft_609 = tt.broadcast %ileft_608 : tensor<2x1x4xi32> -> tensor<2x2x4xi32> loc(#loc268) + %iright_610 = arith.muli %y_605, %flip_53 : tensor<2x2x4xi32> loc(#loc269) + %iright_611 = "tt.reduce"(%iright_610) <{axis = 1 : i32}> ({ + ^bb0(%iright_705: i32 loc(callsite(#loc1 at #loc270)), %iright_706: i32 loc(callsite(#loc1 at #loc270))): + %iright_707 = arith.addi %iright_705, %iright_706 : i32 loc(#loc324) + tt.reduce.return %iright_707 : i32 loc(#loc313) + }) : (tensor<2x2x4xi32>) -> tensor<2x4xi32> loc(#loc313) + %iright_612 = tt.expand_dims %iright_611 {axis = 1 : i32} : tensor<2x4xi32> -> tensor<2x1x4xi32> loc(#loc271) + %iright_613 = tt.broadcast %iright_612 : tensor<2x1x4xi32> -> tensor<2x2x4xi32> loc(#loc272) + %ileft_614 = tt.reshape %ileft_609 : tensor<2x2x4xi32> -> tensor<1x16xi32> loc(#loc273) + %iright_615 = tt.reshape %iright_613 : tensor<2x2x4xi32> -> tensor<1x16xi32> loc(#loc274) + %y_idx_616 = tt.reshape %new_idxs_604 : tensor<1x16xi32> -> tensor<2x2x4xi32> loc(#loc286) + %left_idx_617 = arith.muli %y_idx_616, %ileft_131 : tensor<2x2x4xi32> loc(#loc287) + %left_idx_618 = "tt.reduce"(%left_idx_617) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_705: i32 loc(callsite(#loc1 at #loc288)), %left_idx_706: i32 loc(callsite(#loc1 at #loc288))): + %left_idx_707 = arith.addi %left_idx_705, %left_idx_706 : i32 loc(#loc325) + tt.reduce.return %left_idx_707 : i32 loc(#loc315) + }) : (tensor<2x2x4xi32>) -> tensor<2x4xi32> loc(#loc315) + %left_idx_619 = tt.expand_dims %left_idx_618 {axis = 1 : i32} : tensor<2x4xi32> -> tensor<2x1x4xi32> loc(#loc289) + %left_idx_620 = tt.broadcast %left_idx_619 : tensor<2x1x4xi32> -> tensor<2x2x4xi32> loc(#loc290) + %right_idx_621 = arith.muli %y_idx_616, %flip_53 : tensor<2x2x4xi32> loc(#loc291) + %right_idx_622 = "tt.reduce"(%right_idx_621) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_705: i32 loc(callsite(#loc1 at #loc292)), %right_idx_706: i32 loc(callsite(#loc1 at #loc292))): + %right_idx_707 = arith.addi %right_idx_705, %right_idx_706 : i32 loc(#loc326) + tt.reduce.return %right_idx_707 : i32 loc(#loc317) + }) : (tensor<2x2x4xi32>) -> tensor<2x4xi32> loc(#loc317) + %right_idx_623 = tt.expand_dims %right_idx_622 {axis = 1 : i32} : tensor<2x4xi32> -> tensor<2x1x4xi32> loc(#loc293) + %right_idx_624 = tt.broadcast %right_idx_623 : tensor<2x1x4xi32> -> tensor<2x2x4xi32> loc(#loc294) + %left_idx_625 = tt.reshape %left_idx_620 : tensor<2x2x4xi32> -> tensor<1x16xi32> loc(#loc295) + %right_idx_626 = tt.reshape %right_idx_624 : tensor<2x2x4xi32> -> tensor<1x16xi32> loc(#loc296) + %cond_627 = arith.cmpi slt, %ileft_614, %iright_615 : tensor<1x16xi32> loc(#loc275) + %eq_628 = arith.cmpi eq, %ileft_614, %iright_615 : tensor<1x16xi32> loc(#loc276) + %cond_629 = arith.cmpi sgt, %left_idx_625, %right_idx_626 : tensor<1x16xi32> loc(#loc297) + %cond_630 = arith.andi %eq_628, %cond_629 : tensor<1x16xi1> loc(#loc277) + %cond_631 = arith.ori %cond_627, %cond_630 : tensor<1x16xi1> loc(#loc278) + %ret_632 = arith.xori %ileft_614, %iright_615 : tensor<1x16xi32> loc(#loc281) + %ret_633 = arith.select %cond_631, %ret_632, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc282) + %ret_634 = arith.xori %ret_601, %ret_633 : tensor<1x16xi32> loc(#loc283) + %new_idxs_635 = arith.xori %left_idx_625, %right_idx_626 : tensor<1x16xi32> loc(#loc298) + %new_idxs_636 = arith.select %cond_631, %new_idxs_635, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc284) + %new_idxs_637 = arith.xori %new_idxs_604, %new_idxs_636 : tensor<1x16xi32> loc(#loc285) + %y_638 = tt.reshape %ret_634 : tensor<1x16xi32> -> tensor<4x2x2xi32> loc(#loc264) + %ileft_639 = arith.muli %y_638, %ileft_56 : tensor<4x2x2xi32> loc(#loc265) + %ileft_640 = "tt.reduce"(%ileft_639) <{axis = 1 : i32}> ({ + ^bb0(%ileft_705: i32 loc(callsite(#loc1 at #loc266)), %ileft_706: i32 loc(callsite(#loc1 at #loc266))): + %ileft_707 = arith.addi %ileft_705, %ileft_706 : i32 loc(#loc323) + tt.reduce.return %ileft_707 : i32 loc(#loc311) + }) : (tensor<4x2x2xi32>) -> tensor<4x2xi32> loc(#loc311) + %ileft_641 = tt.expand_dims %ileft_640 {axis = 1 : i32} : tensor<4x2xi32> -> tensor<4x1x2xi32> loc(#loc267) + %ileft_642 = tt.broadcast %ileft_641 : tensor<4x1x2xi32> -> tensor<4x2x2xi32> loc(#loc268) + %iright_643 = arith.muli %y_638, %flip_17 : tensor<4x2x2xi32> loc(#loc269) + %iright_644 = "tt.reduce"(%iright_643) <{axis = 1 : i32}> ({ + ^bb0(%iright_705: i32 loc(callsite(#loc1 at #loc270)), %iright_706: i32 loc(callsite(#loc1 at #loc270))): + %iright_707 = arith.addi %iright_705, %iright_706 : i32 loc(#loc324) + tt.reduce.return %iright_707 : i32 loc(#loc313) + }) : (tensor<4x2x2xi32>) -> tensor<4x2xi32> loc(#loc313) + %iright_645 = tt.expand_dims %iright_644 {axis = 1 : i32} : tensor<4x2xi32> -> tensor<4x1x2xi32> loc(#loc271) + %iright_646 = tt.broadcast %iright_645 : tensor<4x1x2xi32> -> tensor<4x2x2xi32> loc(#loc272) + %ileft_647 = tt.reshape %ileft_642 : tensor<4x2x2xi32> -> tensor<1x16xi32> loc(#loc273) + %iright_648 = tt.reshape %iright_646 : tensor<4x2x2xi32> -> tensor<1x16xi32> loc(#loc274) + %y_idx_649 = tt.reshape %new_idxs_637 : tensor<1x16xi32> -> tensor<4x2x2xi32> loc(#loc286) + %left_idx_650 = arith.muli %y_idx_649, %ileft_56 : tensor<4x2x2xi32> loc(#loc287) + %left_idx_651 = "tt.reduce"(%left_idx_650) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_705: i32 loc(callsite(#loc1 at #loc288)), %left_idx_706: i32 loc(callsite(#loc1 at #loc288))): + %left_idx_707 = arith.addi %left_idx_705, %left_idx_706 : i32 loc(#loc325) + tt.reduce.return %left_idx_707 : i32 loc(#loc315) + }) : (tensor<4x2x2xi32>) -> tensor<4x2xi32> loc(#loc315) + %left_idx_652 = tt.expand_dims %left_idx_651 {axis = 1 : i32} : tensor<4x2xi32> -> tensor<4x1x2xi32> loc(#loc289) + %left_idx_653 = tt.broadcast %left_idx_652 : tensor<4x1x2xi32> -> tensor<4x2x2xi32> loc(#loc290) + %right_idx_654 = arith.muli %y_idx_649, %flip_17 : tensor<4x2x2xi32> loc(#loc291) + %right_idx_655 = "tt.reduce"(%right_idx_654) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_705: i32 loc(callsite(#loc1 at #loc292)), %right_idx_706: i32 loc(callsite(#loc1 at #loc292))): + %right_idx_707 = arith.addi %right_idx_705, %right_idx_706 : i32 loc(#loc326) + tt.reduce.return %right_idx_707 : i32 loc(#loc317) + }) : (tensor<4x2x2xi32>) -> tensor<4x2xi32> loc(#loc317) + %right_idx_656 = tt.expand_dims %right_idx_655 {axis = 1 : i32} : tensor<4x2xi32> -> tensor<4x1x2xi32> loc(#loc293) + %right_idx_657 = tt.broadcast %right_idx_656 : tensor<4x1x2xi32> -> tensor<4x2x2xi32> loc(#loc294) + %left_idx_658 = tt.reshape %left_idx_653 : tensor<4x2x2xi32> -> tensor<1x16xi32> loc(#loc295) + %right_idx_659 = tt.reshape %right_idx_657 : tensor<4x2x2xi32> -> tensor<1x16xi32> loc(#loc296) + %cond_660 = arith.cmpi slt, %ileft_647, %iright_648 : tensor<1x16xi32> loc(#loc275) + %eq_661 = arith.cmpi eq, %ileft_647, %iright_648 : tensor<1x16xi32> loc(#loc276) + %cond_662 = arith.cmpi sgt, %left_idx_658, %right_idx_659 : tensor<1x16xi32> loc(#loc297) + %cond_663 = arith.andi %eq_661, %cond_662 : tensor<1x16xi1> loc(#loc277) + %cond_664 = arith.ori %cond_660, %cond_663 : tensor<1x16xi1> loc(#loc278) + %ret_665 = arith.xori %ileft_647, %iright_648 : tensor<1x16xi32> loc(#loc281) + %ret_666 = arith.select %cond_664, %ret_665, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc282) + %ret_667 = arith.xori %ret_634, %ret_666 : tensor<1x16xi32> loc(#loc283) + %new_idxs_668 = arith.xori %left_idx_658, %right_idx_659 : tensor<1x16xi32> loc(#loc298) + %new_idxs_669 = arith.select %cond_664, %new_idxs_668, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc284) + %new_idxs_670 = arith.xori %new_idxs_637, %new_idxs_669 : tensor<1x16xi32> loc(#loc285) + %y_671 = tt.reshape %ret_667 : tensor<1x16xi32> -> tensor<8x2x1xi32> loc(#loc264) + %ileft_672 = arith.muli %y_671, %ileft : tensor<8x2x1xi32> loc(#loc265) + %ileft_673 = "tt.reduce"(%ileft_672) <{axis = 1 : i32}> ({ + ^bb0(%ileft_705: i32 loc(callsite(#loc1 at #loc266)), %ileft_706: i32 loc(callsite(#loc1 at #loc266))): + %ileft_707 = arith.addi %ileft_705, %ileft_706 : i32 loc(#loc323) + tt.reduce.return %ileft_707 : i32 loc(#loc311) + }) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc311) + %ileft_674 = tt.expand_dims %ileft_673 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc267) + %ileft_675 = tt.broadcast %ileft_674 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc268) + %iright_676 = arith.muli %y_671, %iright : tensor<8x2x1xi32> loc(#loc269) + %iright_677 = "tt.reduce"(%iright_676) <{axis = 1 : i32}> ({ + ^bb0(%iright_705: i32 loc(callsite(#loc1 at #loc270)), %iright_706: i32 loc(callsite(#loc1 at #loc270))): + %iright_707 = arith.addi %iright_705, %iright_706 : i32 loc(#loc324) + tt.reduce.return %iright_707 : i32 loc(#loc313) + }) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc313) + %iright_678 = tt.expand_dims %iright_677 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc271) + %iright_679 = tt.broadcast %iright_678 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc272) + %ileft_680 = tt.reshape %ileft_675 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc273) + %iright_681 = tt.reshape %iright_679 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc274) + %y_idx_682 = tt.reshape %new_idxs_670 : tensor<1x16xi32> -> tensor<8x2x1xi32> loc(#loc286) + %left_idx_683 = arith.muli %y_idx_682, %ileft : tensor<8x2x1xi32> loc(#loc287) + %left_idx_684 = "tt.reduce"(%left_idx_683) <{axis = 1 : i32}> ({ + ^bb0(%left_idx_705: i32 loc(callsite(#loc1 at #loc288)), %left_idx_706: i32 loc(callsite(#loc1 at #loc288))): + %left_idx_707 = arith.addi %left_idx_705, %left_idx_706 : i32 loc(#loc325) + tt.reduce.return %left_idx_707 : i32 loc(#loc315) + }) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc315) + %left_idx_685 = tt.expand_dims %left_idx_684 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc289) + %left_idx_686 = tt.broadcast %left_idx_685 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc290) + %right_idx_687 = arith.muli %y_idx_682, %iright : tensor<8x2x1xi32> loc(#loc291) + %right_idx_688 = "tt.reduce"(%right_idx_687) <{axis = 1 : i32}> ({ + ^bb0(%right_idx_705: i32 loc(callsite(#loc1 at #loc292)), %right_idx_706: i32 loc(callsite(#loc1 at #loc292))): + %right_idx_707 = arith.addi %right_idx_705, %right_idx_706 : i32 loc(#loc326) + tt.reduce.return %right_idx_707 : i32 loc(#loc317) + }) : (tensor<8x2x1xi32>) -> tensor<8x1xi32> loc(#loc317) + %right_idx_689 = tt.expand_dims %right_idx_688 {axis = 1 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc293) + %right_idx_690 = tt.broadcast %right_idx_689 : tensor<8x1x1xi32> -> tensor<8x2x1xi32> loc(#loc294) + %left_idx_691 = tt.reshape %left_idx_686 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc295) + %right_idx_692 = tt.reshape %right_idx_690 : tensor<8x2x1xi32> -> tensor<1x16xi32> loc(#loc296) + %cond_693 = arith.cmpi slt, %ileft_680, %iright_681 : tensor<1x16xi32> loc(#loc275) + %eq_694 = arith.cmpi eq, %ileft_680, %iright_681 : tensor<1x16xi32> loc(#loc276) + %cond_695 = arith.cmpi sgt, %left_idx_691, %right_idx_692 : tensor<1x16xi32> loc(#loc297) + %cond_696 = arith.andi %eq_694, %cond_695 : tensor<1x16xi1> loc(#loc277) + %cond_697 = arith.ori %cond_693, %cond_696 : tensor<1x16xi1> loc(#loc278) + %new_idxs_698 = arith.xori %left_idx_691, %right_idx_692 : tensor<1x16xi32> loc(#loc298) + %new_idxs_699 = arith.select %cond_697, %new_idxs_698, %cst_1 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc284) + %new_idxs_700 = arith.xori %new_idxs_670, %new_idxs_699 : tensor<1x16xi32> loc(#loc285) + %tmp20 = arith.extui %tmp5 : tensor<1x16xi1> to tensor<1x16xi64> loc(#loc220) + %tmp23 = arith.select %tmp0_13, %tmp20, %cst_5 : tensor<1x16xi1>, tensor<1x16xi64> loc(#loc187) + %tmp24 = "tt.reduce"(%tmp23) <{axis = 1 : i32}> ({ + ^bb0(%tmp24_705: i64 loc(callsite(#loc1 at #loc188)), %tmp24_706: i64 loc(callsite(#loc1 at #loc188))): + %tmp24_707 = arith.addi %tmp24_705, %tmp24_706 : i64 loc(#loc299) + tt.reduce.return %tmp24_707 : i64 loc(#loc221) + }) : (tensor<1x16xi64>) -> tensor<1xi64> loc(#loc221) + %tmp24_701 = tt.expand_dims %tmp24 {axis = 1 : i32} : tensor<1xi64> -> tensor<1x1xi64> loc(#loc189) + %tmp25 = arith.extui %tmp14 : tensor<1x16xi1> to tensor<1x16xi64> loc(#loc223) + %tmp28 = arith.select %tmp0_13, %tmp25, %cst_5 : tensor<1x16xi1>, tensor<1x16xi64> loc(#loc191) + %tmp29 = "tt.reduce"(%tmp28) <{axis = 1 : i32}> ({ + ^bb0(%tmp29_705: i64 loc(callsite(#loc1 at #loc192)), %tmp29_706: i64 loc(callsite(#loc1 at #loc192))): + %tmp29_707 = arith.addi %tmp29_705, %tmp29_706 : i64 loc(#loc300) + tt.reduce.return %tmp29_707 : i64 loc(#loc224) + }) : (tensor<1x16xi64>) -> tensor<1xi64> loc(#loc224) + %tmp29_702 = tt.expand_dims %tmp29 {axis = 1 : i32} : tensor<1xi64> -> tensor<1x1xi64> loc(#loc193) + %tmp30 = arith.trunci %tmp24_701 : tensor<1x1xi64> to tensor<1x1xi32> loc(#loc194) + %tmp31 = arith.trunci %tmp29_702 : tensor<1x1xi64> to tensor<1x1xi32> loc(#loc195) + %tmp34 = tt.broadcast %tmp30 : tensor<1x1xi32> -> tensor<1x16xi32> loc(#loc196) + %tmp34_703 = arith.cmpi slt, %r0_index_8, %tmp34 : tensor<1x16xi32> loc(#loc196) + %tmp36 = arith.select %tmp34_703, %new_idxs_368, %cst_3 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc197) + %tmp38 = arith.addi %tmp36, %cst_2 : tensor<1x16xi32> loc(#loc198) + %tmp39 = arith.cmpi slt, %tmp36, %cst_1 : tensor<1x16xi32> loc(#loc199) + %tmp40 = arith.select %tmp39, %tmp38, %tmp36 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc200) + %0 = arith.cmpi sge, %tmp40, %cst_1 : tensor<1x16xi32> loc(#loc83) + %1 = arith.cmpi slt, %tmp40, %cst_2 : tensor<1x16xi32> loc(#loc84) + %2 = arith.andi %0, %1 : tensor<1x16xi1> loc(#loc85) + %3 = arith.xori %xmask_6, %true : i1 loc(#loc86) + %4 = tt.splat %3 : i1 -> tensor<1x16xi1> loc(#loc201) + %5 = arith.ori %2, %4 : tensor<1x16xi1> loc(#loc87) + tt.assert %5, "index out of bounds: 0 <= tmp40 < 17" : tensor<1x16xi1> loc(#loc88) + %tmp45 = tt.broadcast %tmp31 : tensor<1x1xi32> -> tensor<1x16xi32> loc(#loc202) + %tmp45_704 = arith.cmpi slt, %r0_index_8, %tmp45 : tensor<1x16xi32> loc(#loc202) + %tmp46 = arith.select %tmp45_704, %new_idxs_700, %cst_3 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc203) + %tmp47 = arith.addi %tmp46, %cst_2 : tensor<1x16xi32> loc(#loc204) + %tmp48 = arith.cmpi slt, %tmp46, %cst_1 : tensor<1x16xi32> loc(#loc205) + %tmp49 = arith.select %tmp48, %tmp47, %tmp46 : tensor<1x16xi1>, tensor<1x16xi32> loc(#loc206) + %6 = arith.cmpi sge, %tmp49, %cst_1 : tensor<1x16xi32> loc(#loc94) + %7 = arith.cmpi slt, %tmp49, %cst_2 : tensor<1x16xi32> loc(#loc95) + %8 = arith.andi %6, %7 : tensor<1x16xi1> loc(#loc96) + %9 = arith.ori %8, %4 : tensor<1x16xi1> loc(#loc97) + tt.assert %9, "index out of bounds: 0 <= tmp49 < 17" : tensor<1x16xi1> loc(#loc98) + %10 = tt.addptr %out_ptr4, %xoffset : !tt.ptr, i32 loc(#loc99) + %11 = tt.splat %10 : !tt.ptr -> tensor<1x1x!tt.ptr> loc(#loc99) + tt.store %11, %tmp30, %xmask_7 : tensor<1x1x!tt.ptr> loc(#loc100) + %12 = tt.addptr %out_ptr5, %xoffset : !tt.ptr, i32 loc(#loc101) + %13 = tt.splat %12 : !tt.ptr -> tensor<1x1x!tt.ptr> loc(#loc101) + tt.store %13, %tmp31, %xmask_7 : tensor<1x1x!tt.ptr> loc(#loc102) + %14 = tt.splat %out_ptr6 : !tt.ptr -> tensor<1x16x!tt.ptr> loc(#loc103) + %15 = tt.addptr %14, %tmp0_10 : tensor<1x16x!tt.ptr>, tensor<1x16xi32> loc(#loc103) + tt.store %15, %new_idxs_368, %tmp0_13 : tensor<1x16x!tt.ptr> loc(#loc104) + %16 = arith.muli %xoffset, %c17_i32 : i32 loc(#loc105) + %17 = tt.splat %16 : i32 -> tensor<1x16xi32> loc(#loc207) + %18 = arith.addi %tmp40, %17 : tensor<1x16xi32> loc(#loc106) + %19 = tt.splat %out_ptr7 : !tt.ptr -> tensor<1x16x!tt.ptr> loc(#loc107) + %20 = tt.addptr %19, %18 : tensor<1x16x!tt.ptr>, tensor<1x16xi32> loc(#loc107) + tt.store %20, %cst_0, %tmp0_13 : tensor<1x16x!tt.ptr> loc(#loc108) + %21 = tt.splat %out_ptr8 : !tt.ptr -> tensor<1x16x!tt.ptr> loc(#loc109) + %22 = tt.addptr %21, %tmp0_10 : tensor<1x16x!tt.ptr>, tensor<1x16xi32> loc(#loc109) + tt.store %22, %new_idxs_700, %tmp0_13 : tensor<1x16x!tt.ptr> loc(#loc110) + %23 = arith.addi %tmp49, %17 : tensor<1x16xi32> loc(#loc111) + %24 = tt.splat %out_ptr9 : !tt.ptr -> tensor<1x16x!tt.ptr> loc(#loc112) + %25 = tt.addptr %24, %23 : tensor<1x16x!tt.ptr>, tensor<1x16xi32> loc(#loc112) + tt.store %25, %cst_0, %tmp0_13 : tensor<1x16x!tt.ptr> loc(#loc113) + tt.return loc(#loc114) + } loc(#loc) +} loc(#loc) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":26:21) +#loc3 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":24:28) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":27:28) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":27:38) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":34:40) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":34:37) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":34:30) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":34:45) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":36:18) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":38:18) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":39:18) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":41:19) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":40:19) +#loc15 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":43:19) +#loc16 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":627:41) +#loc19 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":627:44) +#loc20 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":627:60) +#loc21 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":627:68) +#loc22 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":533:22) +#loc24 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":537:21) +#loc25 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":538:40) +#loc26 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:36) +#loc28 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:15) +#loc29 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":538:65) +#loc30 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":538:78) +#loc31 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":539:41) +#loc33 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":539:67) +#loc34 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":539:80) +#loc35 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":540:30) +#loc36 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":541:32) +#loc37 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":546:29) +#loc38 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:36) +#loc39 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:23) +#loc40 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":290:25) +#loc42 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:53) +#loc43 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":548:66) +#loc44 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:37) +#loc45 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:23) +#loc47 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:54) +#loc48 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":551:67) +#loc49 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":553:36) +#loc50 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":554:38) +#loc51 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":574:22) +#loc52 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":591:21) +#loc53 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":594:40) +#loc54 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":594:29) +#loc55 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":594:23) +#loc56 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":599:19) +#loc57 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":599:28) +#loc58 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":600:38) +#loc59 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":600:46) +#loc60 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":600:15) +#loc61 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":601:48) +#loc62 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":601:59) +#loc63 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":601:22) +#loc64 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":47:20) +#loc65 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":49:21) +#loc66 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":48:21) +#loc68 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":52:20) +#loc69 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":54:35) +#loc71 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":55:29) +#loc72 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":56:21) +#loc73 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":58:35) +#loc75 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":59:29) +#loc76 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":60:21) +#loc77 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":61:21) +#loc78 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":64:19) +#loc79 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":66:35) +#loc80 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":68:20) +#loc81 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":69:20) +#loc82 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":70:35) +#loc83 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":71:28) +#loc84 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":71:46) +#loc85 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":71:38) +#loc86 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":71:55) +#loc87 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":71:53) +#loc88 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":71:63) +#loc89 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":75:19) +#loc90 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":76:35) +#loc91 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":77:20) +#loc92 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":78:20) +#loc93 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":79:35) +#loc94 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":80:28) +#loc95 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":80:46) +#loc96 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":80:38) +#loc97 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":80:53) +#loc98 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":80:63) +#loc99 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":81:25) +#loc100 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":81:37) +#loc101 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":82:25) +#loc102 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":82:37) +#loc103 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":83:25) +#loc104 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":83:47) +#loc105 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":84:52) +#loc106 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":84:49) +#loc107 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":84:25) +#loc108 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":84:85) +#loc109 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":85:25) +#loc110 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":85:47) +#loc111 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":86:49) +#loc112 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":86:25) +#loc113 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":86:85) +#loc114 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bs/cbs652i7ct57ugx6vnc35n63tinpzu6ba3zymufodig4hpzarwcl.py":86:4) +#loc124 = loc("xmask"(#loc2)) +#loc125 = loc("xoffset"(#loc3)) +#loc126 = loc("r0_index"(#loc4)) +#loc127 = loc("r0_index"(#loc5)) +#loc128 = loc("tmp0"(#loc6)) +#loc129 = loc("tmp0"(#loc7)) +#loc130 = loc("tmp0"(#loc8)) +#loc131 = loc("tmp0"(#loc9)) +#loc132 = loc("tmp2"(#loc10)) +#loc133 = loc("tmp4"(#loc11)) +#loc134 = loc("tmp5"(#loc12)) +#loc135 = loc("tmp7"(#loc13)) +#loc136 = loc("tmp6"(#loc14)) +#loc137 = loc("tmp9"(#loc15)) +#loc138 = loc("flip"(#loc16)) +#loc140 = loc("flip"(#loc19)) +#loc141 = loc("flip"(#loc20)) +#loc142 = loc("flip"(#loc21)) +#loc143 = loc("y"(#loc22)) +#loc144 = loc("left_mask"(#loc24)) +#loc145 = loc("ileft"(#loc25)) +#loc147 = loc("ileft"(#loc29)) +#loc148 = loc("ileft"(#loc30)) +#loc149 = loc("iright"(#loc31)) +#loc151 = loc("iright"(#loc33)) +#loc152 = loc("iright"(#loc34)) +#loc153 = loc("ileft"(#loc35)) +#loc154 = loc("iright"(#loc36)) +#loc155 = loc("y_idx"(#loc37)) +#loc156 = loc("left_idx"(#loc38)) +#loc157 = loc("left_idx"(#loc39)) +#loc158 = loc("input"(#loc40)) +#loc160 = loc("left_idx"(#loc42)) +#loc161 = loc("left_idx"(#loc43)) +#loc162 = loc("right_idx"(#loc44)) +#loc163 = loc("right_idx"(#loc45)) +#loc165 = loc("right_idx"(#loc47)) +#loc166 = loc("right_idx"(#loc48)) +#loc167 = loc("left_idx"(#loc49)) +#loc168 = loc("right_idx"(#loc50)) +#loc169 = loc("cond"(#loc51)) +#loc170 = loc("eq"(#loc52)) +#loc171 = loc("cond"(#loc53)) +#loc172 = loc("cond"(#loc54)) +#loc173 = loc("cond"(#loc55)) +#loc174 = loc("cond"(#loc56)) +#loc175 = loc("cond"(#loc57)) +#loc176 = loc("ret"(#loc58)) +#loc177 = loc("ret"(#loc59)) +#loc178 = loc("ret"(#loc60)) +#loc179 = loc("new_idxs"(#loc61)) +#loc180 = loc("new_idxs"(#loc62)) +#loc181 = loc("new_idxs"(#loc63)) +#loc182 = loc("tmp14"(#loc64)) +#loc183 = loc("tmp16"(#loc65)) +#loc184 = loc("tmp15"(#loc66)) +#loc186 = loc("tmp20"(#loc68)) +#loc187 = loc("tmp23"(#loc69)) +#loc189 = loc("tmp24"(#loc71)) +#loc190 = loc("tmp25"(#loc72)) +#loc191 = loc("tmp28"(#loc73)) +#loc193 = loc("tmp29"(#loc75)) +#loc194 = loc("tmp30"(#loc76)) +#loc195 = loc("tmp31"(#loc77)) +#loc196 = loc("tmp34"(#loc78)) +#loc197 = loc("tmp36"(#loc79)) +#loc198 = loc("tmp38"(#loc80)) +#loc199 = loc("tmp39"(#loc81)) +#loc200 = loc("tmp40"(#loc82)) +#loc201 = loc(fused[#loc87, #loc86]) +#loc202 = loc("tmp45"(#loc89)) +#loc203 = loc("tmp46"(#loc90)) +#loc204 = loc("tmp47"(#loc91)) +#loc205 = loc("tmp48"(#loc92)) +#loc206 = loc("tmp49"(#loc93)) +#loc207 = loc(fused[#loc106, #loc105]) +#loc208 = loc(fused[#loc129, #loc128]) +#loc209 = loc(fused[#loc131, #loc124]) +#loc210 = loc(fused[#loc135, #loc136]) +#loc211 = loc(callsite(#loc138 at #loc139)) +#loc212 = loc(callsite(#loc140 at #loc139)) +#loc213 = loc(callsite(#loc141 at #loc139)) +#loc214 = loc(callsite(#loc142 at #loc139)) +#loc216 = loc("cond"(#loc169)) +#loc217 = loc("eq"(#loc170)) +#loc218 = loc(fused[#loc183, #loc184]) +#loc220 = loc(fused[#loc186, #loc135, #loc136]) +#loc221 = loc(callsite(#loc26 at #loc188)) +#loc223 = loc(fused[#loc190, #loc183, #loc184]) +#loc224 = loc(callsite(#loc26 at #loc192)) +#loc226 = loc(callsite(#loc143 at #loc215)) +#loc227 = loc(callsite(#loc144 at #loc215)) +#loc228 = loc(callsite(#loc145 at #loc215)) +#loc230 = loc(callsite(#loc147 at #loc215)) +#loc231 = loc(callsite(#loc148 at #loc215)) +#loc232 = loc(callsite(#loc149 at #loc215)) +#loc234 = loc(callsite(#loc151 at #loc215)) +#loc235 = loc(callsite(#loc152 at #loc215)) +#loc236 = loc(callsite(#loc153 at #loc215)) +#loc237 = loc(callsite(#loc154 at #loc215)) +#loc238 = loc(callsite(#loc155 at #loc215)) +#loc239 = loc(callsite(#loc156 at #loc215)) +#loc240 = loc(callsite(#loc157 at #loc215)) +#loc242 = loc(callsite(#loc160 at #loc215)) +#loc243 = loc(callsite(#loc161 at #loc215)) +#loc244 = loc(callsite(#loc162 at #loc215)) +#loc245 = loc(callsite(#loc163 at #loc215)) +#loc247 = loc(callsite(#loc165 at #loc215)) +#loc248 = loc(callsite(#loc166 at #loc215)) +#loc249 = loc(callsite(#loc167 at #loc215)) +#loc250 = loc(callsite(#loc168 at #loc215)) +#loc251 = loc(callsite(#loc216 at #loc215)) +#loc252 = loc(callsite(#loc217 at #loc215)) +#loc253 = loc(callsite(#loc171 at #loc215)) +#loc254 = loc(callsite(#loc172 at #loc215)) +#loc255 = loc(callsite(#loc173 at #loc215)) +#loc256 = loc(callsite(#loc174 at #loc215)) +#loc257 = loc(callsite(#loc175 at #loc215)) +#loc258 = loc(callsite(#loc176 at #loc215)) +#loc259 = loc(callsite(#loc177 at #loc215)) +#loc260 = loc(callsite(#loc178 at #loc215)) +#loc261 = loc(callsite(#loc179 at #loc215)) +#loc262 = loc(callsite(#loc180 at #loc215)) +#loc263 = loc(callsite(#loc181 at #loc215)) +#loc264 = loc(callsite(#loc143 at #loc219)) +#loc265 = loc(callsite(#loc145 at #loc219)) +#loc267 = loc(callsite(#loc147 at #loc219)) +#loc268 = loc(callsite(#loc148 at #loc219)) +#loc269 = loc(callsite(#loc149 at #loc219)) +#loc271 = loc(callsite(#loc151 at #loc219)) +#loc272 = loc(callsite(#loc152 at #loc219)) +#loc273 = loc(callsite(#loc153 at #loc219)) +#loc274 = loc(callsite(#loc154 at #loc219)) +#loc275 = loc(callsite(#loc216 at #loc219)) +#loc276 = loc(callsite(#loc217 at #loc219)) +#loc277 = loc(callsite(#loc172 at #loc219)) +#loc278 = loc(callsite(#loc173 at #loc219)) +#loc279 = loc(callsite(#loc174 at #loc219)) +#loc280 = loc(callsite(#loc175 at #loc219)) +#loc281 = loc(callsite(#loc176 at #loc219)) +#loc282 = loc(callsite(#loc177 at #loc219)) +#loc283 = loc(callsite(#loc178 at #loc219)) +#loc284 = loc(callsite(#loc180 at #loc219)) +#loc285 = loc(callsite(#loc181 at #loc219)) +#loc286 = loc(callsite(#loc155 at #loc219)) +#loc287 = loc(callsite(#loc157 at #loc219)) +#loc289 = loc(callsite(#loc160 at #loc219)) +#loc290 = loc(callsite(#loc161 at #loc219)) +#loc291 = loc(callsite(#loc163 at #loc219)) +#loc293 = loc(callsite(#loc165 at #loc219)) +#loc294 = loc(callsite(#loc166 at #loc219)) +#loc295 = loc(callsite(#loc167 at #loc219)) +#loc296 = loc(callsite(#loc168 at #loc219)) +#loc297 = loc(callsite(#loc171 at #loc219)) +#loc298 = loc(callsite(#loc179 at #loc219)) +#loc299 = loc(callsite(#loc28 at #loc221)) +#loc300 = loc(callsite(#loc28 at #loc224)) +#loc301 = loc(callsite(#loc26 at #loc229)) +#loc303 = loc(callsite(#loc26 at #loc233)) +#loc305 = loc(callsite(#loc158 at #loc241)) +#loc306 = loc(callsite(#loc26 at #loc241)) +#loc308 = loc(callsite(#loc158 at #loc246)) +#loc309 = loc(callsite(#loc26 at #loc246)) +#loc311 = loc(callsite(#loc26 at #loc266)) +#loc313 = loc(callsite(#loc26 at #loc270)) +#loc315 = loc(callsite(#loc26 at #loc288)) +#loc317 = loc(callsite(#loc26 at #loc292)) +#loc319 = loc(callsite(#loc28 at #loc301)) +#loc320 = loc(callsite(#loc28 at #loc303)) +#loc321 = loc(callsite(#loc28 at #loc306)) +#loc322 = loc(callsite(#loc28 at #loc309)) +#loc323 = loc(callsite(#loc28 at #loc311)) +#loc324 = loc(callsite(#loc28 at #loc313)) +#loc325 = loc(callsite(#loc28 at #loc315)) +#loc326 = loc(callsite(#loc28 at #loc317)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/OU2OAVK633YHT43FEX32EDY5CTULBM4PI6AKNCY4Q5R3A3O236SA/triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0.ptx b/SpecForge-ext/cache/compiled_kernels/triton/0/OU2OAVK633YHT43FEX32EDY5CTULBM4PI6AKNCY4Q5R3A3O236SA/triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0.ptx new file mode 100644 index 0000000000000000000000000000000000000000..a9464cd513fb91705f87b36d261c5b7c163cf10b --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/OU2OAVK633YHT43FEX32EDY5CTULBM4PI6AKNCY4Q5R3A3O236SA/triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0.ptx @@ -0,0 +1,995 @@ +// +// Generated by LLVM NVPTX Back-End +// + +.version 8.7 +.target sm_90a +.address_size 64 + + // .globl triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0 // -- Begin function triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0 +.extern .func __assertfail +( + .param .b64 __assertfail_param_0, + .param .b64 __assertfail_param_1, + .param .b32 __assertfail_param_2, + .param .b64 __assertfail_param_3, + .param .b64 __assertfail_param_4 +) +.noreturn; +.global .align 1 .b8 assertFunc_0[8] = {117, 110, 107, 110, 111, 119, 110}; +.global .align 1 .b8 assertFile_0[114] = {47, 119, 111, 114, 107, 115, 112, 97, 99, 101, 47, 104, 97, 110, 114, 117, 105, 47, 83, 112, 101, 99, 70, 111, 114, 103, 101, 45, 101, 120, 116, 47, 99, 97, 99, 104, 101, 47, 99, 111, 109, 112, 105, 108, 101, 100, 95, 107, 101, 114, 110, 101, 108, 115, 47, 51, 101, 47, 99, 51, 101, 99, 118, 112, 55, 113, 110, 54, 120, 103, 52, 102, 55, 109, 51, 55, 99, 100, 120, 114, 111, 97, 100, 107, 105, 114, 53, 99, 104, 54, 114, 115, 97, 54, 110, 55, 105, 104, 99, 121, 107, 51, 120, 112, 115, 53, 102, 102, 111, 54, 46, 112, 121}; +.global .align 1 .b8 assertMessage_0[40] = {105, 110, 100, 101, 120, 32, 111, 117, 116, 32, 111, 102, 32, 98, 111, 117, 110, 100, 115, 58, 32, 48, 32, 60, 61, 32, 116, 109, 112, 54, 32, 60, 32, 49, 53, 49, 57, 51, 54}; +.extern .shared .align 16 .b8 global_smem[]; + // @triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0 +.visible .entry triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0( + .param .u64 .ptr .global .align 1 triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0_param_0, + .param .u64 .ptr .global .align 1 triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0_param_1, + .param .u64 .ptr .global .align 1 triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0_param_2, + .param .u64 .ptr .global .align 1 triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0_param_3, + .param .u64 triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0_param_4, + .param .u64 triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0_param_5, + .param .u64 .ptr .global .align 1 triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0_param_6, + .param .u64 .ptr .global .align 1 triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0_param_7 +) +.reqntid 512 +{ + .reg .pred %p<212>; + .reg .b16 %rs<7>; + .reg .b32 %r<104>; + .reg .b64 %rd<119>; + .loc 1 18 0 // c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:18:0 +$L__func_begin0: + .loc 1 18 0 // c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:18:0 + +// %bb.0: + ld.param.b64 %rd23, [triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0_param_4]; + ld.param.b64 %rd22, [triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0_param_3]; + ld.param.b64 %rd21, [triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0_param_2]; + ld.param.b64 %rd20, [triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0_param_0]; + ld.param.b64 %rd29, [triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0_param_1]; +$L__tmp0: + .loc 1 22 28 // c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:22:28 + mov.u32 %r7, %ctaid.x; + .loc 1 22 34 // c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:22:34 + cvt.u64.u32 %rd1, %r7; + .loc 1 25 37 // c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:25:37 + mov.u32 %r1, %tid.x; + shl.b32 %r8, %r1, 2; + and.b32 %r9, %r8, 2044; + .loc 1 30 40 // c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:30:40 + cvt.u64.u32 %rd2, %r9; + and.b32 %r10, %r1, 511; + mul.wide.u32 %rd30, %r10, 8; + mad.wide.u32 %rd31, %r7, 303872, %rd30; + add.s64 %rd112, %rd29, %rd31; + mov.b32 %r102, 0fFF800000; + mov.b64 %rd116, {%r102, %r102}; + mov.b64 %rd114, 9223372036854775807; + mov.b64 %rd113, -2048; + setp.gt.s64 %p2, %rd23, %rd1; + mov.b64 %rd115, %rd114; + mov.b32 %r103, %r102; + mov.b64 %rd117, %rd114; + mov.b64 %rd118, %rd114; +$L__BB0_1: // =>This Inner Loop Header: Depth=1 + .loc 1 36 53 // c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:36:53 + // begin inline asm + mov.u64 %rd32, 0x0; + createpolicy.fractional.L2::evict_first.b64 %rd32, 1.0; + // end inline asm +$L__tmp1: + .loc 2 147 29 // triton_helpers.py:147:29 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:39:38 ] + mov.b64 {%r15, %r16}, %rd116; + setp.nan.f32 %p3, %r15, %r15; + setp.nan.f32 %p4, %r16, %r16; + setp.nan.f32 %p5, %r102, %r102; + setp.nan.f32 %p6, %r103, %r103; +$L__tmp2: + .loc 1 31 31 // c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:31:31 + add.s64 %rd35, %rd2, %rd113; + add.s64 %rd36, %rd35, 2048; + add.s64 %rd37, %rd35, 2049; + add.s64 %rd38, %rd35, 2050; + .loc 1 32 29 // c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:32:29 + add.s64 %rd39, %rd35, 2051; + setp.lt.u64 %p7, %rd36, 151936; + .loc 1 36 63 // c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:36:63 + and.pred %p1, %p2, %p7; + mov.b32 %r13, 0; + .loc 1 36 53 // c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:36:53 + // begin inline asm + mov.u32 %r11, %r13; + mov.u32 %r12, %r13; + @%p1 ld.global.L1::evict_first.L2::cache_hint.v2.b32 { %r11, %r12 }, [ %rd112 + 0 ], %rd32; + // end inline asm + mov.b32 {%rs1, %rs2}, %r12; + .loc 1 36 115 // c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:36:115 + mov.b32 {%rs3, %rs4}, %r11; + cvt.f32.bf16 %r17, %rs3; + cvt.f32.bf16 %r18, %rs4; + cvt.f32.bf16 %r19, %rs1; + cvt.f32.bf16 %r20, %rs2; +$L__tmp3: + .loc 2 144 21 // triton_helpers.py:144:21 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:39:38 ] + setp.gt.f32 %p8, %r16, %r18; + setp.gt.f32 %p9, %r15, %r17; + setp.gt.f32 %p10, %r102, %r19; + setp.gt.f32 %p11, %r103, %r20; + .loc 2 145 23 // triton_helpers.py:145:23 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:39:38 ] + setp.eq.f32 %p12, %r15, %r17; + setp.eq.f32 %p13, %r16, %r18; + setp.eq.f32 %p14, %r102, %r19; + setp.eq.f32 %p15, %r103, %r20; + .loc 2 148 29 // triton_helpers.py:148:29 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:39:38 ] + setp.nan.bf16x2 %p16|%p17, %r11, %r13; + setp.num.bf16x2 %p18|%p19, %r11, %r13; + setp.nan.bf16 %p20, %rs1, %rs1; + setp.num.bf16 %p21, %rs1, %rs1; + setp.nan.bf16 %p22, %rs2, %rs2; + setp.num.bf16 %p23, %rs2, %rs2; + .loc 2 149 27 // triton_helpers.py:149:27 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:39:38 ] + and.pred %p24, %p4, %p19; + and.pred %p25, %p3, %p18; + and.pred %p26, %p5, %p21; + and.pred %p27, %p6, %p23; + .loc 2 149 16 // triton_helpers.py:149:16 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:39:38 ] + or.pred %p28, %p9, %p25; + or.pred %p29, %p8, %p24; + or.pred %p30, %p10, %p26; + or.pred %p31, %p11, %p27; + .loc 2 151 27 // triton_helpers.py:151:27 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:39:38 ] + and.pred %p32, %p3, %p16; + and.pred %p33, %p4, %p17; + and.pred %p34, %p5, %p20; + and.pred %p35, %p6, %p22; + .loc 2 151 17 // triton_helpers.py:151:17 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:39:38 ] + or.pred %p36, %p13, %p33; + or.pred %p37, %p12, %p32; + or.pred %p38, %p14, %p34; + or.pred %p39, %p15, %p35; + .loc 2 154 31 // triton_helpers.py:154:31 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:39:38 ] + setp.lt.s64 %p40, %rd117, %rd36; + setp.le.s64 %p41, %rd118, %rd36; + setp.lt.s64 %p42, %rd114, %rd38; + setp.lt.s64 %p43, %rd115, %rd39; + .loc 2 154 21 // triton_helpers.py:154:21 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:39:38 ] + and.pred %p44, %p40, %p37; + and.pred %p45, %p41, %p36; + and.pred %p46, %p42, %p38; + and.pred %p47, %p43, %p39; + .loc 2 154 12 // triton_helpers.py:154:12 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:39:38 ] + or.pred %p48, %p29, %p45; + or.pred %p49, %p28, %p44; + or.pred %p50, %p30, %p46; + or.pred %p51, %p31, %p47; + .loc 2 155 35 // triton_helpers.py:155:35 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:39:38 ] + selp.f32 %r21, %r15, %r17, %p49; + selp.f32 %r22, %r16, %r18, %p48; + selp.f32 %r23, %r102, %r19, %p50; + selp.f32 %r24, %r103, %r20, %p51; + .loc 2 155 69 // triton_helpers.py:155:69 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:39:38 ] + selp.b64 %rd40, %rd117, %rd36, %p49; + selp.b64 %rd41, %rd118, %rd37, %p48; + selp.b64 %rd42, %rd114, %rd38, %p50; + selp.b64 %rd43, %rd115, %rd39, %p51; +$L__tmp4: + .loc 1 41 54 // c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:41:54 + selp.f32 %r25, %r22, %r16, %p1; + selp.f32 %r26, %r21, %r15, %p1; + mov.b64 %rd116, {%r26, %r25}; + selp.f32 %r102, %r23, %r102, %p1; + selp.f32 %r103, %r24, %r103, %p1; + .loc 1 42 66 // c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:42:66 + selp.b64 %rd118, %rd41, %rd118, %p1; + selp.b64 %rd117, %rd40, %rd117, %p1; + selp.b64 %rd114, %rd42, %rd114, %p1; + selp.b64 %rd115, %rd43, %rd115, %p1; + .loc 1 30 40 // c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:30:40 + add.s64 %rd113, %rd113, 2048; + add.s64 %rd112, %rd112, 4096; + setp.lt.u64 %p52, %rd113, 149888; + @%p52 bra $L__BB0_1; +// %bb.2: + .loc 1 24 21 // c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:24:21 + setp.le.s64 %p60, %rd23, %rd1; + .loc 1 25 37 // c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:25:37 + and.b32 %r36, %r1, 31; +$L__tmp5: + .loc 2 144 21 // triton_helpers.py:144:21 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + mov.b64 {%r37, %r38}, %rd116; + setp.gt.f32 %p61, %r37, %r38; + setp.eq.f32 %p62, %r38, %r37; + .loc 2 147 29 // triton_helpers.py:147:29 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.nan.f32 %p63, %r37, %r37; + .loc 2 148 29 // triton_helpers.py:148:29 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.num.f32 %p64, %r38, %r38; + setp.nan.f32 %p65, %r38, %r38; + .loc 2 149 27 // triton_helpers.py:149:27 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + and.pred %p66, %p63, %p65; + and.pred %p67, %p63, %p64; + .loc 2 149 16 // triton_helpers.py:149:16 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + or.pred %p68, %p61, %p67; + or.pred %p69, %p62, %p66; + .loc 2 154 31 // triton_helpers.py:154:31 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.lt.s64 %p70, %rd117, %rd118; + .loc 2 154 21 // triton_helpers.py:154:21 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + and.pred %p71, %p70, %p69; + .loc 2 154 12 // triton_helpers.py:154:12 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + or.pred %p72, %p68, %p71; + .loc 2 155 35 // triton_helpers.py:155:35 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + selp.f32 %r39, %r37, %r38, %p72; + .loc 2 155 69 // triton_helpers.py:155:69 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + selp.b64 %rd51, %rd117, %rd118, %p72; + .loc 2 144 21 // triton_helpers.py:144:21 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.gt.f32 %p73, %r39, %r102; + .loc 2 145 23 // triton_helpers.py:145:23 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.eq.f32 %p74, %r39, %r102; + .loc 2 147 29 // triton_helpers.py:147:29 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.nan.f32 %p75, %r39, %r39; + .loc 2 148 29 // triton_helpers.py:148:29 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.nan.f32 %p76, %r102, %r102; + setp.num.f32 %p77, %r102, %r102; + .loc 2 149 27 // triton_helpers.py:149:27 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + and.pred %p78, %p75, %p77; + .loc 2 149 16 // triton_helpers.py:149:16 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + or.pred %p79, %p73, %p78; + .loc 2 151 27 // triton_helpers.py:151:27 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + and.pred %p80, %p76, %p75; + .loc 2 151 17 // triton_helpers.py:151:17 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + or.pred %p81, %p74, %p80; + .loc 2 154 31 // triton_helpers.py:154:31 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.lt.s64 %p82, %rd51, %rd114; + .loc 2 154 21 // triton_helpers.py:154:21 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + and.pred %p83, %p82, %p81; + .loc 2 154 12 // triton_helpers.py:154:12 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + or.pred %p84, %p79, %p83; + .loc 2 155 35 // triton_helpers.py:155:35 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + selp.f32 %r40, %r39, %r102, %p84; + .loc 2 155 69 // triton_helpers.py:155:69 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + selp.b64 %rd52, %rd51, %rd114, %p84; + .loc 2 144 21 // triton_helpers.py:144:21 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.gt.f32 %p85, %r40, %r103; + .loc 2 145 23 // triton_helpers.py:145:23 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.eq.f32 %p86, %r40, %r103; + .loc 2 147 29 // triton_helpers.py:147:29 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.nan.f32 %p87, %r40, %r40; + .loc 2 148 29 // triton_helpers.py:148:29 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.nan.f32 %p88, %r103, %r103; + setp.num.f32 %p89, %r103, %r103; + .loc 2 149 27 // triton_helpers.py:149:27 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + and.pred %p90, %p87, %p89; + .loc 2 149 16 // triton_helpers.py:149:16 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + or.pred %p91, %p85, %p90; + .loc 2 151 27 // triton_helpers.py:151:27 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + and.pred %p92, %p88, %p87; + .loc 2 151 17 // triton_helpers.py:151:17 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + or.pred %p93, %p86, %p92; + .loc 2 154 31 // triton_helpers.py:154:31 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.lt.s64 %p94, %rd52, %rd115; + .loc 2 154 21 // triton_helpers.py:154:21 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + and.pred %p95, %p94, %p93; + .loc 2 154 12 // triton_helpers.py:154:12 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + or.pred %p96, %p91, %p95; + .loc 2 155 35 // triton_helpers.py:155:35 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + selp.f32 %r41, %r40, %r103, %p96; + .loc 2 155 69 // triton_helpers.py:155:69 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + selp.b64 %rd53, %rd52, %rd115, %p96; + .loc 2 165 42 // triton_helpers.py:165:42 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + shfl.sync.bfly.b32 %r42, %r41, 16, 31, -1; + mov.b64 {_, %r43}, %rd53; + cvt.u32.u64 %r44, %rd53; + shfl.sync.bfly.b32 %r45, %r44, 16, 31, -1; + shfl.sync.bfly.b32 %r46, %r43, 16, 31, -1; + cvt.u64.u32 %rd54, %r45; + cvt.u64.u32 %rd55, %r46; + shl.b64 %rd56, %rd55, 32; + or.b64 %rd57, %rd54, %rd56; + .loc 2 144 21 // triton_helpers.py:144:21 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.gt.f32 %p97, %r41, %r42; + .loc 2 145 23 // triton_helpers.py:145:23 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.eq.f32 %p98, %r41, %r42; + .loc 2 147 29 // triton_helpers.py:147:29 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.nan.f32 %p99, %r41, %r41; + .loc 2 148 29 // triton_helpers.py:148:29 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.nan.f32 %p100, %r42, %r42; + setp.num.f32 %p101, %r42, %r42; + .loc 2 149 27 // triton_helpers.py:149:27 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + and.pred %p102, %p99, %p101; + .loc 2 149 16 // triton_helpers.py:149:16 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + or.pred %p103, %p97, %p102; + .loc 2 151 27 // triton_helpers.py:151:27 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + and.pred %p104, %p99, %p100; + .loc 2 151 17 // triton_helpers.py:151:17 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + or.pred %p105, %p98, %p104; + .loc 2 154 31 // triton_helpers.py:154:31 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.lt.s64 %p106, %rd53, %rd57; + .loc 2 154 21 // triton_helpers.py:154:21 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + and.pred %p107, %p105, %p106; + .loc 2 154 12 // triton_helpers.py:154:12 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + or.pred %p108, %p103, %p107; + .loc 2 155 35 // triton_helpers.py:155:35 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + selp.f32 %r47, %r41, %r42, %p108; + .loc 2 155 69 // triton_helpers.py:155:69 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + selp.b64 %rd58, %rd53, %rd57, %p108; + .loc 2 165 42 // triton_helpers.py:165:42 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + shfl.sync.bfly.b32 %r48, %r47, 8, 31, -1; + mov.b64 {_, %r49}, %rd58; + cvt.u32.u64 %r50, %rd58; + shfl.sync.bfly.b32 %r51, %r50, 8, 31, -1; + shfl.sync.bfly.b32 %r52, %r49, 8, 31, -1; + cvt.u64.u32 %rd59, %r51; + cvt.u64.u32 %rd60, %r52; + shl.b64 %rd61, %rd60, 32; + or.b64 %rd62, %rd59, %rd61; + .loc 2 144 21 // triton_helpers.py:144:21 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.gt.f32 %p109, %r47, %r48; + .loc 2 145 23 // triton_helpers.py:145:23 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.eq.f32 %p110, %r47, %r48; + .loc 2 147 29 // triton_helpers.py:147:29 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.nan.f32 %p111, %r47, %r47; + .loc 2 148 29 // triton_helpers.py:148:29 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.nan.f32 %p112, %r48, %r48; + setp.num.f32 %p113, %r48, %r48; + .loc 2 149 27 // triton_helpers.py:149:27 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + and.pred %p114, %p111, %p113; + .loc 2 149 16 // triton_helpers.py:149:16 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + or.pred %p115, %p109, %p114; + .loc 2 151 27 // triton_helpers.py:151:27 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + and.pred %p116, %p112, %p111; + .loc 2 151 17 // triton_helpers.py:151:17 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + or.pred %p117, %p110, %p116; + .loc 2 154 31 // triton_helpers.py:154:31 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.lt.s64 %p118, %rd58, %rd62; + .loc 2 154 21 // triton_helpers.py:154:21 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + and.pred %p119, %p118, %p117; + .loc 2 154 12 // triton_helpers.py:154:12 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + or.pred %p120, %p115, %p119; + .loc 2 155 35 // triton_helpers.py:155:35 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + selp.f32 %r53, %r47, %r48, %p120; + .loc 2 155 69 // triton_helpers.py:155:69 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + selp.b64 %rd63, %rd58, %rd62, %p120; + .loc 2 165 42 // triton_helpers.py:165:42 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + shfl.sync.bfly.b32 %r54, %r53, 4, 31, -1; + mov.b64 {_, %r55}, %rd63; + cvt.u32.u64 %r56, %rd63; + shfl.sync.bfly.b32 %r57, %r56, 4, 31, -1; + shfl.sync.bfly.b32 %r58, %r55, 4, 31, -1; + cvt.u64.u32 %rd64, %r57; + cvt.u64.u32 %rd65, %r58; + shl.b64 %rd66, %rd65, 32; + or.b64 %rd67, %rd64, %rd66; + .loc 2 144 21 // triton_helpers.py:144:21 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.gt.f32 %p121, %r53, %r54; + .loc 2 145 23 // triton_helpers.py:145:23 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.eq.f32 %p122, %r53, %r54; + .loc 2 147 29 // triton_helpers.py:147:29 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.nan.f32 %p123, %r53, %r53; + .loc 2 148 29 // triton_helpers.py:148:29 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.nan.f32 %p124, %r54, %r54; + setp.num.f32 %p125, %r54, %r54; + .loc 2 149 27 // triton_helpers.py:149:27 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + and.pred %p126, %p123, %p125; + .loc 2 149 16 // triton_helpers.py:149:16 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + or.pred %p127, %p121, %p126; + .loc 2 151 27 // triton_helpers.py:151:27 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + and.pred %p128, %p124, %p123; + .loc 2 151 17 // triton_helpers.py:151:17 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + or.pred %p129, %p122, %p128; + .loc 2 154 31 // triton_helpers.py:154:31 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.lt.s64 %p130, %rd63, %rd67; + .loc 2 154 21 // triton_helpers.py:154:21 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + and.pred %p131, %p130, %p129; + .loc 2 154 12 // triton_helpers.py:154:12 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + or.pred %p132, %p127, %p131; + .loc 2 155 35 // triton_helpers.py:155:35 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + selp.f32 %r59, %r53, %r54, %p132; + .loc 2 155 69 // triton_helpers.py:155:69 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + selp.b64 %rd68, %rd63, %rd67, %p132; + .loc 2 165 42 // triton_helpers.py:165:42 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + shfl.sync.bfly.b32 %r60, %r59, 2, 31, -1; + mov.b64 {_, %r61}, %rd68; + cvt.u32.u64 %r62, %rd68; + shfl.sync.bfly.b32 %r63, %r62, 2, 31, -1; + shfl.sync.bfly.b32 %r64, %r61, 2, 31, -1; + cvt.u64.u32 %rd69, %r63; + cvt.u64.u32 %rd70, %r64; + shl.b64 %rd71, %rd70, 32; + or.b64 %rd72, %rd69, %rd71; + .loc 2 144 21 // triton_helpers.py:144:21 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.gt.f32 %p133, %r59, %r60; + .loc 2 145 23 // triton_helpers.py:145:23 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.eq.f32 %p134, %r59, %r60; + .loc 2 147 29 // triton_helpers.py:147:29 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.nan.f32 %p135, %r59, %r59; + .loc 2 148 29 // triton_helpers.py:148:29 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.nan.f32 %p136, %r60, %r60; + setp.num.f32 %p137, %r60, %r60; + .loc 2 149 27 // triton_helpers.py:149:27 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + and.pred %p138, %p135, %p137; + .loc 2 149 16 // triton_helpers.py:149:16 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + or.pred %p139, %p133, %p138; + .loc 2 151 27 // triton_helpers.py:151:27 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + and.pred %p140, %p136, %p135; + .loc 2 151 17 // triton_helpers.py:151:17 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + or.pred %p141, %p134, %p140; + .loc 2 154 31 // triton_helpers.py:154:31 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.lt.s64 %p142, %rd68, %rd72; + .loc 2 154 21 // triton_helpers.py:154:21 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + and.pred %p143, %p142, %p141; + .loc 2 154 12 // triton_helpers.py:154:12 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + or.pred %p144, %p139, %p143; + .loc 2 155 35 // triton_helpers.py:155:35 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + selp.f32 %r65, %r59, %r60, %p144; + .loc 2 155 69 // triton_helpers.py:155:69 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + selp.b64 %rd73, %rd68, %rd72, %p144; + .loc 2 165 42 // triton_helpers.py:165:42 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + shfl.sync.bfly.b32 %r66, %r65, 1, 31, -1; + mov.b64 {_, %r67}, %rd73; + cvt.u32.u64 %r68, %rd73; + shfl.sync.bfly.b32 %r69, %r68, 1, 31, -1; + shfl.sync.bfly.b32 %r70, %r67, 1, 31, -1; + cvt.u64.u32 %rd74, %r69; + cvt.u64.u32 %rd75, %r70; + shl.b64 %rd76, %rd75, 32; + or.b64 %rd77, %rd74, %rd76; + .loc 2 144 21 // triton_helpers.py:144:21 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.gt.f32 %p145, %r65, %r66; + .loc 2 145 23 // triton_helpers.py:145:23 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.eq.f32 %p146, %r65, %r66; + .loc 2 147 29 // triton_helpers.py:147:29 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.nan.f32 %p147, %r65, %r65; + .loc 2 148 29 // triton_helpers.py:148:29 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.nan.f32 %p148, %r66, %r66; + setp.num.f32 %p149, %r66, %r66; + .loc 2 149 27 // triton_helpers.py:149:27 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + and.pred %p150, %p147, %p149; + .loc 2 149 16 // triton_helpers.py:149:16 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + or.pred %p151, %p145, %p150; + .loc 2 151 27 // triton_helpers.py:151:27 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + and.pred %p152, %p148, %p147; + .loc 2 151 17 // triton_helpers.py:151:17 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + or.pred %p153, %p146, %p152; + .loc 2 154 31 // triton_helpers.py:154:31 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.lt.s64 %p154, %rd73, %rd77; + .loc 2 154 21 // triton_helpers.py:154:21 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + and.pred %p155, %p154, %p153; + .loc 2 154 12 // triton_helpers.py:154:12 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + or.pred %p156, %p151, %p155; + .loc 2 155 69 // triton_helpers.py:155:69 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + selp.b64 %rd44, %rd73, %rd77, %p156; + .loc 2 165 42 // triton_helpers.py:165:42 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + bfe.u32 %r71, %r1, 5, 4; + setp.eq.b32 %p53, %r36, 0; + mov.b32 %r72, global_smem; + add.s32 %r73, %r72, 128; + shl.b32 %r74, %r71, 2; + add.s32 %r27, %r73, %r74; + .loc 2 155 35 // triton_helpers.py:155:35 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + selp.b32 %r28, %r65, %r66, %p156; + .loc 2 165 42 // triton_helpers.py:165:42 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + // begin inline asm + @%p53 st.shared.b32 [ %r27 + 0 ], %r28; + // end inline asm + shl.b32 %r75, %r71, 3; + add.s32 %r29, %r72, %r75; + // begin inline asm + @%p53 st.shared.b64 [ %r29 + 0 ], %rd44; + // end inline asm + bar.sync 0; + setp.lt.u32 %p55, %r1, 16; + add.s32 %r31, %r73, %r8; + // begin inline asm + @%p55 ld.shared.b32 %r30, [ %r31 + 0 ]; + // end inline asm + shl.b32 %r77, %r1, 3; + add.s32 %r32, %r72, %r77; + // begin inline asm + @%p55 ld.shared.b64 %rd45, [ %r32 + 0 ]; + // end inline asm + shfl.sync.bfly.b32 %r78, %r30, 8, 31, -1; + mov.b64 {_, %r79}, %rd45; + cvt.u32.u64 %r80, %rd45; + shfl.sync.bfly.b32 %r81, %r80, 8, 31, -1; + shfl.sync.bfly.b32 %r82, %r79, 8, 31, -1; + cvt.u64.u32 %rd78, %r81; + cvt.u64.u32 %rd79, %r82; + shl.b64 %rd80, %rd79, 32; + or.b64 %rd81, %rd78, %rd80; + .loc 2 144 21 // triton_helpers.py:144:21 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.gt.f32 %p157, %r30, %r78; + .loc 2 145 23 // triton_helpers.py:145:23 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.eq.f32 %p158, %r30, %r78; + .loc 2 147 29 // triton_helpers.py:147:29 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.nan.f32 %p159, %r30, %r30; + .loc 2 148 29 // triton_helpers.py:148:29 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.nan.f32 %p160, %r78, %r78; + setp.num.f32 %p161, %r78, %r78; + .loc 2 149 27 // triton_helpers.py:149:27 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + and.pred %p162, %p159, %p161; + .loc 2 149 16 // triton_helpers.py:149:16 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + or.pred %p163, %p157, %p162; + .loc 2 151 27 // triton_helpers.py:151:27 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + and.pred %p164, %p159, %p160; + .loc 2 151 17 // triton_helpers.py:151:17 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + or.pred %p165, %p158, %p164; + .loc 2 154 31 // triton_helpers.py:154:31 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.lt.s64 %p166, %rd45, %rd81; + .loc 2 154 21 // triton_helpers.py:154:21 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + and.pred %p167, %p165, %p166; + .loc 2 154 12 // triton_helpers.py:154:12 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + or.pred %p168, %p163, %p167; + .loc 2 155 35 // triton_helpers.py:155:35 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + selp.f32 %r83, %r30, %r78, %p168; + .loc 2 155 69 // triton_helpers.py:155:69 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + selp.b64 %rd82, %rd45, %rd81, %p168; + .loc 2 165 42 // triton_helpers.py:165:42 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + shfl.sync.bfly.b32 %r84, %r83, 4, 31, -1; + mov.b64 {_, %r85}, %rd82; + cvt.u32.u64 %r86, %rd82; + shfl.sync.bfly.b32 %r87, %r86, 4, 31, -1; + shfl.sync.bfly.b32 %r88, %r85, 4, 31, -1; + cvt.u64.u32 %rd83, %r87; + cvt.u64.u32 %rd84, %r88; + shl.b64 %rd85, %rd84, 32; + or.b64 %rd86, %rd83, %rd85; + .loc 2 144 21 // triton_helpers.py:144:21 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.gt.f32 %p169, %r83, %r84; + .loc 2 145 23 // triton_helpers.py:145:23 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.eq.f32 %p170, %r83, %r84; + .loc 2 147 29 // triton_helpers.py:147:29 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.nan.f32 %p171, %r83, %r83; + .loc 2 148 29 // triton_helpers.py:148:29 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.nan.f32 %p172, %r84, %r84; + setp.num.f32 %p173, %r84, %r84; + .loc 2 149 27 // triton_helpers.py:149:27 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + and.pred %p174, %p171, %p173; + .loc 2 149 16 // triton_helpers.py:149:16 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + or.pred %p175, %p169, %p174; + .loc 2 151 27 // triton_helpers.py:151:27 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + and.pred %p176, %p172, %p171; + .loc 2 151 17 // triton_helpers.py:151:17 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + or.pred %p177, %p170, %p176; + .loc 2 154 31 // triton_helpers.py:154:31 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.lt.s64 %p178, %rd82, %rd86; + .loc 2 154 21 // triton_helpers.py:154:21 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + and.pred %p179, %p178, %p177; + .loc 2 154 12 // triton_helpers.py:154:12 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + or.pred %p180, %p175, %p179; + .loc 2 155 35 // triton_helpers.py:155:35 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + selp.f32 %r89, %r83, %r84, %p180; + .loc 2 155 69 // triton_helpers.py:155:69 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + selp.b64 %rd87, %rd82, %rd86, %p180; + .loc 2 165 42 // triton_helpers.py:165:42 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + shfl.sync.bfly.b32 %r90, %r89, 2, 31, -1; + mov.b64 {_, %r91}, %rd87; + cvt.u32.u64 %r92, %rd87; + shfl.sync.bfly.b32 %r93, %r92, 2, 31, -1; + shfl.sync.bfly.b32 %r94, %r91, 2, 31, -1; + cvt.u64.u32 %rd88, %r93; + cvt.u64.u32 %rd89, %r94; + shl.b64 %rd90, %rd89, 32; + or.b64 %rd91, %rd88, %rd90; + .loc 2 144 21 // triton_helpers.py:144:21 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.gt.f32 %p181, %r89, %r90; + .loc 2 145 23 // triton_helpers.py:145:23 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.eq.f32 %p182, %r89, %r90; + .loc 2 147 29 // triton_helpers.py:147:29 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.nan.f32 %p183, %r89, %r89; + .loc 2 148 29 // triton_helpers.py:148:29 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.nan.f32 %p184, %r90, %r90; + setp.num.f32 %p185, %r90, %r90; + .loc 2 149 27 // triton_helpers.py:149:27 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + and.pred %p186, %p183, %p185; + .loc 2 149 16 // triton_helpers.py:149:16 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + or.pred %p187, %p181, %p186; + .loc 2 151 27 // triton_helpers.py:151:27 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + and.pred %p188, %p184, %p183; + .loc 2 151 17 // triton_helpers.py:151:17 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + or.pred %p189, %p182, %p188; + .loc 2 154 31 // triton_helpers.py:154:31 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.lt.s64 %p190, %rd87, %rd91; + .loc 2 154 21 // triton_helpers.py:154:21 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + and.pred %p191, %p190, %p189; + .loc 2 154 12 // triton_helpers.py:154:12 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + or.pred %p192, %p187, %p191; + .loc 2 155 35 // triton_helpers.py:155:35 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + selp.f32 %r95, %r89, %r90, %p192; + .loc 2 155 69 // triton_helpers.py:155:69 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + selp.b64 %rd92, %rd87, %rd91, %p192; + .loc 2 165 42 // triton_helpers.py:165:42 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + shfl.sync.bfly.b32 %r96, %r95, 1, 31, -1; + mov.b64 {_, %r97}, %rd92; + cvt.u32.u64 %r98, %rd92; + shfl.sync.bfly.b32 %r99, %r98, 1, 31, -1; + shfl.sync.bfly.b32 %r100, %r97, 1, 31, -1; + cvt.u64.u32 %rd93, %r99; + cvt.u64.u32 %rd94, %r100; + shl.b64 %rd95, %rd94, 32; + or.b64 %rd96, %rd93, %rd95; + .loc 2 144 21 // triton_helpers.py:144:21 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.gt.f32 %p193, %r95, %r96; + .loc 2 145 23 // triton_helpers.py:145:23 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.eq.f32 %p194, %r95, %r96; + .loc 2 147 29 // triton_helpers.py:147:29 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.nan.f32 %p195, %r95, %r95; + .loc 2 148 29 // triton_helpers.py:148:29 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.nan.f32 %p196, %r96, %r96; + setp.num.f32 %p197, %r96, %r96; + .loc 2 149 27 // triton_helpers.py:149:27 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + and.pred %p198, %p195, %p197; + .loc 2 149 16 // triton_helpers.py:149:16 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + or.pred %p199, %p193, %p198; + .loc 2 151 27 // triton_helpers.py:151:27 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + and.pred %p200, %p196, %p195; + .loc 2 151 17 // triton_helpers.py:151:17 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + or.pred %p201, %p194, %p200; + .loc 2 154 31 // triton_helpers.py:154:31 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.lt.s64 %p202, %rd92, %rd96; + .loc 2 154 21 // triton_helpers.py:154:21 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + and.pred %p203, %p202, %p201; + .loc 2 154 12 // triton_helpers.py:154:12 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + or.pred %p204, %p199, %p203; + .loc 2 155 69 // triton_helpers.py:155:69 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + selp.b64 %rd46, %rd92, %rd96, %p204; + .loc 2 165 42 // triton_helpers.py:165:42 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + setp.eq.b32 %p57, %r1, 0; + .loc 2 155 35 // triton_helpers.py:155:35 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + selp.b32 %r34, %r95, %r96, %p204; + .loc 2 165 42 // triton_helpers.py:165:42 @[ c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:43:75 ] + // begin inline asm + @%p57 st.shared.b32 [ %r31 + 0 ], %r34; + // end inline asm + // begin inline asm + @%p57 st.shared.b64 [ %r32 + 0 ], %rd46; + // end inline asm + bar.sync 0; + ld.shared.b64 %rd97, [global_smem]; +$L__tmp6: + .loc 1 45 31 // c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:45:31 + shl.b64 %rd98, %rd1, 3; + add.s64 %rd49, %rd22, %rd98; + .loc 1 45 36 // c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:45:36 + // begin inline asm + mov.u64 %rd47, 0x0; + createpolicy.fractional.L2::evict_last.b64 %rd47, 1.0; + // end inline asm + // begin inline asm + mov.u64 %rd48, 0x0; + @%p2 ld.global.L1::evict_last.L2::cache_hint.b64 { %rd48 }, [ %rd49 + 0 ], %rd47; + // end inline asm + .loc 1 47 18 // c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:47:18 + add.s64 %rd99, %rd97, 151936; + .loc 1 48 18 // c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:48:18 + setp.lt.s64 %p205, %rd97, 0; + .loc 1 49 32 // c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:49:32 + selp.b64 %rd19, %rd99, %rd97, %p205; + .loc 1 50 37 // c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:50:37 + setp.lt.u64 %p206, %rd19, 151936; + .loc 1 50 65 // c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:50:65 + or.pred %p207, %p60, %p206; + @%p207 bra $L__BB0_4; + bra.uni $L__BB0_3; +$L__BB0_4: + bar.sync 0; + .loc 1 51 30 // c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:51:30 + add.s64 %rd101, %rd21, %rd19; + .loc 1 51 37 // c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:51:37 + // begin inline asm + mov.u64 %rd102, 0x0; + createpolicy.fractional.L2::evict_last.b64 %rd102, 1.0; + // end inline asm + // begin inline asm + mov.u16 %rs5, 0x0; + @%p2 ld.global.L1::evict_last.L2::cache_hint.b8 { %rs5 }, [ %rd101 + 0 ], %rd102; + // end inline asm + and.b16 %rs6, %rs5, 255; + setp.eq.b16 %p210, %rs6, 0; + .loc 1 54 20 // c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:54:20 + selp.b64 %rd103, 0, %rd48, %p210; + .loc 1 55 4 // c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:55:4 + bar.sync 0; + .loc 1 56 28 // c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:56:28 + add.s64 %rd104, %rd20, %rd98; + .loc 1 56 40 // c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:56:40 + setp.eq.b32 %p211, %r10, 0; + and.pred %p209, %p211, %p2; + // begin inline asm + @%p209 st.global.b64 [ %rd104 + 0 ], { %rd103 }; + // end inline asm + .loc 1 56 4 // c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:56:4 + ret; +$L__BB0_3: + .loc 1 50 65 // c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py:50:65 + { // callseq 0, 0 + .param .b64 param0; + .param .b64 param1; + .param .b32 param2; + .param .b64 param3; + .param .b64 param4; + mov.b64 %rd106, assertFunc_0; + cvta.global.u64 %rd107, %rd106; + st.param.b64 [param3], %rd107; + mov.b64 %rd108, assertFile_0; + cvta.global.u64 %rd109, %rd108; + st.param.b64 [param1], %rd109; + mov.b64 %rd110, assertMessage_0; + cvta.global.u64 %rd111, %rd110; + st.param.b64 [param0], %rd111; + st.param.b64 [param4], 1; + st.param.b32 [param2], 50; + call.uni __assertfail, (param0, param1, param2, param3, param4); + } // callseq 0 + trap; +$L__tmp7: +$L__func_end0: + // -- End function +} + .file 1 "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py" + .file 2 "/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py" + .section .debug_abbrev + { +.b8 1 // Abbreviation Code +.b8 17 // DW_TAG_compile_unit +.b8 1 // DW_CHILDREN_yes +.b8 37 // DW_AT_producer +.b8 8 // DW_FORM_string +.b8 19 // DW_AT_language +.b8 5 // DW_FORM_data2 +.b8 3 // DW_AT_name +.b8 8 // DW_FORM_string +.b8 16 // DW_AT_stmt_list +.b8 6 // DW_FORM_data4 +.b8 27 // DW_AT_comp_dir +.b8 8 // DW_FORM_string +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 2 // Abbreviation Code +.b8 46 // DW_TAG_subprogram +.b8 0 // DW_CHILDREN_no +.b8 3 // DW_AT_name +.b8 8 // DW_FORM_string +.b8 32 // DW_AT_inline +.b8 11 // DW_FORM_data1 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 3 // Abbreviation Code +.b8 46 // DW_TAG_subprogram +.b8 1 // DW_CHILDREN_yes +.b8 17 // DW_AT_low_pc +.b8 1 // DW_FORM_addr +.b8 18 // DW_AT_high_pc +.b8 1 // DW_FORM_addr +.b8 49 // DW_AT_abstract_origin +.b8 19 // DW_FORM_ref4 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 4 // Abbreviation Code +.b8 29 // DW_TAG_inlined_subroutine +.b8 0 // DW_CHILDREN_no +.b8 49 // DW_AT_abstract_origin +.b8 19 // DW_FORM_ref4 +.b8 17 // DW_AT_low_pc +.b8 1 // DW_FORM_addr +.b8 18 // DW_AT_high_pc +.b8 1 // DW_FORM_addr +.b8 88 // DW_AT_call_file +.b8 11 // DW_FORM_data1 +.b8 89 // DW_AT_call_line +.b8 11 // DW_FORM_data1 +.b8 87 // DW_AT_call_column +.b8 11 // DW_FORM_data1 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 0 // EOM(3) + } + .section .debug_info + { +.b32 263 // Length of Unit +.b8 2 // DWARF version number +.b8 0 +.b32 .debug_abbrev // Offset Into Abbrev. Section +.b8 8 // Address Size (in bytes) +.b8 1 // Abbrev [1] 0xb:0x100 DW_TAG_compile_unit +.b8 116 // DW_AT_producer +.b8 114 +.b8 105 +.b8 116 +.b8 111 +.b8 110 +.b8 0 +.b8 2 // DW_AT_language +.b8 0 +.b8 99 // DW_AT_name +.b8 51 +.b8 101 +.b8 99 +.b8 118 +.b8 112 +.b8 55 +.b8 113 +.b8 110 +.b8 54 +.b8 120 +.b8 103 +.b8 52 +.b8 102 +.b8 55 +.b8 109 +.b8 51 +.b8 55 +.b8 99 +.b8 100 +.b8 120 +.b8 114 +.b8 111 +.b8 97 +.b8 100 +.b8 107 +.b8 105 +.b8 114 +.b8 53 +.b8 99 +.b8 104 +.b8 54 +.b8 114 +.b8 115 +.b8 97 +.b8 54 +.b8 110 +.b8 55 +.b8 105 +.b8 104 +.b8 99 +.b8 121 +.b8 107 +.b8 51 +.b8 120 +.b8 112 +.b8 115 +.b8 53 +.b8 102 +.b8 102 +.b8 111 +.b8 54 +.b8 46 +.b8 112 +.b8 121 +.b8 0 +.b32 .debug_line // DW_AT_stmt_list +.b8 47 // DW_AT_comp_dir +.b8 119 +.b8 111 +.b8 114 +.b8 107 +.b8 115 +.b8 112 +.b8 97 +.b8 99 +.b8 101 +.b8 47 +.b8 104 +.b8 97 +.b8 110 +.b8 114 +.b8 117 +.b8 105 +.b8 47 +.b8 83 +.b8 112 +.b8 101 +.b8 99 +.b8 70 +.b8 111 +.b8 114 +.b8 103 +.b8 101 +.b8 45 +.b8 101 +.b8 120 +.b8 116 +.b8 47 +.b8 99 +.b8 97 +.b8 99 +.b8 104 +.b8 101 +.b8 47 +.b8 99 +.b8 111 +.b8 109 +.b8 112 +.b8 105 +.b8 108 +.b8 101 +.b8 100 +.b8 95 +.b8 107 +.b8 101 +.b8 114 +.b8 110 +.b8 101 +.b8 108 +.b8 115 +.b8 47 +.b8 51 +.b8 101 +.b8 0 +.b8 2 // Abbrev [2] 0x8b:0x39 DW_TAG_subprogram +.b8 116 // DW_AT_name +.b8 114 +.b8 105 +.b8 116 +.b8 111 +.b8 110 +.b8 95 +.b8 114 +.b8 101 +.b8 100 +.b8 95 +.b8 102 +.b8 117 +.b8 115 +.b8 101 +.b8 100 +.b8 95 +.b8 95 +.b8 116 +.b8 111 +.b8 95 +.b8 99 +.b8 111 +.b8 112 +.b8 121 +.b8 95 +.b8 97 +.b8 114 +.b8 103 +.b8 109 +.b8 97 +.b8 120 +.b8 95 +.b8 105 +.b8 110 +.b8 100 +.b8 101 +.b8 120 +.b8 95 +.b8 109 +.b8 117 +.b8 108 +.b8 95 +.b8 117 +.b8 110 +.b8 115 +.b8 113 +.b8 117 +.b8 101 +.b8 101 +.b8 122 +.b8 101 +.b8 95 +.b8 48 +.b8 0 +.b8 1 // DW_AT_inline +.b8 3 // Abbrev [3] 0xc4:0x46 DW_TAG_subprogram +.b64 $L__func_begin0 // DW_AT_low_pc +.b64 $L__func_end0 // DW_AT_high_pc +.b32 139 // DW_AT_abstract_origin +.b8 4 // Abbrev [4] 0xd9:0x18 DW_TAG_inlined_subroutine +.b32 139 // DW_AT_abstract_origin +.b64 $L__tmp1 // DW_AT_low_pc +.b64 $L__tmp4 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 39 // DW_AT_call_line +.b8 38 // DW_AT_call_column +.b8 4 // Abbrev [4] 0xf1:0x18 DW_TAG_inlined_subroutine +.b32 139 // DW_AT_abstract_origin +.b64 $L__tmp5 // DW_AT_low_pc +.b64 $L__tmp6 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 43 // DW_AT_call_line +.b8 75 // DW_AT_call_column +.b8 0 // End Of Children Mark +.b8 0 // End Of Children Mark + } + .section .debug_macinfo { } diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/OU2OAVK633YHT43FEX32EDY5CTULBM4PI6AKNCY4Q5R3A3O236SA/triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0.source b/SpecForge-ext/cache/compiled_kernels/triton/0/OU2OAVK633YHT43FEX32EDY5CTULBM4PI6AKNCY4Q5R3A3O236SA/triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0.source new file mode 100644 index 0000000000000000000000000000000000000000..8cb87c3bfa82c19c76e051c3cfdd54b3fb67a720 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/OU2OAVK633YHT43FEX32EDY5CTULBM4PI6AKNCY4Q5R3A3O236SA/triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0.source @@ -0,0 +1,378 @@ +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":18:0) +#loc53 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":143:0) +#loc65 = loc(unknown) +#loc73 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":86:0) +#loc77 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":63:0) +#loc86 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":164:0) +#loc90 = loc("in_out_ptr0"(#loc)) +#loc91 = loc("in_ptr0"(#loc)) +#loc92 = loc("in_ptr1"(#loc)) +#loc93 = loc("in_ptr2"(#loc)) +#loc94 = loc("xnumel"(#loc)) +#loc95 = loc("r0_numel"(#loc)) +#loc135 = loc("a_value"(#loc53)) +#loc136 = loc("a_index"(#loc53)) +#loc137 = loc("b_value"(#loc53)) +#loc138 = loc("b_index"(#loc53)) +#loc151 = loc("x"(#loc73)) +#loc152 = loc("x"(#loc77)) +#loc153 = loc("value"(#loc86)) +#loc154 = loc("index"(#loc86)) +module { + tt.func public @triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(%in_out_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_out_ptr0"(#loc)), %in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %in_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr1"(#loc)), %in_ptr2: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr2"(#loc)), %xnumel: i64 loc("xnumel"(#loc)), %r0_numel: i64 {tt.divisibility = 16 : i32} loc("r0_numel"(#loc))) attributes {noinline = false} { + %r0_numel_0 = arith.constant 151936 : i32 loc(#loc96) + %xoffset = tt.get_program_id x : i32 loc(#loc97) + %xoffset_1 = arith.extsi %xoffset : i32 to i64 loc(#loc98) + %xoffset_2 = arith.constant 1 : i32 loc(#loc99) + %xoffset_3 = arith.constant 1 : i64 loc(#loc99) + %xoffset_4 = arith.muli %xoffset_1, %xoffset_3 : i64 loc(#loc99) + %xindex = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32> loc(#loc100) + %xindex_5 = tt.expand_dims %xindex {axis = 1 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc101) + %xindex_6 = arith.extsi %xindex_5 : tensor<1x1xi32> to tensor<1x1xi64> loc(#loc102) + %xindex_7 = tt.splat %xoffset_4 : i64 -> tensor<1x1xi64> loc(#loc103) + %xindex_8 = arith.addi %xindex_7, %xindex_6 : tensor<1x1xi64> loc(#loc103) + %xmask = tt.splat %xnumel : i64 -> tensor<1x1xi64> loc(#loc104) + %xmask_9 = arith.cmpi slt, %xindex_8, %xmask : tensor<1x1xi64> loc(#loc104) + %r0_base = tt.make_range {end = 2048 : i32, start = 0 : i32} : tensor<2048xi32> loc(#loc105) + %r0_base_10 = tt.expand_dims %r0_base {axis = 0 : i32} : tensor<2048xi32> -> tensor<1x2048xi32> loc(#loc106) + %r0_base_11 = arith.extsi %r0_base_10 : tensor<1x2048xi32> to tensor<1x2048xi64> loc(#loc107) + %_tmp2 = arith.constant 0xFF800000 : f32 loc(#loc108) + %_tmp2_12 = arith.constant dense<0xFF800000> : tensor<1x2048xf32> loc(#loc108) + %_tmp2_index = arith.constant 9223372036854775807 : i64 loc(#loc109) + %_tmp2_index_13 = arith.constant dense<9223372036854775807> : tensor<1x2048xi64> loc(#loc109) + %c0_i32 = arith.constant 0 : i32 loc(#loc15) + %c2048_i32 = arith.constant 2048 : i32 loc(#loc15) + %0 = arith.bitcast %c0_i32 : i32 to i32 loc(#loc15) + %1 = arith.bitcast %r0_numel_0 : i32 to i32 loc(#loc15) + %2 = arith.bitcast %c2048_i32 : i32 to i32 loc(#loc15) + %3 = ub.poison : i32 loc(#loc15) + %_tmp2_index_14:2 = scf.for %r0_offset = %0 to %1 step %2 iter_args(%_tmp2_29 = %_tmp2_12, %_tmp2_index_30 = %_tmp2_index_13) -> (tensor<1x2048xf32>, tensor<1x2048xi64>) : i32 { + %r0_index = arith.extsi %r0_offset : i32 to i64 loc(#loc111) + %r0_index_31 = tt.splat %r0_index : i64 -> tensor<1x2048xi64> loc(#loc111) + %r0_index_32 = arith.addi %r0_index_31, %r0_base_11 : tensor<1x2048xi64> loc(#loc111) + %r0_mask = arith.extsi %r0_numel_0 : i32 to i64 loc(#loc112) + %r0_mask_33 = tt.splat %r0_mask : i64 -> tensor<1x2048xi64> loc(#loc112) + %r0_mask_34 = arith.cmpi slt, %r0_index_32, %r0_mask_33 : tensor<1x2048xi64> loc(#loc112) + %tmp0 = arith.constant 151936 : i32 loc(#loc113) + %tmp0_35 = arith.constant 151936 : i64 loc(#loc113) + %tmp0_36 = arith.constant dense<151936> : tensor<1x1xi64> loc(#loc113) + %tmp0_37 = arith.muli %tmp0_36, %xindex_8 : tensor<1x1xi64> loc(#loc113) + %tmp0_38 = tt.broadcast %tmp0_37 : tensor<1x1xi64> -> tensor<1x2048xi64> loc(#loc114) + %tmp0_39 = arith.addi %r0_index_32, %tmp0_38 : tensor<1x2048xi64> loc(#loc114) + %tmp0_40 = tt.splat %in_ptr0 : !tt.ptr -> tensor<1x2048x!tt.ptr> loc(#loc115) + %tmp0_41 = tt.addptr %tmp0_40, %tmp0_39 : tensor<1x2048x!tt.ptr>, tensor<1x2048xi64> loc(#loc115) + %tmp0_42 = tt.broadcast %xmask_9 : tensor<1x1xi1> -> tensor<1x2048xi1> loc(#loc116) + %tmp0_43 = arith.andi %r0_mask_34, %tmp0_42 : tensor<1x2048xi1> loc(#loc116) + %tmp0_44 = arith.constant 0.000000e+00 : f32 loc(#loc117) + %tmp0_45 = arith.constant dense<0.000000e+00> : tensor<1x2048xf32> loc(#loc117) + %tmp0_46 = arith.truncf %tmp0_45 : tensor<1x2048xf32> to tensor<1x2048xbf16> loc(#loc117) + %tmp0_47 = tt.load %tmp0_41, %tmp0_43, %tmp0_46 evictionPolicy = evict_first : tensor<1x2048x!tt.ptr> loc(#loc117) + %tmp0_48 = arith.extf %tmp0_47 : tensor<1x2048xbf16> to tensor<1x2048xf32> loc(#loc118) + %16:2 = tt.call @torch._inductor.runtime.triton_helpers.maximum_with_index__fp32S1_2048S_i64S1_2048S_fp32S1_2048S_i64S1_2048S__(%_tmp2_29, %_tmp2_index_30, %tmp0_48, %r0_index_32) : (tensor<1x2048xf32>, tensor<1x2048xi64>, tensor<1x2048xf32>, tensor<1x2048xi64>) -> (tensor<1x2048xf32>, tensor<1x2048xi64>) loc(#loc24) + %_tmp2_49 = tt.broadcast %xmask_9 : tensor<1x1xi1> -> tensor<1x2048xi1> loc(#loc119) + %_tmp2_50 = arith.andi %r0_mask_34, %_tmp2_49 : tensor<1x2048xi1> loc(#loc119) + %_tmp2_51 = arith.select %_tmp2_50, %16#0, %_tmp2_29 : tensor<1x2048xi1>, tensor<1x2048xf32> loc(#loc120) + %_tmp2_index_52 = tt.broadcast %xmask_9 : tensor<1x1xi1> -> tensor<1x2048xi1> loc(#loc121) + %_tmp2_index_53 = arith.andi %r0_mask_34, %_tmp2_index_52 : tensor<1x2048xi1> loc(#loc121) + %_tmp2_index_54 = arith.select %_tmp2_index_53, %16#1, %_tmp2_index_30 : tensor<1x2048xi1>, tensor<1x2048xi64> loc(#loc122) + scf.yield %_tmp2_51, %_tmp2_index_54 : tensor<1x2048xf32>, tensor<1x2048xi64> loc(#loc29) + } loc(#loc155) + %4:2 = tt.call @"torch._inductor.runtime.triton_helpers.max_with_index__fp32S1_2048S_i64S1_2048S__(2,)cconstexpr_1_"(%_tmp2_index_14#0, %_tmp2_index_14#1) : (tensor<1x2048xf32>, tensor<1x2048xi64>) -> (tensor<1xf32>, tensor<1xi64>) loc(#loc30) + %tmp2 = tt.expand_dims %4#1 {axis = 1 : i32} : tensor<1xi64> -> tensor<1x1xi64> loc(#loc123) + %tmp11 = tt.splat %in_ptr2 : !tt.ptr -> tensor<1x1x!tt.ptr> loc(#loc124) + %tmp11_15 = tt.addptr %tmp11, %xindex_8 : tensor<1x1x!tt.ptr>, tensor<1x1xi64> loc(#loc124) + %tmp11_16 = tt.load %tmp11_15, %xmask_9 evictionPolicy = evict_last : tensor<1x1x!tt.ptr> loc(#loc125) + %tmp3 = arith.constant 151936 : i32 loc(#loc126) + %tmp3_17 = arith.constant dense<151936> : tensor<1x1xi32> loc(#loc126) + %tmp4 = arith.extsi %tmp3_17 : tensor<1x1xi32> to tensor<1x1xi64> loc(#loc127) + %tmp4_18 = arith.addi %tmp2, %tmp4 : tensor<1x1xi64> loc(#loc127) + %tmp5 = arith.constant 0 : i32 loc(#loc128) + %tmp5_19 = arith.extsi %tmp5 : i32 to i64 loc(#loc128) + %tmp5_20 = tt.splat %tmp5_19 : i64 -> tensor<1x1xi64> loc(#loc128) + %tmp5_21 = arith.cmpi slt, %tmp2, %tmp5_20 : tensor<1x1xi64> loc(#loc128) + %tmp6 = arith.select %tmp5_21, %tmp4_18, %tmp2 : tensor<1x1xi1>, tensor<1x1xi64> loc(#loc129) + %c0_i32_22 = arith.constant 0 : i32 loc(#loc38) + %5 = arith.extsi %c0_i32_22 : i32 to i64 loc(#loc38) + %6 = tt.splat %5 : i64 -> tensor<1x1xi64> loc(#loc38) + %7 = arith.cmpi sle, %6, %tmp6 : tensor<1x1xi64> loc(#loc38) + %c151936_i32 = arith.constant 151936 : i32 loc(#loc39) + %8 = arith.extsi %c151936_i32 : i32 to i64 loc(#loc39) + %9 = tt.splat %8 : i64 -> tensor<1x1xi64> loc(#loc39) + %10 = arith.cmpi slt, %tmp6, %9 : tensor<1x1xi64> loc(#loc39) + %11 = arith.andi %7, %10 : tensor<1x1xi1> loc(#loc40) + %true = arith.constant true loc(#loc41) + %cst = arith.constant dense : tensor<1x1xi1> loc(#loc41) + %12 = arith.xori %xmask_9, %cst : tensor<1x1xi1> loc(#loc41) + %13 = arith.ori %11, %12 : tensor<1x1xi1> loc(#loc42) + tt.assert %13, "index out of bounds: 0 <= tmp6 < 151936" : tensor<1x1xi1> loc(#loc43) + %tmp8 = tt.splat %in_ptr1 : !tt.ptr -> tensor<1x1x!tt.ptr> loc(#loc130) + %tmp8_23 = tt.addptr %tmp8, %tmp6 : tensor<1x1x!tt.ptr>, tensor<1x1xi64> loc(#loc130) + %tmp8_24 = tt.bitcast %tmp8_23 : tensor<1x1x!tt.ptr> -> tensor<1x1x!tt.ptr> loc(#loc131) + %tmp8_25 = tt.load %tmp8_24, %xmask_9 evictionPolicy = evict_last : tensor<1x1x!tt.ptr> loc(#loc131) + %tmp8_26 = arith.constant 0 : i8 loc(#loc131) + %tmp8_27 = arith.constant dense<0> : tensor<1x1xi8> loc(#loc131) + %tmp8_28 = arith.cmpi ne, %tmp8_25, %tmp8_27 : tensor<1x1xi8> loc(#loc131) + %tmp9 = arith.extui %tmp8_28 : tensor<1x1xi1> to tensor<1x1xi32> loc(#loc132) + %tmp10 = arith.extsi %tmp9 : tensor<1x1xi32> to tensor<1x1xi64> loc(#loc133) + %tmp12 = arith.muli %tmp10, %tmp11_16 : tensor<1x1xi64> loc(#loc134) + gpu.barrier loc(#loc49) + %14 = tt.splat %in_out_ptr0 : !tt.ptr -> tensor<1x1x!tt.ptr> loc(#loc50) + %15 = tt.addptr %14, %xindex_8 : tensor<1x1x!tt.ptr>, tensor<1x1xi64> loc(#loc50) + tt.store %15, %tmp12, %xmask_9 : tensor<1x1x!tt.ptr> loc(#loc51) + tt.return loc(#loc52) + } loc(#loc) + tt.func private @torch._inductor.runtime.triton_helpers.maximum_with_index__fp32S1_2048S_i64S1_2048S_fp32S1_2048S_i64S1_2048S__(%a_value: tensor<1x2048xf32> loc("a_value"(#loc53)), %a_index: tensor<1x2048xi64> loc("a_index"(#loc53)), %b_value: tensor<1x2048xf32> loc("b_value"(#loc53)), %b_index: tensor<1x2048xi64> loc("b_index"(#loc53))) -> (tensor<1x2048xf32>, tensor<1x2048xi64>) attributes {noinline = false} { + %mask = arith.cmpf ogt, %a_value, %b_value : tensor<1x2048xf32> loc(#loc156) + %equal = arith.cmpf oeq, %a_value, %b_value : tensor<1x2048xf32> loc(#loc157) + %0 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__fp32S1_2048S__(%a_value) : (tensor<1x2048xf32>) -> i1 loc(#loc56) + %1:2 = scf.if %0 -> (tensor<1x2048xi1>, tensor<1x2048xi1>) { + %a_isnan = arith.cmpf une, %a_value, %a_value : tensor<1x2048xf32> loc(#loc141) + %b_isnan = arith.cmpf une, %b_value, %b_value : tensor<1x2048xf32> loc(#loc142) + %mask_3 = arith.constant true loc(#loc143) + %mask_4 = arith.constant dense : tensor<1x2048xi1> loc(#loc143) + %mask_5 = arith.xori %b_isnan, %mask_4 : tensor<1x2048xi1> loc(#loc143) + %mask_6 = arith.andi %a_isnan, %mask_5 : tensor<1x2048xi1> loc(#loc144) + %mask_7 = arith.ori %mask, %mask_6 : tensor<1x2048xi1> loc(#loc158) + %equal_8 = arith.andi %a_isnan, %b_isnan : tensor<1x2048xi1> loc(#loc146) + %equal_9 = arith.ori %equal, %equal_8 : tensor<1x2048xi1> loc(#loc159) + scf.yield %mask_7, %equal_9 : tensor<1x2048xi1>, tensor<1x2048xi1> loc(#loc159) + } else { + scf.yield %mask, %equal : tensor<1x2048xi1>, tensor<1x2048xi1> loc(#loc65) + } loc(#loc57) + %mask_0 = arith.cmpi slt, %a_index, %b_index : tensor<1x2048xi64> loc(#loc148) + %mask_1 = arith.andi %1#1, %mask_0 : tensor<1x2048xi1> loc(#loc149) + %mask_2 = arith.ori %1#0, %mask_1 : tensor<1x2048xi1> loc(#loc150) + %2 = arith.select %mask_2, %a_value, %b_value : tensor<1x2048xi1>, tensor<1x2048xf32> loc(#loc69) + %3 = arith.select %mask_2, %a_index, %b_index : tensor<1x2048xi1>, tensor<1x2048xi64> loc(#loc70) + tt.return %2, %3 : tensor<1x2048xf32>, tensor<1x2048xi64> loc(#loc71) + ^bb1: // no predecessors + %4 = ub.poison : tensor<1x2048xf32> loc(#loc72) + %5 = ub.poison : tensor<1x2048xi64> loc(#loc72) + tt.return %4, %5 : tensor<1x2048xf32>, tensor<1x2048xi64> loc(#loc72) + } loc(#loc53) + tt.func private @torch._inductor.runtime.triton_helpers.is_floating__fp32S1_2048S__(%x: tensor<1x2048xf32> loc("x"(#loc73))) -> i1 attributes {noinline = false} { + %0 = tt.call @torch._inductor.runtime.triton_helpers.promote_to_tensor__fp32S1_2048S__(%x) : (tensor<1x2048xf32>) -> tensor<1x2048xf32> loc(#loc74) + %true = arith.constant true loc(#loc75) + tt.return %true : i1 loc(#loc75) + ^bb1: // no predecessors + %1 = ub.poison : i1 loc(#loc76) + tt.return %1 : i1 loc(#loc76) + } loc(#loc73) + tt.func private @torch._inductor.runtime.triton_helpers.promote_to_tensor__fp32S1_2048S__(%x: tensor<1x2048xf32> loc("x"(#loc77))) -> tensor<1x2048xf32> attributes {noinline = false} { + %0 = tt.call @"triton.language.standard.zeros____(0, 0)cconstexpr_1__(1,)cconstexpr_int1_"() : () -> tensor<1xi1> loc(#loc78) + %1 = arith.uitofp %0 : tensor<1xi1> to tensor<1xf32> loc(#loc79) + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<1xf32> -> tensor<1x1xf32> loc(#loc79) + %3 = tt.broadcast %2 : tensor<1x1xf32> -> tensor<1x2048xf32> loc(#loc79) + %4 = arith.addf %x, %3 : tensor<1x2048xf32> loc(#loc79) + tt.return %4 : tensor<1x2048xf32> loc(#loc80) + ^bb1: // no predecessors + %5 = ub.poison : tensor<1x2048xf32> loc(#loc81) + tt.return %5 : tensor<1x2048xf32> loc(#loc81) + } loc(#loc77) + tt.func private @"triton.language.standard.zeros____(0, 0)cconstexpr_1__(1,)cconstexpr_int1_"() -> tensor<1xi1> attributes {noinline = false} { + %false = arith.constant false loc(#loc83) + %cst = arith.constant dense : tensor<1xi1> loc(#loc83) + tt.return %cst : tensor<1xi1> loc(#loc84) + ^bb1: // no predecessors + %0 = ub.poison : tensor<1xi1> loc(#loc85) + tt.return %0 : tensor<1xi1> loc(#loc85) + } loc(#loc82) + tt.func private @"torch._inductor.runtime.triton_helpers.max_with_index__fp32S1_2048S_i64S1_2048S__(2,)cconstexpr_1_"(%value: tensor<1x2048xf32> loc("value"(#loc86)), %index: tensor<1x2048xi64> loc("index"(#loc86))) -> (tensor<1xf32>, tensor<1xi64>) attributes {noinline = false} { + %0:2 = "tt.reduce"(%value, %index) <{axis = 1 : i32}> ({ + ^bb0(%arg2: f32 loc(unknown), %arg3: i64 loc(unknown), %arg4: f32 loc(unknown), %arg5: i64 loc(unknown)): + %3:2 = tt.call @torch._inductor.runtime.triton_helpers.maximum_with_index__fp32_i64_fp32_i64__(%arg2, %arg3, %arg4, %arg5) : (f32, i64, f32, i64) -> (f32, i64) loc(#loc87) + tt.reduce.return %3#0, %3#1 : f32, i64 loc(#loc87) + }) : (tensor<1x2048xf32>, tensor<1x2048xi64>) -> (tensor<1xf32>, tensor<1xi64>) loc(#loc87) + tt.return %0#0, %0#1 : tensor<1xf32>, tensor<1xi64> loc(#loc88) + ^bb1: // no predecessors + %1 = ub.poison : tensor<1xf32> loc(#loc89) + %2 = ub.poison : tensor<1xi64> loc(#loc89) + tt.return %1, %2 : tensor<1xf32>, tensor<1xi64> loc(#loc89) + } loc(#loc86) + tt.func private @torch._inductor.runtime.triton_helpers.maximum_with_index__fp32_i64_fp32_i64__(%a_value: f32 loc("a_value"(#loc53)), %a_index: i64 loc("a_index"(#loc53)), %b_value: f32 loc("b_value"(#loc53)), %b_index: i64 loc("b_index"(#loc53))) -> (f32, i64) attributes {noinline = false} { + %mask = arith.cmpf ogt, %a_value, %b_value : f32 loc(#loc156) + %equal = arith.cmpf oeq, %a_value, %b_value : f32 loc(#loc157) + %0 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__fp32__(%a_value) : (f32) -> i1 loc(#loc56) + %1:2 = scf.if %0 -> (i1, i1) { + %a_isnan = arith.cmpf une, %a_value, %a_value : f32 loc(#loc141) + %b_isnan = arith.cmpf une, %b_value, %b_value : f32 loc(#loc142) + %mask_3 = arith.constant true loc(#loc143) + %mask_4 = arith.xori %b_isnan, %mask_3 : i1 loc(#loc143) + %mask_5 = arith.andi %a_isnan, %mask_4 : i1 loc(#loc144) + %mask_6 = arith.ori %mask, %mask_5 : i1 loc(#loc158) + %equal_7 = arith.andi %a_isnan, %b_isnan : i1 loc(#loc146) + %equal_8 = arith.ori %equal, %equal_7 : i1 loc(#loc159) + scf.yield %mask_6, %equal_8 : i1, i1 loc(#loc159) + } else { + scf.yield %mask, %equal : i1, i1 loc(#loc65) + } loc(#loc57) + %mask_0 = arith.cmpi slt, %a_index, %b_index : i64 loc(#loc148) + %mask_1 = arith.andi %1#1, %mask_0 : i1 loc(#loc149) + %mask_2 = arith.ori %1#0, %mask_1 : i1 loc(#loc150) + %2 = arith.select %mask_2, %a_value, %b_value : f32 loc(#loc69) + %3 = arith.select %mask_2, %a_index, %b_index : i64 loc(#loc70) + tt.return %2, %3 : f32, i64 loc(#loc71) + ^bb1: // no predecessors + %4 = ub.poison : f32 loc(#loc72) + %5 = ub.poison : i64 loc(#loc72) + tt.return %4, %5 : f32, i64 loc(#loc72) + } loc(#loc53) + tt.func private @torch._inductor.runtime.triton_helpers.is_floating__fp32__(%x: f32 loc("x"(#loc73))) -> i1 attributes {noinline = false} { + %0 = tt.call @torch._inductor.runtime.triton_helpers.promote_to_tensor__fp32__(%x) : (f32) -> tensor<1xf32> loc(#loc74) + %true = arith.constant true loc(#loc75) + tt.return %true : i1 loc(#loc75) + ^bb1: // no predecessors + %1 = ub.poison : i1 loc(#loc76) + tt.return %1 : i1 loc(#loc76) + } loc(#loc73) + tt.func private @torch._inductor.runtime.triton_helpers.promote_to_tensor__fp32__(%x: f32 loc("x"(#loc77))) -> tensor<1xf32> attributes {noinline = false} { + %0 = tt.call @"triton.language.standard.zeros____(0, 0)cconstexpr_1__(1,)cconstexpr_int1_"() : () -> tensor<1xi1> loc(#loc78) + %1 = arith.uitofp %0 : tensor<1xi1> to tensor<1xf32> loc(#loc79) + %2 = tt.splat %x : f32 -> tensor<1xf32> loc(#loc79) + %3 = arith.addf %2, %1 : tensor<1xf32> loc(#loc79) + tt.return %3 : tensor<1xf32> loc(#loc80) + ^bb1: // no predecessors + %4 = ub.poison : tensor<1xf32> loc(#loc81) + tt.return %4 : tensor<1xf32> loc(#loc81) + } loc(#loc77) +} loc(#loc) +#loc1 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":19:15) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":22:28) +#loc3 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":22:34) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":22:46) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":23:36) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":23:44) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":23:56) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":23:23) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":24:21) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":25:27) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":25:37) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":25:49) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":28:55) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":29:67) +#loc15 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":30:40) +#loc16 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":31:31) +#loc17 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":32:29) +#loc18 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":36:48) +#loc19 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":36:41) +#loc20 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":36:34) +#loc21 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":36:63) +#loc22 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":36:53) +#loc23 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":36:115) +#loc24 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":39:38) +#loc25 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":41:35) +#loc26 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":41:54) +#loc27 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":42:41) +#loc28 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":42:66) +#loc29 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":42:8) +#loc30 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":43:75) +#loc31 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":44:20) +#loc32 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":45:31) +#loc33 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":45:36) +#loc34 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":46:40) +#loc35 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":47:18) +#loc36 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":48:18) +#loc37 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":49:32) +#loc38 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":50:28) +#loc39 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":50:44) +#loc40 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":50:37) +#loc41 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":50:57) +#loc42 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":50:55) +#loc43 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":50:65) +#loc44 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":51:30) +#loc45 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":51:37) +#loc46 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":52:19) +#loc47 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":53:20) +#loc48 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":54:20) +#loc49 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":55:4) +#loc50 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":56:28) +#loc51 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":56:40) +#loc52 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":56:4) +#loc54 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":144:21) +#loc55 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":145:23) +#loc56 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":146:19) +#loc57 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":146:7) +#loc58 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":147:29) +#loc59 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":148:29) +#loc60 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":149:31) +#loc61 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":149:27) +#loc62 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":149:16) +#loc63 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":151:27) +#loc64 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":151:17) +#loc66 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":154:31) +#loc67 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":154:21) +#loc68 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":154:12) +#loc69 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":155:35) +#loc70 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":155:69) +#loc71 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":155:11) +#loc72 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":155:4) +#loc74 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":87:29) +#loc75 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":87:11) +#loc76 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":87:4) +#loc78 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":65:30) +#loc79 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":65:15) +#loc80 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":65:11) +#loc81 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":65:4) +#loc82 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":118:0) +#loc83 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":127:31) +#loc84 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":127:11) +#loc85 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":127:4) +#loc87 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":165:42) +#loc88 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":165:11) +#loc89 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":165:4) +#loc96 = loc("r0_numel"(#loc1)) +#loc97 = loc("xoffset"(#loc2)) +#loc98 = loc("xoffset"(#loc3)) +#loc99 = loc("xoffset"(#loc4)) +#loc100 = loc("xindex"(#loc5)) +#loc101 = loc("xindex"(#loc6)) +#loc102 = loc("xindex"(#loc7)) +#loc103 = loc("xindex"(#loc8)) +#loc104 = loc("xmask"(#loc9)) +#loc105 = loc("r0_base"(#loc10)) +#loc106 = loc("r0_base"(#loc11)) +#loc107 = loc("r0_base"(#loc12)) +#loc108 = loc("_tmp2"(#loc13)) +#loc109 = loc("_tmp2_index"(#loc14)) +#loc110 = loc("_tmp2"(#loc15)) +#loc111 = loc("r0_index"(#loc16)) +#loc112 = loc("r0_mask"(#loc17)) +#loc113 = loc("tmp0"(#loc18)) +#loc114 = loc("tmp0"(#loc19)) +#loc115 = loc("tmp0"(#loc20)) +#loc116 = loc("tmp0"(#loc21)) +#loc117 = loc("tmp0"(#loc22)) +#loc118 = loc("tmp0"(#loc23)) +#loc119 = loc("_tmp2"(#loc25)) +#loc120 = loc("_tmp2"(#loc26)) +#loc121 = loc("_tmp2_index"(#loc27)) +#loc122 = loc("_tmp2_index"(#loc28)) +#loc123 = loc("tmp2"(#loc31)) +#loc124 = loc("tmp11"(#loc32)) +#loc125 = loc("tmp11"(#loc33)) +#loc126 = loc("tmp3"(#loc34)) +#loc127 = loc("tmp4"(#loc35)) +#loc128 = loc("tmp5"(#loc36)) +#loc129 = loc("tmp6"(#loc37)) +#loc130 = loc("tmp8"(#loc44)) +#loc131 = loc("tmp8"(#loc45)) +#loc132 = loc("tmp9"(#loc46)) +#loc133 = loc("tmp10"(#loc47)) +#loc134 = loc("tmp12"(#loc48)) +#loc139 = loc("mask"(#loc54)) +#loc140 = loc("equal"(#loc55)) +#loc141 = loc("a_isnan"(#loc58)) +#loc142 = loc("b_isnan"(#loc59)) +#loc143 = loc("mask"(#loc60)) +#loc144 = loc("mask"(#loc61)) +#loc145 = loc("mask"(#loc62)) +#loc146 = loc("equal"(#loc63)) +#loc147 = loc("equal"(#loc64)) +#loc148 = loc("mask"(#loc66)) +#loc149 = loc("mask"(#loc67)) +#loc150 = loc("mask"(#loc68)) +#loc155 = loc("_tmp2_index"(#loc110)) +#loc156 = loc("mask"(#loc139)) +#loc157 = loc("equal"(#loc140)) +#loc158 = loc("mask"(#loc145)) +#loc159 = loc("equal"(#loc147)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/OU2OAVK633YHT43FEX32EDY5CTULBM4PI6AKNCY4Q5R3A3O236SA/triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0.ttgir b/SpecForge-ext/cache/compiled_kernels/triton/0/OU2OAVK633YHT43FEX32EDY5CTULBM4PI6AKNCY4Q5R3A3O236SA/triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0.ttgir new file mode 100644 index 0000000000000000000000000000000000000000..8f5c57d677a6cf90521cd99aa7bd86df3556a4ee --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/OU2OAVK633YHT43FEX32EDY5CTULBM4PI6AKNCY4Q5R3A3O236SA/triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0.ttgir @@ -0,0 +1,250 @@ +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 16], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [1, 16], order = [1, 0]}> +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":18:0) +#loc1 = loc(unknown) +#loc35 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":43:75) +#loc57 = loc("in_out_ptr0"(#loc)) +#loc58 = loc("in_ptr0"(#loc)) +#loc59 = loc("in_ptr1"(#loc)) +#loc60 = loc("in_ptr2"(#loc)) +#loc61 = loc("xnumel"(#loc)) +#loc62 = loc("r0_numel"(#loc)) +#loc94 = loc(callsite(#loc1 at #loc35)) +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 16 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(%in_out_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_out_ptr0"(#loc)), %in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %in_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr1"(#loc)), %in_ptr2: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr2"(#loc)), %xnumel: i64 loc("xnumel"(#loc)), %r0_numel: i64 {tt.divisibility = 16 : i32} loc("r0_numel"(#loc))) attributes {noinline = false} { + %cst = arith.constant dense<0> : tensor<1x1xi64, #blocked> loc(#loc1) + %cst_0 = arith.constant dense<151936> : tensor<1x1xi64, #blocked> loc(#loc1) + %cst_1 = arith.constant dense<151936> : tensor<1x1xi64, #blocked1> loc(#loc1) + %cst_2 = arith.constant dense<9223372036854775807> : tensor<1x2048xi64, #blocked1> loc(#loc1) + %cst_3 = arith.constant dense<0xFF800000> : tensor<1x2048xf32, #blocked1> loc(#loc1) + %cst_4 = arith.constant dense<151936> : tensor<1x2048xi64, #blocked1> loc(#loc1) + %cst_5 = arith.constant dense<0> : tensor<1x1xi8, #blocked> loc(#loc1) + %cst_6 = arith.constant dense<0.000000e+00> : tensor<1x2048xbf16, #blocked1> loc(#loc1) + %c0_i32 = arith.constant 0 : i32 loc(#loc1) + %c151936_i32 = arith.constant 151936 : i32 loc(#loc1) + %c2048_i32 = arith.constant 2048 : i32 loc(#loc1) + %c151936_i64 = arith.constant 151936 : i64 loc(#loc1) + %true = arith.constant true loc(#loc1) + %cst_7 = arith.constant dense : tensor<1x2048xi1, #blocked1> loc(#loc1) + %cst_8 = arith.constant dense<0> : tensor<1x1xi64, #blocked1> loc(#loc1) + %xoffset = tt.get_program_id x : i32 loc(#loc63) + %xoffset_9 = arith.extsi %xoffset : i32 to i64 loc(#loc64) + %xmask = arith.cmpi slt, %xoffset_9, %xnumel : i64 loc(#loc65) + %r0_base = tt.make_range {end = 2048 : i32, start = 0 : i32} : tensor<2048xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc66) + %r0_base_10 = tt.expand_dims %r0_base {axis = 0 : i32} : tensor<2048xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x2048xi32, #blocked1> loc(#loc66) + %r0_base_11 = arith.extsi %r0_base_10 : tensor<1x2048xi32, #blocked1> to tensor<1x2048xi64, #blocked1> loc(#loc67) + %tmp0 = arith.muli %xoffset_9, %c151936_i64 : i64 loc(#loc68) + %tmp0_12 = tt.splat %tmp0 : i64 -> tensor<1x2048xi64, #blocked1> loc(#loc106) + %tmp0_13 = tt.splat %in_ptr0 : !tt.ptr -> tensor<1x2048x!tt.ptr, #blocked1> loc(#loc70) + %tmp0_14 = tt.splat %xmask : i1 -> tensor<1x2048xi1, #blocked1> loc(#loc107) + %_tmp2_index:2 = scf.for %_tmp2_index_28 = %c0_i32 to %c151936_i32 step %c2048_i32 iter_args(%_tmp2 = %cst_3, %_tmp2_index_29 = %cst_2) -> (tensor<1x2048xf32, #blocked1>, tensor<1x2048xi64, #blocked1>) : i32 { + %r0_index = arith.extsi %_tmp2_index_28 : i32 to i64 loc(#loc73) + %r0_index_30 = tt.splat %r0_index : i64 -> tensor<1x2048xi64, #blocked1> loc(#loc73) + %r0_index_31 = arith.addi %r0_index_30, %r0_base_11 : tensor<1x2048xi64, #blocked1> loc(#loc73) + %r0_mask = arith.cmpi slt, %r0_index_31, %cst_4 : tensor<1x2048xi64, #blocked1> loc(#loc74) + %tmp0_32 = arith.addi %r0_index_31, %tmp0_12 : tensor<1x2048xi64, #blocked1> loc(#loc69) + %tmp0_33 = tt.addptr %tmp0_13, %tmp0_32 : tensor<1x2048x!tt.ptr, #blocked1>, tensor<1x2048xi64, #blocked1> loc(#loc70) + %tmp0_34 = arith.andi %r0_mask, %tmp0_14 : tensor<1x2048xi1, #blocked1> loc(#loc71) + %tmp0_35 = tt.load %tmp0_33, %tmp0_34, %cst_6 evictionPolicy = evict_first : tensor<1x2048x!tt.ptr, #blocked1> loc(#loc75) + %tmp0_36 = arith.extf %tmp0_35 : tensor<1x2048xbf16, #blocked1> to tensor<1x2048xf32, #blocked1> loc(#loc76) + %mask = arith.cmpf ogt, %_tmp2, %tmp0_36 : tensor<1x2048xf32, #blocked1> loc(#loc133) + %equal = arith.cmpf oeq, %_tmp2, %tmp0_36 : tensor<1x2048xf32, #blocked1> loc(#loc134) + %a_isnan = arith.cmpf une, %_tmp2, %_tmp2 : tensor<1x2048xf32, #blocked1> loc(#loc111) + %b_isnan = arith.cmpf une, %tmp0_36, %tmp0_36 : tensor<1x2048xf32, #blocked1> loc(#loc112) + %mask_37 = arith.xori %b_isnan, %cst_7 : tensor<1x2048xi1, #blocked1> loc(#loc113) + %mask_38 = arith.andi %a_isnan, %mask_37 : tensor<1x2048xi1, #blocked1> loc(#loc114) + %mask_39 = arith.ori %mask, %mask_38 : tensor<1x2048xi1, #blocked1> loc(#loc135) + %equal_40 = arith.andi %a_isnan, %b_isnan : tensor<1x2048xi1, #blocked1> loc(#loc116) + %equal_41 = arith.ori %equal, %equal_40 : tensor<1x2048xi1, #blocked1> loc(#loc136) + %mask_42 = arith.cmpi slt, %_tmp2_index_29, %r0_index_31 : tensor<1x2048xi64, #blocked1> loc(#loc118) + %mask_43 = arith.andi %equal_41, %mask_42 : tensor<1x2048xi1, #blocked1> loc(#loc119) + %mask_44 = arith.ori %mask_39, %mask_43 : tensor<1x2048xi1, #blocked1> loc(#loc120) + %8 = arith.select %mask_44, %_tmp2, %tmp0_36 : tensor<1x2048xi1, #blocked1>, tensor<1x2048xf32, #blocked1> loc(#loc89) + %9 = arith.select %mask_44, %_tmp2_index_29, %r0_index_31 : tensor<1x2048xi1, #blocked1>, tensor<1x2048xi64, #blocked1> loc(#loc90) + %_tmp2_45 = arith.select %tmp0_34, %8, %_tmp2 : tensor<1x2048xi1, #blocked1>, tensor<1x2048xf32, #blocked1> loc(#loc91) + %_tmp2_index_46 = arith.select %tmp0_34, %9, %_tmp2_index_29 : tensor<1x2048xi1, #blocked1>, tensor<1x2048xi64, #blocked1> loc(#loc92) + scf.yield %_tmp2_45, %_tmp2_index_46 : tensor<1x2048xf32, #blocked1>, tensor<1x2048xi64, #blocked1> loc(#loc33) + } loc(#loc108) + %0:2 = "tt.reduce"(%_tmp2_index#0, %_tmp2_index#1) <{axis = 1 : i32}> ({ + ^bb0(%arg6: f32 loc(callsite(#loc1 at #loc35)), %arg7: i64 loc(callsite(#loc1 at #loc35)), %arg8: f32 loc(callsite(#loc1 at #loc35)), %arg9: i64 loc(callsite(#loc1 at #loc35))): + %mask = arith.cmpf ogt, %arg6, %arg8 : f32 loc(#loc137) + %equal = arith.cmpf oeq, %arg6, %arg8 : f32 loc(#loc138) + %a_isnan = arith.cmpf une, %arg6, %arg6 : f32 loc(#loc121) + %b_isnan = arith.cmpf une, %arg8, %arg8 : f32 loc(#loc122) + %mask_28 = arith.xori %b_isnan, %true : i1 loc(#loc123) + %mask_29 = arith.andi %a_isnan, %mask_28 : i1 loc(#loc124) + %mask_30 = arith.ori %mask, %mask_29 : i1 loc(#loc139) + %equal_31 = arith.andi %a_isnan, %b_isnan : i1 loc(#loc125) + %equal_32 = arith.ori %equal, %equal_31 : i1 loc(#loc140) + %mask_33 = arith.cmpi slt, %arg7, %arg9 : i64 loc(#loc126) + %mask_34 = arith.andi %equal_32, %mask_33 : i1 loc(#loc127) + %mask_35 = arith.ori %mask_30, %mask_34 : i1 loc(#loc128) + %8 = arith.select %mask_35, %arg6, %arg8 : f32 loc(#loc129) + %9 = arith.select %mask_35, %arg7, %arg9 : i64 loc(#loc130) + tt.reduce.return %8, %9 : f32, i64 loc(#loc93) + }) : (tensor<1x2048xf32, #blocked1>, tensor<1x2048xi64, #blocked1>) -> (tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<1xi64, #ttg.slice<{dim = 1, parent = #blocked1}>>) loc(#loc93) + %tmp8 = ttg.convert_layout %0#1 : tensor<1xi64, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<1xi64, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc95) + %tmp2 = tt.expand_dims %tmp8 {axis = 1 : i32} : tensor<1xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<1x1xi64, #blocked> loc(#loc96) + %tmp2_15 = tt.expand_dims %0#1 {axis = 1 : i32} : tensor<1xi64, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<1x1xi64, #blocked1> loc(#loc96) + %tmp11 = tt.addptr %in_ptr2, %xoffset_9 : !tt.ptr, i64 loc(#loc97) + %tmp11_16 = tt.splat %tmp11 : !tt.ptr -> tensor<1x1x!tt.ptr, #blocked> loc(#loc98) + %tmp11_17 = tt.splat %xmask : i1 -> tensor<1x1xi1, #blocked> loc(#loc98) + %tmp11_18 = tt.load %tmp11_16, %tmp11_17 evictionPolicy = evict_last : tensor<1x1x!tt.ptr, #blocked> loc(#loc98) + %tmp4 = arith.addi %tmp2, %cst_0 : tensor<1x1xi64, #blocked> loc(#loc99) + %tmp4_19 = arith.addi %tmp2_15, %cst_1 : tensor<1x1xi64, #blocked1> loc(#loc99) + %tmp5 = arith.cmpi slt, %tmp2, %cst : tensor<1x1xi64, #blocked> loc(#loc100) + %tmp5_20 = arith.cmpi slt, %tmp2_15, %cst_8 : tensor<1x1xi64, #blocked1> loc(#loc100) + %tmp6 = arith.select %tmp5, %tmp4, %tmp2 : tensor<1x1xi1, #blocked>, tensor<1x1xi64, #blocked> loc(#loc101) + %tmp6_21 = arith.select %tmp5_20, %tmp4_19, %tmp2_15 : tensor<1x1xi1, #blocked1>, tensor<1x1xi64, #blocked1> loc(#loc101) + %1 = arith.cmpi sge, %tmp6_21, %cst_8 : tensor<1x1xi64, #blocked1> loc(#loc43) + %2 = arith.cmpi slt, %tmp6_21, %cst_1 : tensor<1x1xi64, #blocked1> loc(#loc44) + %3 = arith.andi %1, %2 : tensor<1x1xi1, #blocked1> loc(#loc45) + %xmask_22 = arith.cmpi sge, %xoffset_9, %xnumel : i64 loc(#loc131) + %4 = tt.splat %xmask_22 : i1 -> tensor<1x1xi1, #blocked1> loc(#loc46) + %5 = arith.ori %3, %4 : tensor<1x1xi1, #blocked1> loc(#loc47) + tt.assert %5, "index out of bounds: 0 <= tmp6 < 151936" : tensor<1x1xi1, #blocked1> loc(#loc48) + %tmp8_23 = tt.splat %in_ptr1 : !tt.ptr -> tensor<1x1x!tt.ptr, #blocked> loc(#loc102) + %tmp8_24 = tt.addptr %tmp8_23, %tmp6 : tensor<1x1x!tt.ptr, #blocked>, tensor<1x1xi64, #blocked> loc(#loc102) + %tmp8_25 = tt.bitcast %tmp8_24 : tensor<1x1x!tt.ptr, #blocked> -> tensor<1x1x!tt.ptr, #blocked> loc(#loc95) + %tmp8_26 = tt.load %tmp8_25, %tmp11_17 evictionPolicy = evict_last : tensor<1x1x!tt.ptr, #blocked> loc(#loc95) + %tmp8_27 = arith.cmpi ne, %tmp8_26, %cst_5 : tensor<1x1xi8, #blocked> loc(#loc95) + %tmp10 = arith.extui %tmp8_27 : tensor<1x1xi1, #blocked> to tensor<1x1xi64, #blocked> loc(#loc132) + %tmp12 = arith.muli %tmp10, %tmp11_18 : tensor<1x1xi64, #blocked> loc(#loc105) + gpu.barrier loc(#loc53) + %6 = tt.addptr %in_out_ptr0, %xoffset_9 : !tt.ptr, i64 loc(#loc54) + %7 = tt.splat %6 : !tt.ptr -> tensor<1x1x!tt.ptr, #blocked> loc(#loc55) + tt.store %7, %tmp12, %tmp11_17 : tensor<1x1x!tt.ptr, #blocked> loc(#loc55) + tt.return loc(#loc56) + } loc(#loc) +} loc(#loc) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":22:28) +#loc3 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":22:34) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":24:21) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":25:37) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":25:49) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":36:48) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":36:41) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":36:34) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":36:63) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":30:40) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":31:31) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":32:29) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":36:53) +#loc15 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":36:115) +#loc16 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":144:21) +#loc17 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":39:38) +#loc18 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":145:23) +#loc19 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":147:29) +#loc20 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":148:29) +#loc21 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":149:31) +#loc22 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":149:27) +#loc23 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":149:16) +#loc24 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":151:27) +#loc25 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":151:17) +#loc26 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":154:31) +#loc27 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":154:21) +#loc28 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":154:12) +#loc29 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":155:35) +#loc30 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":155:69) +#loc31 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":41:54) +#loc32 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":42:66) +#loc33 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":42:8) +#loc34 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":165:42) +#loc36 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":51:37) +#loc37 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":44:20) +#loc38 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":45:31) +#loc39 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":45:36) +#loc40 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":47:18) +#loc41 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":48:18) +#loc42 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":49:32) +#loc43 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":50:28) +#loc44 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":50:44) +#loc45 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":50:37) +#loc46 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":50:57) +#loc47 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":50:55) +#loc48 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":50:65) +#loc49 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":51:30) +#loc50 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":53:20) +#loc51 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":52:19) +#loc52 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":54:20) +#loc53 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":55:4) +#loc54 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":56:28) +#loc55 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":56:40) +#loc56 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":56:4) +#loc63 = loc("xoffset"(#loc2)) +#loc64 = loc("xoffset"(#loc3)) +#loc65 = loc("xmask"(#loc4)) +#loc66 = loc("r0_base"(#loc5)) +#loc67 = loc("r0_base"(#loc6)) +#loc68 = loc("tmp0"(#loc7)) +#loc69 = loc("tmp0"(#loc8)) +#loc70 = loc("tmp0"(#loc9)) +#loc71 = loc("tmp0"(#loc10)) +#loc72 = loc("_tmp2"(#loc11)) +#loc73 = loc("r0_index"(#loc12)) +#loc74 = loc("r0_mask"(#loc13)) +#loc75 = loc("tmp0"(#loc14)) +#loc76 = loc("tmp0"(#loc15)) +#loc77 = loc("mask"(#loc16)) +#loc78 = loc("equal"(#loc18)) +#loc79 = loc("a_isnan"(#loc19)) +#loc80 = loc("b_isnan"(#loc20)) +#loc81 = loc("mask"(#loc21)) +#loc82 = loc("mask"(#loc22)) +#loc83 = loc("mask"(#loc23)) +#loc84 = loc("equal"(#loc24)) +#loc85 = loc("equal"(#loc25)) +#loc86 = loc("mask"(#loc26)) +#loc87 = loc("mask"(#loc27)) +#loc88 = loc("mask"(#loc28)) +#loc89 = loc(callsite(#loc29 at #loc17)) +#loc90 = loc(callsite(#loc30 at #loc17)) +#loc91 = loc("_tmp2"(#loc31)) +#loc92 = loc("_tmp2_index"(#loc32)) +#loc93 = loc(callsite(#loc34 at #loc35)) +#loc95 = loc("tmp8"(#loc36)) +#loc96 = loc("tmp2"(#loc37)) +#loc97 = loc("tmp11"(#loc38)) +#loc98 = loc("tmp11"(#loc39)) +#loc99 = loc("tmp4"(#loc40)) +#loc100 = loc("tmp5"(#loc41)) +#loc101 = loc("tmp6"(#loc42)) +#loc102 = loc("tmp8"(#loc49)) +#loc103 = loc("tmp10"(#loc50)) +#loc104 = loc("tmp9"(#loc51)) +#loc105 = loc("tmp12"(#loc52)) +#loc106 = loc(fused[#loc69, #loc68]) +#loc107 = loc(fused[#loc71, #loc65]) +#loc108 = loc("_tmp2_index"(#loc72)) +#loc109 = loc("mask"(#loc77)) +#loc110 = loc("equal"(#loc78)) +#loc111 = loc(callsite(#loc79 at #loc17)) +#loc112 = loc(callsite(#loc80 at #loc17)) +#loc113 = loc(callsite(#loc81 at #loc17)) +#loc114 = loc(callsite(#loc82 at #loc17)) +#loc115 = loc("mask"(#loc83)) +#loc116 = loc(callsite(#loc84 at #loc17)) +#loc117 = loc("equal"(#loc85)) +#loc118 = loc(callsite(#loc86 at #loc17)) +#loc119 = loc(callsite(#loc87 at #loc17)) +#loc120 = loc(callsite(#loc88 at #loc17)) +#loc121 = loc(callsite(#loc79 at #loc93)) +#loc122 = loc(callsite(#loc80 at #loc93)) +#loc123 = loc(callsite(#loc81 at #loc93)) +#loc124 = loc(callsite(#loc82 at #loc93)) +#loc125 = loc(callsite(#loc84 at #loc93)) +#loc126 = loc(callsite(#loc86 at #loc93)) +#loc127 = loc(callsite(#loc87 at #loc93)) +#loc128 = loc(callsite(#loc88 at #loc93)) +#loc129 = loc(callsite(#loc29 at #loc93)) +#loc130 = loc(callsite(#loc30 at #loc93)) +#loc131 = loc(fused[#loc46, #loc65]) +#loc132 = loc(fused[#loc103, #loc104]) +#loc133 = loc(callsite(#loc109 at #loc17)) +#loc134 = loc(callsite(#loc110 at #loc17)) +#loc135 = loc(callsite(#loc115 at #loc17)) +#loc136 = loc(callsite(#loc117 at #loc17)) +#loc137 = loc(callsite(#loc109 at #loc93)) +#loc138 = loc(callsite(#loc110 at #loc93)) +#loc139 = loc(callsite(#loc115 at #loc93)) +#loc140 = loc(callsite(#loc117 at #loc93)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/OU2OAVK633YHT43FEX32EDY5CTULBM4PI6AKNCY4Q5R3A3O236SA/triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0.ttir b/SpecForge-ext/cache/compiled_kernels/triton/0/OU2OAVK633YHT43FEX32EDY5CTULBM4PI6AKNCY4Q5R3A3O236SA/triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0.ttir new file mode 100644 index 0000000000000000000000000000000000000000..fa2e9ad50edf18c1ab6a8b75a5b33e4a93c84908 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/OU2OAVK633YHT43FEX32EDY5CTULBM4PI6AKNCY4Q5R3A3O236SA/triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0.ttir @@ -0,0 +1,246 @@ +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":18:0) +#loc1 = loc(unknown) +#loc39 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":43:75) +#loc60 = loc("in_out_ptr0"(#loc)) +#loc61 = loc("in_ptr0"(#loc)) +#loc62 = loc("in_ptr1"(#loc)) +#loc63 = loc("in_ptr2"(#loc)) +#loc64 = loc("xnumel"(#loc)) +#loc65 = loc("r0_numel"(#loc)) +#loc101 = loc(callsite(#loc1 at #loc39)) +module { + tt.func public @triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(%in_out_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_out_ptr0"(#loc)), %in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %in_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr1"(#loc)), %in_ptr2: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr2"(#loc)), %xnumel: i64 loc("xnumel"(#loc)), %r0_numel: i64 {tt.divisibility = 16 : i32} loc("r0_numel"(#loc))) attributes {noinline = false} { + %c151936_i64 = arith.constant 151936 : i64 loc(#loc1) + %true = arith.constant true loc(#loc1) + %cst = arith.constant dense : tensor<1x2048xi1> loc(#loc1) + %cst_0 = arith.constant dense<0> : tensor<1x1xi64> loc(#loc1) + %cst_1 = arith.constant dense<0.000000e+00> : tensor<1x2048xbf16> loc(#loc1) + %cst_2 = arith.constant dense<151936> : tensor<1x2048xi64> loc(#loc1) + %c2048_i32 = arith.constant 2048 : i32 loc(#loc2) + %c151936_i32 = arith.constant 151936 : i32 loc(#loc2) + %c0_i32 = arith.constant 0 : i32 loc(#loc2) + %tmp8 = arith.constant dense<0> : tensor<1x1xi8> loc(#loc66) + %cst_3 = arith.constant dense<151936> : tensor<1x1xi64> loc(#loc1) + %_tmp2_index = arith.constant dense<9223372036854775807> : tensor<1x2048xi64> loc(#loc67) + %_tmp2 = arith.constant dense<0xFF800000> : tensor<1x2048xf32> loc(#loc68) + %xoffset = tt.get_program_id x : i32 loc(#loc69) + %xoffset_4 = arith.extsi %xoffset : i32 to i64 loc(#loc70) + %xmask = arith.cmpi slt, %xoffset_4, %xnumel : i64 loc(#loc71) + %xmask_5 = tt.splat %xmask : i1 -> tensor<1x1xi1> loc(#loc71) + %r0_base = tt.make_range {end = 2048 : i32, start = 0 : i32} : tensor<2048xi32> loc(#loc72) + %r0_base_6 = tt.expand_dims %r0_base {axis = 0 : i32} : tensor<2048xi32> -> tensor<1x2048xi32> loc(#loc73) + %r0_base_7 = arith.extsi %r0_base_6 : tensor<1x2048xi32> to tensor<1x2048xi64> loc(#loc74) + %_tmp2_index_8:2 = scf.for %r0_offset = %c0_i32 to %c151936_i32 step %c2048_i32 iter_args(%_tmp2_16 = %_tmp2, %_tmp2_index_17 = %_tmp2_index) -> (tensor<1x2048xf32>, tensor<1x2048xi64>) : i32 { + %r0_index = arith.extsi %r0_offset : i32 to i64 loc(#loc76) + %r0_index_18 = tt.splat %r0_index : i64 -> tensor<1x2048xi64> loc(#loc76) + %r0_index_19 = arith.addi %r0_index_18, %r0_base_7 : tensor<1x2048xi64> loc(#loc76) + %r0_mask = arith.cmpi slt, %r0_index_19, %cst_2 : tensor<1x2048xi64> loc(#loc77) + %tmp0 = arith.muli %xoffset_4, %c151936_i64 : i64 loc(#loc78) + %tmp0_20 = tt.splat %tmp0 : i64 -> tensor<1x2048xi64> loc(#loc113) + %tmp0_21 = arith.addi %r0_index_19, %tmp0_20 : tensor<1x2048xi64> loc(#loc79) + %tmp0_22 = tt.splat %in_ptr0 : !tt.ptr -> tensor<1x2048x!tt.ptr> loc(#loc80) + %tmp0_23 = tt.addptr %tmp0_22, %tmp0_21 : tensor<1x2048x!tt.ptr>, tensor<1x2048xi64> loc(#loc80) + %tmp0_24 = tt.splat %xmask : i1 -> tensor<1x2048xi1> loc(#loc114) + %tmp0_25 = arith.andi %r0_mask, %tmp0_24 : tensor<1x2048xi1> loc(#loc81) + %tmp0_26 = tt.load %tmp0_23, %tmp0_25, %cst_1 evictionPolicy = evict_first : tensor<1x2048x!tt.ptr> loc(#loc82) + %tmp0_27 = arith.extf %tmp0_26 : tensor<1x2048xbf16> to tensor<1x2048xf32> loc(#loc83) + %mask = arith.cmpf ogt, %_tmp2_16, %tmp0_27 : tensor<1x2048xf32> loc(#loc138) + %equal = arith.cmpf oeq, %_tmp2_16, %tmp0_27 : tensor<1x2048xf32> loc(#loc139) + %a_isnan = arith.cmpf une, %_tmp2_16, %_tmp2_16 : tensor<1x2048xf32> loc(#loc117) + %b_isnan = arith.cmpf une, %tmp0_27, %tmp0_27 : tensor<1x2048xf32> loc(#loc118) + %mask_28 = arith.xori %b_isnan, %cst : tensor<1x2048xi1> loc(#loc119) + %mask_29 = arith.andi %a_isnan, %mask_28 : tensor<1x2048xi1> loc(#loc120) + %mask_30 = arith.ori %mask, %mask_29 : tensor<1x2048xi1> loc(#loc140) + %equal_31 = arith.andi %a_isnan, %b_isnan : tensor<1x2048xi1> loc(#loc122) + %equal_32 = arith.ori %equal, %equal_31 : tensor<1x2048xi1> loc(#loc141) + %mask_33 = arith.cmpi slt, %_tmp2_index_17, %r0_index_19 : tensor<1x2048xi64> loc(#loc124) + %mask_34 = arith.andi %equal_32, %mask_33 : tensor<1x2048xi1> loc(#loc125) + %mask_35 = arith.ori %mask_30, %mask_34 : tensor<1x2048xi1> loc(#loc126) + %9 = arith.select %mask_35, %_tmp2_16, %tmp0_27 : tensor<1x2048xi1>, tensor<1x2048xf32> loc(#loc96) + %10 = arith.select %mask_35, %_tmp2_index_17, %r0_index_19 : tensor<1x2048xi1>, tensor<1x2048xi64> loc(#loc97) + %_tmp2_36 = arith.select %tmp0_25, %9, %_tmp2_16 : tensor<1x2048xi1>, tensor<1x2048xf32> loc(#loc98) + %_tmp2_index_37 = arith.select %tmp0_25, %10, %_tmp2_index_17 : tensor<1x2048xi1>, tensor<1x2048xi64> loc(#loc99) + scf.yield %_tmp2_36, %_tmp2_index_37 : tensor<1x2048xf32>, tensor<1x2048xi64> loc(#loc37) + } loc(#loc112) + %0:2 = "tt.reduce"(%_tmp2_index_8#0, %_tmp2_index_8#1) <{axis = 1 : i32}> ({ + ^bb0(%arg6: f32 loc(callsite(#loc1 at #loc39)), %arg7: i64 loc(callsite(#loc1 at #loc39)), %arg8: f32 loc(callsite(#loc1 at #loc39)), %arg9: i64 loc(callsite(#loc1 at #loc39))): + %mask = arith.cmpf ogt, %arg6, %arg8 : f32 loc(#loc142) + %equal = arith.cmpf oeq, %arg6, %arg8 : f32 loc(#loc143) + %a_isnan = arith.cmpf une, %arg6, %arg6 : f32 loc(#loc127) + %b_isnan = arith.cmpf une, %arg8, %arg8 : f32 loc(#loc128) + %mask_16 = arith.xori %b_isnan, %true : i1 loc(#loc129) + %mask_17 = arith.andi %a_isnan, %mask_16 : i1 loc(#loc130) + %mask_18 = arith.ori %mask, %mask_17 : i1 loc(#loc144) + %equal_19 = arith.andi %a_isnan, %b_isnan : i1 loc(#loc131) + %equal_20 = arith.ori %equal, %equal_19 : i1 loc(#loc145) + %mask_21 = arith.cmpi slt, %arg7, %arg9 : i64 loc(#loc132) + %mask_22 = arith.andi %equal_20, %mask_21 : i1 loc(#loc133) + %mask_23 = arith.ori %mask_18, %mask_22 : i1 loc(#loc134) + %9 = arith.select %mask_23, %arg6, %arg8 : f32 loc(#loc135) + %10 = arith.select %mask_23, %arg7, %arg9 : i64 loc(#loc136) + tt.reduce.return %9, %10 : f32, i64 loc(#loc100) + }) : (tensor<1x2048xf32>, tensor<1x2048xi64>) -> (tensor<1xf32>, tensor<1xi64>) loc(#loc100) + %tmp2 = tt.expand_dims %0#1 {axis = 1 : i32} : tensor<1xi64> -> tensor<1x1xi64> loc(#loc102) + %tmp11 = tt.addptr %in_ptr2, %xoffset_4 : !tt.ptr, i64 loc(#loc103) + %tmp11_9 = tt.splat %tmp11 : !tt.ptr -> tensor<1x1x!tt.ptr> loc(#loc103) + %tmp11_10 = tt.load %tmp11_9, %xmask_5 evictionPolicy = evict_last : tensor<1x1x!tt.ptr> loc(#loc104) + %tmp4 = arith.addi %tmp2, %cst_3 : tensor<1x1xi64> loc(#loc105) + %tmp5 = arith.cmpi slt, %tmp2, %cst_0 : tensor<1x1xi64> loc(#loc106) + %tmp6 = arith.select %tmp5, %tmp4, %tmp2 : tensor<1x1xi1>, tensor<1x1xi64> loc(#loc107) + %1 = arith.cmpi sge, %tmp6, %cst_0 : tensor<1x1xi64> loc(#loc46) + %2 = arith.cmpi slt, %tmp6, %cst_3 : tensor<1x1xi64> loc(#loc47) + %3 = arith.andi %1, %2 : tensor<1x1xi1> loc(#loc48) + %4 = arith.xori %xmask, %true : i1 loc(#loc49) + %5 = tt.splat %4 : i1 -> tensor<1x1xi1> loc(#loc49) + %6 = arith.ori %3, %5 : tensor<1x1xi1> loc(#loc50) + tt.assert %6, "index out of bounds: 0 <= tmp6 < 151936" : tensor<1x1xi1> loc(#loc51) + %tmp8_11 = tt.splat %in_ptr1 : !tt.ptr -> tensor<1x1x!tt.ptr> loc(#loc108) + %tmp8_12 = tt.addptr %tmp8_11, %tmp6 : tensor<1x1x!tt.ptr>, tensor<1x1xi64> loc(#loc108) + %tmp8_13 = tt.bitcast %tmp8_12 : tensor<1x1x!tt.ptr> -> tensor<1x1x!tt.ptr> loc(#loc66) + %tmp8_14 = tt.load %tmp8_13, %xmask_5 evictionPolicy = evict_last : tensor<1x1x!tt.ptr> loc(#loc66) + %tmp8_15 = arith.cmpi ne, %tmp8_14, %tmp8 : tensor<1x1xi8> loc(#loc66) + %tmp10 = arith.extui %tmp8_15 : tensor<1x1xi1> to tensor<1x1xi64> loc(#loc137) + %tmp12 = arith.muli %tmp10, %tmp11_10 : tensor<1x1xi64> loc(#loc111) + gpu.barrier loc(#loc56) + %7 = tt.addptr %in_out_ptr0, %xoffset_4 : !tt.ptr, i64 loc(#loc57) + %8 = tt.splat %7 : !tt.ptr -> tensor<1x1x!tt.ptr> loc(#loc57) + tt.store %8, %tmp12, %xmask_5 : tensor<1x1x!tt.ptr> loc(#loc58) + tt.return loc(#loc59) + } loc(#loc) +} loc(#loc) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":30:40) +#loc3 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":51:37) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":29:67) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":28:55) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":22:28) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":22:34) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":24:21) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":25:27) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":25:37) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":25:49) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":31:31) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":32:29) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":36:48) +#loc15 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":36:41) +#loc16 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":36:34) +#loc17 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":36:63) +#loc18 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":36:53) +#loc19 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":36:115) +#loc20 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":144:21) +#loc21 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":39:38) +#loc22 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":145:23) +#loc23 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":147:29) +#loc24 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":148:29) +#loc25 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":149:31) +#loc26 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":149:27) +#loc27 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":149:16) +#loc28 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":151:27) +#loc29 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":151:17) +#loc30 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":154:31) +#loc31 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":154:21) +#loc32 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":154:12) +#loc33 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":155:35) +#loc34 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":155:69) +#loc35 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":41:54) +#loc36 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":42:66) +#loc37 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":42:8) +#loc38 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":165:42) +#loc40 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":44:20) +#loc41 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":45:31) +#loc42 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":45:36) +#loc43 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":47:18) +#loc44 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":48:18) +#loc45 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":49:32) +#loc46 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":50:28) +#loc47 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":50:44) +#loc48 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":50:37) +#loc49 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":50:57) +#loc50 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":50:55) +#loc51 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":50:65) +#loc52 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":51:30) +#loc53 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":53:20) +#loc54 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":52:19) +#loc55 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":54:20) +#loc56 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":55:4) +#loc57 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":56:28) +#loc58 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":56:40) +#loc59 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py":56:4) +#loc66 = loc("tmp8"(#loc3)) +#loc67 = loc("_tmp2_index"(#loc4)) +#loc68 = loc("_tmp2"(#loc5)) +#loc69 = loc("xoffset"(#loc6)) +#loc70 = loc("xoffset"(#loc7)) +#loc71 = loc("xmask"(#loc8)) +#loc72 = loc("r0_base"(#loc9)) +#loc73 = loc("r0_base"(#loc10)) +#loc74 = loc("r0_base"(#loc11)) +#loc75 = loc("_tmp2"(#loc2)) +#loc76 = loc("r0_index"(#loc12)) +#loc77 = loc("r0_mask"(#loc13)) +#loc78 = loc("tmp0"(#loc14)) +#loc79 = loc("tmp0"(#loc15)) +#loc80 = loc("tmp0"(#loc16)) +#loc81 = loc("tmp0"(#loc17)) +#loc82 = loc("tmp0"(#loc18)) +#loc83 = loc("tmp0"(#loc19)) +#loc84 = loc("mask"(#loc20)) +#loc85 = loc("equal"(#loc22)) +#loc86 = loc("a_isnan"(#loc23)) +#loc87 = loc("b_isnan"(#loc24)) +#loc88 = loc("mask"(#loc25)) +#loc89 = loc("mask"(#loc26)) +#loc90 = loc("mask"(#loc27)) +#loc91 = loc("equal"(#loc28)) +#loc92 = loc("equal"(#loc29)) +#loc93 = loc("mask"(#loc30)) +#loc94 = loc("mask"(#loc31)) +#loc95 = loc("mask"(#loc32)) +#loc96 = loc(callsite(#loc33 at #loc21)) +#loc97 = loc(callsite(#loc34 at #loc21)) +#loc98 = loc("_tmp2"(#loc35)) +#loc99 = loc("_tmp2_index"(#loc36)) +#loc100 = loc(callsite(#loc38 at #loc39)) +#loc102 = loc("tmp2"(#loc40)) +#loc103 = loc("tmp11"(#loc41)) +#loc104 = loc("tmp11"(#loc42)) +#loc105 = loc("tmp4"(#loc43)) +#loc106 = loc("tmp5"(#loc44)) +#loc107 = loc("tmp6"(#loc45)) +#loc108 = loc("tmp8"(#loc52)) +#loc109 = loc("tmp10"(#loc53)) +#loc110 = loc("tmp9"(#loc54)) +#loc111 = loc("tmp12"(#loc55)) +#loc112 = loc("_tmp2_index"(#loc75)) +#loc113 = loc(fused[#loc79, #loc78]) +#loc114 = loc(fused[#loc81, #loc71]) +#loc115 = loc("mask"(#loc84)) +#loc116 = loc("equal"(#loc85)) +#loc117 = loc(callsite(#loc86 at #loc21)) +#loc118 = loc(callsite(#loc87 at #loc21)) +#loc119 = loc(callsite(#loc88 at #loc21)) +#loc120 = loc(callsite(#loc89 at #loc21)) +#loc121 = loc("mask"(#loc90)) +#loc122 = loc(callsite(#loc91 at #loc21)) +#loc123 = loc("equal"(#loc92)) +#loc124 = loc(callsite(#loc93 at #loc21)) +#loc125 = loc(callsite(#loc94 at #loc21)) +#loc126 = loc(callsite(#loc95 at #loc21)) +#loc127 = loc(callsite(#loc86 at #loc100)) +#loc128 = loc(callsite(#loc87 at #loc100)) +#loc129 = loc(callsite(#loc88 at #loc100)) +#loc130 = loc(callsite(#loc89 at #loc100)) +#loc131 = loc(callsite(#loc91 at #loc100)) +#loc132 = loc(callsite(#loc93 at #loc100)) +#loc133 = loc(callsite(#loc94 at #loc100)) +#loc134 = loc(callsite(#loc95 at #loc100)) +#loc135 = loc(callsite(#loc33 at #loc100)) +#loc136 = loc(callsite(#loc34 at #loc100)) +#loc137 = loc(fused[#loc109, #loc110]) +#loc138 = loc(callsite(#loc115 at #loc21)) +#loc139 = loc(callsite(#loc116 at #loc21)) +#loc140 = loc(callsite(#loc121 at #loc21)) +#loc141 = loc(callsite(#loc123 at #loc21)) +#loc142 = loc(callsite(#loc115 at #loc100)) +#loc143 = loc(callsite(#loc116 at #loc100)) +#loc144 = loc(callsite(#loc121 at #loc100)) +#loc145 = loc(callsite(#loc123 at #loc100)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/P3QZALUCGDQIGARFKKM54ROSVAIQRSFS3BCW42DE3EAUTSMCDEEQ/__grp__triton_red_fused_argmax_1.json b/SpecForge-ext/cache/compiled_kernels/triton/0/P3QZALUCGDQIGARFKKM54ROSVAIQRSFS3BCW42DE3EAUTSMCDEEQ/__grp__triton_red_fused_argmax_1.json new file mode 100644 index 0000000000000000000000000000000000000000..88899e4f5b5572e21ffd68ef5fccb1d04865b608 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/P3QZALUCGDQIGARFKKM54ROSVAIQRSFS3BCW42DE3EAUTSMCDEEQ/__grp__triton_red_fused_argmax_1.json @@ -0,0 +1 @@ +{"child_paths": {"triton_red_fused_argmax_1.source": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/P3QZALUCGDQIGARFKKM54ROSVAIQRSFS3BCW42DE3EAUTSMCDEEQ/triton_red_fused_argmax_1.source", "triton_red_fused_argmax_1.ttir": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/P3QZALUCGDQIGARFKKM54ROSVAIQRSFS3BCW42DE3EAUTSMCDEEQ/triton_red_fused_argmax_1.ttir", "triton_red_fused_argmax_1.ttgir": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/P3QZALUCGDQIGARFKKM54ROSVAIQRSFS3BCW42DE3EAUTSMCDEEQ/triton_red_fused_argmax_1.ttgir", "triton_red_fused_argmax_1.llir": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/P3QZALUCGDQIGARFKKM54ROSVAIQRSFS3BCW42DE3EAUTSMCDEEQ/triton_red_fused_argmax_1.llir", "triton_red_fused_argmax_1.ptx": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/P3QZALUCGDQIGARFKKM54ROSVAIQRSFS3BCW42DE3EAUTSMCDEEQ/triton_red_fused_argmax_1.ptx", "triton_red_fused_argmax_1.cubin": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/P3QZALUCGDQIGARFKKM54ROSVAIQRSFS3BCW42DE3EAUTSMCDEEQ/triton_red_fused_argmax_1.cubin", "triton_red_fused_argmax_1.json": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/P3QZALUCGDQIGARFKKM54ROSVAIQRSFS3BCW42DE3EAUTSMCDEEQ/triton_red_fused_argmax_1.json"}} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/P3QZALUCGDQIGARFKKM54ROSVAIQRSFS3BCW42DE3EAUTSMCDEEQ/triton_red_fused_argmax_1.cubin b/SpecForge-ext/cache/compiled_kernels/triton/0/P3QZALUCGDQIGARFKKM54ROSVAIQRSFS3BCW42DE3EAUTSMCDEEQ/triton_red_fused_argmax_1.cubin new file mode 100644 index 0000000000000000000000000000000000000000..367b18d354f818aee1724fc23a51dbc42af0c6ee Binary files /dev/null and b/SpecForge-ext/cache/compiled_kernels/triton/0/P3QZALUCGDQIGARFKKM54ROSVAIQRSFS3BCW42DE3EAUTSMCDEEQ/triton_red_fused_argmax_1.cubin differ diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/P3QZALUCGDQIGARFKKM54ROSVAIQRSFS3BCW42DE3EAUTSMCDEEQ/triton_red_fused_argmax_1.json b/SpecForge-ext/cache/compiled_kernels/triton/0/P3QZALUCGDQIGARFKKM54ROSVAIQRSFS3BCW42DE3EAUTSMCDEEQ/triton_red_fused_argmax_1.json new file mode 100644 index 0000000000000000000000000000000000000000..d96f7459ed11906ae82a550833d81a065d884dba --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/P3QZALUCGDQIGARFKKM54ROSVAIQRSFS3BCW42DE3EAUTSMCDEEQ/triton_red_fused_argmax_1.json @@ -0,0 +1 @@ +{"hash": "7ee1902e8230e08302255299de45d2a81108c8b2d8456e6864d90149c9821909", "target": {"backend": "cuda", "arch": 90, "warp_size": 32}, "num_warps": 16, "num_ctas": 1, "num_stages": 1, "warp_size": 32, "maxnreg": null, "cluster_dims": [1, 1, 1], "ptx_version": null, "ptx_options": null, "ir_override": null, "enable_fp_fusion": true, "launch_cooperative_grid": false, "launch_pdl": false, "supported_fp8_dtypes": ["fp8e4b15", "fp8e4nv", "fp8e5"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b15"], "default_dot_input_precision": "tf32", "allowed_dot_input_precisions": ["tf32", "tf32x3", "ieee"], "max_num_imprecise_acc_default": 1073741824, "extern_libs": [["libdevice", "/workspace/specforge/lib/python3.11/site-packages/triton/backends/nvidia/lib/libdevice.10.bc"]], "debug": true, "backend_name": "cuda", "sanitize_overflow": false, "arch": "sm90", "instrumentation_mode": "", "triton_version": "3.5.1", "tensordesc_meta": [], "shared": 1024, "tmem_size": 0, "global_scratch_size": 0, "global_scratch_align": 1, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "triton_red_fused_argmax_1"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/P3QZALUCGDQIGARFKKM54ROSVAIQRSFS3BCW42DE3EAUTSMCDEEQ/triton_red_fused_argmax_1.llir b/SpecForge-ext/cache/compiled_kernels/triton/0/P3QZALUCGDQIGARFKKM54ROSVAIQRSFS3BCW42DE3EAUTSMCDEEQ/triton_red_fused_argmax_1.llir new file mode 100644 index 0000000000000000000000000000000000000000..71b63b9fc009c88e9430207b9ddcecc9b0536971 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/P3QZALUCGDQIGARFKKM54ROSVAIQRSFS3BCW42DE3EAUTSMCDEEQ/triton_red_fused_argmax_1.llir @@ -0,0 +1,1166 @@ +; ModuleID = 'LLVMDialectModule' +source_filename = "LLVMDialectModule" +target datalayout = "e-p3:32:32-p4:32:32-p5:32:32-p6:32:32-p7:32:32-i64:64-i128:128-v16:16-v32:32-n16:32:64" + +@global_smem = external addrspace(3) global [0 x i8], align 16 + +; Function Attrs: nounwind +define ptx_kernel void @triton_red_fused_argmax_1(ptr addrspace(1) %0, ptr addrspace(1) %1, i64 %2, i64 %3, i32 %4, i32 %5, ptr addrspace(1) readnone captures(none) %6, ptr addrspace(1) readnone captures(none) %7) local_unnamed_addr #0 !dbg !4 { + %9 = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !dbg !7 + %10 = shl i32 %9, 6, !dbg !8 + %11 = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !dbg !9 + %12 = and i32 %11, 448, !dbg !9 + %13 = and i32 %11, 63, !dbg !9 + %14 = lshr exact i32 %12, 6, !dbg !9 + %15 = or disjoint i32 %14, 8, !dbg !9 + %16 = or disjoint i32 %14, 16, !dbg !9 + %17 = or disjoint i32 %14, 24, !dbg !9 + %18 = insertelement <4 x i32> poison, i32 %14, i64 0, !dbg !9 + %19 = shufflevector <4 x i32> %18, <4 x i32> poison, <4 x i32> zeroinitializer, !dbg !9 + %20 = or disjoint <4 x i32> %19, , !dbg !9 + %21 = insertelement <8 x i32> poison, i32 %17, i64 4, !dbg !10 + %22 = insertelement <8 x i32> %21, i32 %16, i64 5, !dbg !10 + %23 = insertelement <8 x i32> %22, i32 %15, i64 6, !dbg !10 + %24 = insertelement <8 x i32> %23, i32 %14, i64 7, !dbg !10 + %25 = shufflevector <4 x i32> %20, <4 x i32> poison, <8 x i32> , !dbg !10 + %26 = shufflevector <8 x i32> %25, <8 x i32> %24, <8 x i32> , !dbg !10 + %27 = insertelement <8 x i32> poison, i32 %10, i64 0, !dbg !10 + %28 = shufflevector <8 x i32> %27, <8 x i32> poison, <8 x i32> zeroinitializer, !dbg !10 + %29 = or disjoint <8 x i32> %26, %28, !dbg !10 + %30 = insertelement <8 x i32> poison, i32 %4, i64 0, !dbg !11 + %31 = shufflevector <8 x i32> %30, <8 x i32> poison, <8 x i32> zeroinitializer, !dbg !11 + %32 = icmp slt <8 x i32> %29, %31, !dbg !11 + %33 = extractelement <8 x i32> %29, i64 7, !dbg !12 + %34 = sext i32 %33 to i64, !dbg !12 + %35 = extractelement <8 x i32> %29, i64 6, !dbg !12 + %36 = sext i32 %35 to i64, !dbg !12 + %37 = extractelement <8 x i32> %29, i64 5, !dbg !12 + %38 = sext i32 %37 to i64, !dbg !12 + %39 = extractelement <8 x i32> %29, i64 4, !dbg !12 + %40 = sext i32 %39 to i64, !dbg !12 + %41 = extractelement <8 x i32> %29, i64 3, !dbg !12 + %42 = sext i32 %41 to i64, !dbg !12 + %43 = extractelement <8 x i32> %29, i64 2, !dbg !12 + %44 = sext i32 %43 to i64, !dbg !12 + %45 = extractelement <8 x i32> %29, i64 1, !dbg !12 + %46 = sext i32 %45 to i64, !dbg !12 + %47 = extractelement <8 x i32> %29, i64 0, !dbg !12 + %48 = sext i32 %47 to i64, !dbg !12 + %.frozen = freeze i64 %34, !dbg !13 + %.frozen70 = freeze i64 %2, !dbg !13 + %49 = sdiv i64 %.frozen, %.frozen70, !dbg !13 + %50 = mul i64 %49, %.frozen70, !dbg !12 + %.decomposed = sub i64 %.frozen, %50, !dbg !12 + %.frozen71 = freeze i64 %36, !dbg !13 + %.frozen72 = freeze i64 %2, !dbg !13 + %51 = sdiv i64 %.frozen71, %.frozen72, !dbg !13 + %52 = mul i64 %51, %.frozen72, !dbg !12 + %.decomposed73 = sub i64 %.frozen71, %52, !dbg !12 + %.frozen74 = freeze i64 %38, !dbg !13 + %.frozen75 = freeze i64 %2, !dbg !13 + %53 = sdiv i64 %.frozen74, %.frozen75, !dbg !13 + %54 = mul i64 %53, %.frozen75, !dbg !12 + %.decomposed76 = sub i64 %.frozen74, %54, !dbg !12 + %.frozen77 = freeze i64 %40, !dbg !13 + %.frozen78 = freeze i64 %2, !dbg !13 + %55 = sdiv i64 %.frozen77, %.frozen78, !dbg !13 + %56 = mul i64 %55, %.frozen78, !dbg !12 + %.decomposed79 = sub i64 %.frozen77, %56, !dbg !12 + %.frozen80 = freeze i64 %42, !dbg !13 + %.frozen81 = freeze i64 %2, !dbg !13 + %57 = sdiv i64 %.frozen80, %.frozen81, !dbg !13 + %58 = mul i64 %57, %.frozen81, !dbg !12 + %.decomposed82 = sub i64 %.frozen80, %58, !dbg !12 + %.frozen83 = freeze i64 %44, !dbg !13 + %.frozen84 = freeze i64 %2, !dbg !13 + %59 = sdiv i64 %.frozen83, %.frozen84, !dbg !13 + %60 = mul i64 %59, %.frozen84, !dbg !12 + %.decomposed85 = sub i64 %.frozen83, %60, !dbg !12 + %.frozen86 = freeze i64 %46, !dbg !13 + %.frozen87 = freeze i64 %2, !dbg !13 + %61 = sdiv i64 %.frozen86, %.frozen87, !dbg !13 + %62 = mul i64 %61, %.frozen87, !dbg !12 + %.decomposed88 = sub i64 %.frozen86, %62, !dbg !12 + %.frozen89 = freeze i64 %48, !dbg !13 + %.frozen90 = freeze i64 %2, !dbg !13 + %63 = sdiv i64 %.frozen89, %.frozen90, !dbg !13 + %64 = mul i64 %63, %.frozen90, !dbg !12 + %.decomposed91 = sub i64 %.frozen89, %64, !dbg !12 + %65 = mul i64 %49, %3, !dbg !14 + %66 = mul i64 %51, %3, !dbg !14 + %67 = mul i64 %53, %3, !dbg !14 + %68 = mul i64 %55, %3, !dbg !14 + %69 = mul i64 %57, %3, !dbg !14 + %70 = mul i64 %59, %3, !dbg !14 + %71 = mul i64 %61, %3, !dbg !14 + %72 = mul i64 %63, %3, !dbg !14 + %.idx = mul nsw i64 %.decomposed, 128000 + %73 = getelementptr i8, ptr addrspace(1) %0, i64 %.idx + %invariant.gep = getelementptr float, ptr addrspace(1) %73, i64 %65, !dbg !15 + %.idx1 = mul nsw i64 %.decomposed73, 128000 + %74 = getelementptr i8, ptr addrspace(1) %0, i64 %.idx1 + %invariant.gep9 = getelementptr float, ptr addrspace(1) %74, i64 %66, !dbg !15 + %.idx2 = mul nsw i64 %.decomposed76, 128000 + %75 = getelementptr i8, ptr addrspace(1) %0, i64 %.idx2 + %invariant.gep11 = getelementptr float, ptr addrspace(1) %75, i64 %67, !dbg !15 + %.idx3 = mul nsw i64 %.decomposed79, 128000 + %76 = getelementptr i8, ptr addrspace(1) %0, i64 %.idx3 + %invariant.gep13 = getelementptr float, ptr addrspace(1) %76, i64 %68, !dbg !15 + %.idx4 = mul nsw i64 %.decomposed82, 128000 + %77 = getelementptr i8, ptr addrspace(1) %0, i64 %.idx4 + %invariant.gep15 = getelementptr float, ptr addrspace(1) %77, i64 %69, !dbg !15 + %.idx5 = mul nsw i64 %.decomposed85, 128000 + %78 = getelementptr i8, ptr addrspace(1) %0, i64 %.idx5 + %invariant.gep17 = getelementptr float, ptr addrspace(1) %78, i64 %70, !dbg !15 + %.idx6 = mul nsw i64 %.decomposed88, 128000 + %79 = getelementptr i8, ptr addrspace(1) %0, i64 %.idx6 + %invariant.gep19 = getelementptr float, ptr addrspace(1) %79, i64 %71, !dbg !15 + %.idx7 = mul nsw i64 %.decomposed91, 128000 + %80 = getelementptr i8, ptr addrspace(1) %0, i64 %.idx7 + %invariant.gep21 = getelementptr float, ptr addrspace(1) %80, i64 %72, !dbg !15 + %81 = zext nneg i32 %13 to i64, !dbg !15 + %82 = extractelement <8 x i1> %32, i64 0, !dbg !16 + %83 = extractelement <8 x i1> %32, i64 1, !dbg !16 + %84 = extractelement <8 x i1> %32, i64 2, !dbg !16 + %85 = extractelement <8 x i1> %32, i64 3, !dbg !16 + %86 = extractelement <8 x i1> %32, i64 4, !dbg !16 + %87 = extractelement <8 x i1> %32, i64 5, !dbg !16 + %88 = extractelement <8 x i1> %32, i64 6, !dbg !16 + %89 = extractelement <8 x i1> %32, i64 7, !dbg !16 + br label %90, !dbg !15 + +90: ; preds = %8, %90 + %indvars.iv = phi i64 [ 0, %8 ], [ %indvars.iv.next, %90 ] + %91 = phi <8 x float> [ splat (float 0xFFF0000000000000), %8 ], [ %139, %90 ] + %92 = phi <8 x i32> [ splat (i32 2147483647), %8 ], [ %140, %90 ] + %93 = or disjoint i64 %indvars.iv, %81, !dbg !17 + %gep = getelementptr float, ptr addrspace(1) %invariant.gep, i64 %93, !dbg !18 + %gep10 = getelementptr float, ptr addrspace(1) %invariant.gep9, i64 %93, !dbg !18 + %gep12 = getelementptr float, ptr addrspace(1) %invariant.gep11, i64 %93, !dbg !18 + %gep14 = getelementptr float, ptr addrspace(1) %invariant.gep13, i64 %93, !dbg !18 + %gep16 = getelementptr float, ptr addrspace(1) %invariant.gep15, i64 %93, !dbg !18 + %gep18 = getelementptr float, ptr addrspace(1) %invariant.gep17, i64 %93, !dbg !18 + %gep20 = getelementptr float, ptr addrspace(1) %invariant.gep19, i64 %93, !dbg !18 + %gep22 = getelementptr float, ptr addrspace(1) %invariant.gep21, i64 %93, !dbg !18 + %94 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_first.b64 $0, 1.0;", "=l"() #4, !dbg !16 + %95 = tail call i32 asm sideeffect "mov.u32 $0, $1;\0A\09@$4 ld.global.L1::evict_first.L2::cache_hint.b32 { $0 }, [ $2 + 0 ], $3;", "=r,r,l,l,b"(i32 0, ptr addrspace(1) %gep, i64 %94, i1 %89) #4, !dbg !16 + %96 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_first.b64 $0, 1.0;", "=l"() #4, !dbg !16 + %97 = tail call i32 asm sideeffect "mov.u32 $0, $1;\0A\09@$4 ld.global.L1::evict_first.L2::cache_hint.b32 { $0 }, [ $2 + 0 ], $3;", "=r,r,l,l,b"(i32 0, ptr addrspace(1) %gep10, i64 %96, i1 %88) #4, !dbg !16 + %98 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_first.b64 $0, 1.0;", "=l"() #4, !dbg !16 + %99 = tail call i32 asm sideeffect "mov.u32 $0, $1;\0A\09@$4 ld.global.L1::evict_first.L2::cache_hint.b32 { $0 }, [ $2 + 0 ], $3;", "=r,r,l,l,b"(i32 0, ptr addrspace(1) %gep12, i64 %98, i1 %87) #4, !dbg !16 + %100 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_first.b64 $0, 1.0;", "=l"() #4, !dbg !16 + %101 = tail call i32 asm sideeffect "mov.u32 $0, $1;\0A\09@$4 ld.global.L1::evict_first.L2::cache_hint.b32 { $0 }, [ $2 + 0 ], $3;", "=r,r,l,l,b"(i32 0, ptr addrspace(1) %gep14, i64 %100, i1 %86) #4, !dbg !16 + %102 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_first.b64 $0, 1.0;", "=l"() #4, !dbg !16 + %103 = tail call i32 asm sideeffect "mov.u32 $0, $1;\0A\09@$4 ld.global.L1::evict_first.L2::cache_hint.b32 { $0 }, [ $2 + 0 ], $3;", "=r,r,l,l,b"(i32 0, ptr addrspace(1) %gep16, i64 %102, i1 %85) #4, !dbg !16 + %104 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_first.b64 $0, 1.0;", "=l"() #4, !dbg !16 + %105 = tail call i32 asm sideeffect "mov.u32 $0, $1;\0A\09@$4 ld.global.L1::evict_first.L2::cache_hint.b32 { $0 }, [ $2 + 0 ], $3;", "=r,r,l,l,b"(i32 0, ptr addrspace(1) %gep18, i64 %104, i1 %84) #4, !dbg !16 + %106 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_first.b64 $0, 1.0;", "=l"() #4, !dbg !16 + %107 = tail call i32 asm sideeffect "mov.u32 $0, $1;\0A\09@$4 ld.global.L1::evict_first.L2::cache_hint.b32 { $0 }, [ $2 + 0 ], $3;", "=r,r,l,l,b"(i32 0, ptr addrspace(1) %gep20, i64 %106, i1 %83) #4, !dbg !16 + %108 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_first.b64 $0, 1.0;", "=l"() #4, !dbg !16 + %109 = tail call i32 asm sideeffect "mov.u32 $0, $1;\0A\09@$4 ld.global.L1::evict_first.L2::cache_hint.b32 { $0 }, [ $2 + 0 ], $3;", "=r,r,l,l,b"(i32 0, ptr addrspace(1) %gep22, i64 %108, i1 %82) #4, !dbg !16 + %110 = fcmp uno <8 x float> %91, zeroinitializer, !dbg !19 + %111 = trunc nuw nsw i64 %93 to i32, !dbg !23 + %112 = insertelement <8 x i32> poison, i32 %109, i64 0, !dbg !16 + %113 = insertelement <8 x i32> %112, i32 %107, i64 1, !dbg !16 + %114 = insertelement <8 x i32> %113, i32 %105, i64 2, !dbg !16 + %115 = insertelement <8 x i32> %114, i32 %103, i64 3, !dbg !16 + %116 = insertelement <8 x i32> %115, i32 %101, i64 4, !dbg !16 + %117 = insertelement <8 x i32> %116, i32 %99, i64 5, !dbg !16 + %118 = insertelement <8 x i32> %117, i32 %97, i64 6, !dbg !16 + %119 = insertelement <8 x i32> %118, i32 %95, i64 7, !dbg !16 + %120 = bitcast <8 x i32> %119 to <8 x float>, !dbg !16 + %121 = fcmp ogt <8 x float> %91, %120, !dbg !24 + %122 = fcmp oeq <8 x float> %91, %120, !dbg !25 + %123 = fcmp uno <8 x float> %120, zeroinitializer, !dbg !26 + %124 = xor <8 x i1> %123, splat (i1 true), !dbg !27 + %125 = and <8 x i1> %110, %124, !dbg !28 + %126 = or <8 x i1> %121, %125, !dbg !29 + %127 = and <8 x i1> %110, %123, !dbg !30 + %128 = or <8 x i1> %122, %127, !dbg !31 + %129 = insertelement <8 x i64> poison, i64 %93, i64 0, !dbg !32 + %130 = shufflevector <8 x i64> %129, <8 x i64> poison, <8 x i32> zeroinitializer, !dbg !32 + %131 = sext <8 x i32> %92 to <8 x i64>, !dbg !32 + %132 = icmp sgt <8 x i64> %130, %131, !dbg !32 + %133 = and <8 x i1> %132, %128, !dbg !33 + %134 = or <8 x i1> %126, %133, !dbg !34 + %135 = select <8 x i1> %134, <8 x float> %91, <8 x float> %120, !dbg !35 + %136 = insertelement <8 x i32> poison, i32 %111, i64 0, !dbg !23 + %137 = shufflevector <8 x i32> %136, <8 x i32> poison, <8 x i32> zeroinitializer, !dbg !23 + %138 = select <8 x i1> %134, <8 x i32> %92, <8 x i32> %137, !dbg !23 + %139 = select <8 x i1> %32, <8 x float> %135, <8 x float> %91, !dbg !36 + %140 = select <8 x i1> %32, <8 x i32> %138, <8 x i32> %92, !dbg !37 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 64, !dbg !15 + %141 = icmp samesign ult i64 %indvars.iv, 31936, !dbg !15 + br i1 %141, label %90, label %142, !dbg !15 + +142: ; preds = %90 + %143 = or disjoint i32 %10, %13, !dbg !10 + %144 = icmp slt i32 %143, %4, !dbg !11 + %145 = and i32 %11, 31, !dbg !9 + %146 = lshr i32 %11, 5, !dbg !9 + %147 = extractelement <8 x float> %139, i64 7, !dbg !38 + %148 = bitcast float %147 to i32, !dbg !38 + %149 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %148, i32 16, i32 31), !dbg !38 + %150 = bitcast i32 %149 to float, !dbg !38 + %151 = extractelement <8 x i32> %140, i64 7, !dbg !38 + %152 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %151, i32 16, i32 31), !dbg !38 + %153 = fcmp ogt float %147, %150, !dbg !40 + %154 = fcmp oeq float %147, %150, !dbg !41 + %155 = fcmp uno <8 x float> %139, zeroinitializer, !dbg !42 + %156 = fcmp uno float %150, 0.000000e+00, !dbg !43 + %157 = xor i1 %156, true, !dbg !44 + %158 = extractelement <8 x i1> %155, i64 7, !dbg !45 + %159 = and i1 %158, %157, !dbg !46 + %160 = or i1 %153, %159, !dbg !47 + %161 = and i1 %158, %156, !dbg !45 + %162 = or i1 %154, %161, !dbg !48 + %163 = icmp slt i32 %151, %152, !dbg !49 + %164 = and i1 %163, %162, !dbg !50 + %165 = or i1 %160, %164, !dbg !51 + %166 = select i1 %165, float %147, float %150, !dbg !52 + %167 = select i1 %165, i32 %151, i32 %152, !dbg !53 + %168 = bitcast float %166 to i32, !dbg !38 + %169 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %168, i32 8, i32 31), !dbg !38 + %170 = bitcast i32 %169 to float, !dbg !38 + %171 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %167, i32 8, i32 31), !dbg !38 + %172 = fcmp ogt float %166, %170, !dbg !40 + %173 = fcmp oeq float %166, %170, !dbg !41 + %174 = fcmp uno float %166, 0.000000e+00, !dbg !42 + %175 = fcmp uno float %170, 0.000000e+00, !dbg !43 + %176 = xor i1 %175, true, !dbg !44 + %177 = and i1 %174, %176, !dbg !46 + %178 = or i1 %172, %177, !dbg !47 + %179 = and i1 %175, %174, !dbg !45 + %180 = or i1 %173, %179, !dbg !48 + %181 = icmp slt i32 %167, %171, !dbg !49 + %182 = and i1 %181, %180, !dbg !50 + %183 = or i1 %178, %182, !dbg !51 + %184 = select i1 %183, float %166, float %170, !dbg !52 + %185 = select i1 %183, i32 %167, i32 %171, !dbg !53 + %186 = bitcast float %184 to i32, !dbg !38 + %187 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %186, i32 4, i32 31), !dbg !38 + %188 = bitcast i32 %187 to float, !dbg !38 + %189 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %185, i32 4, i32 31), !dbg !38 + %190 = fcmp ogt float %184, %188, !dbg !40 + %191 = fcmp oeq float %184, %188, !dbg !41 + %192 = fcmp uno float %184, 0.000000e+00, !dbg !42 + %193 = fcmp uno float %188, 0.000000e+00, !dbg !43 + %194 = xor i1 %193, true, !dbg !44 + %195 = and i1 %192, %194, !dbg !46 + %196 = or i1 %190, %195, !dbg !47 + %197 = and i1 %193, %192, !dbg !45 + %198 = or i1 %191, %197, !dbg !48 + %199 = icmp slt i32 %185, %189, !dbg !49 + %200 = and i1 %199, %198, !dbg !50 + %201 = or i1 %196, %200, !dbg !51 + %202 = select i1 %201, float %184, float %188, !dbg !52 + %203 = select i1 %201, i32 %185, i32 %189, !dbg !53 + %204 = bitcast float %202 to i32, !dbg !38 + %205 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %204, i32 2, i32 31), !dbg !38 + %206 = bitcast i32 %205 to float, !dbg !38 + %207 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %203, i32 2, i32 31), !dbg !38 + %208 = fcmp ogt float %202, %206, !dbg !40 + %209 = fcmp oeq float %202, %206, !dbg !41 + %210 = fcmp uno float %202, 0.000000e+00, !dbg !42 + %211 = fcmp uno float %206, 0.000000e+00, !dbg !43 + %212 = xor i1 %211, true, !dbg !44 + %213 = and i1 %210, %212, !dbg !46 + %214 = or i1 %208, %213, !dbg !47 + %215 = and i1 %211, %210, !dbg !45 + %216 = or i1 %209, %215, !dbg !48 + %217 = icmp slt i32 %203, %207, !dbg !49 + %218 = and i1 %217, %216, !dbg !50 + %219 = or i1 %214, %218, !dbg !51 + %220 = select i1 %219, float %202, float %206, !dbg !52 + %221 = select i1 %219, i32 %203, i32 %207, !dbg !53 + %222 = bitcast float %220 to i32, !dbg !38 + %223 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %222, i32 1, i32 31), !dbg !38 + %224 = bitcast i32 %223 to float, !dbg !38 + %225 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %221, i32 1, i32 31), !dbg !38 + %226 = fcmp ogt float %220, %224, !dbg !40 + %227 = fcmp oeq float %220, %224, !dbg !41 + %228 = fcmp uno float %220, 0.000000e+00, !dbg !42 + %229 = fcmp uno float %224, 0.000000e+00, !dbg !43 + %230 = xor i1 %229, true, !dbg !44 + %231 = and i1 %228, %230, !dbg !46 + %232 = or i1 %226, %231, !dbg !47 + %233 = and i1 %229, %228, !dbg !45 + %234 = or i1 %227, %233, !dbg !48 + %235 = icmp slt i32 %221, %225, !dbg !49 + %236 = and i1 %235, %234, !dbg !50 + %237 = or i1 %232, %236, !dbg !51 + %238 = select i1 %237, i32 %221, i32 %225, !dbg !53 + %239 = extractelement <8 x float> %139, i64 6, !dbg !38 + %240 = bitcast float %239 to i32, !dbg !38 + %241 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %240, i32 16, i32 31), !dbg !38 + %242 = bitcast i32 %241 to float, !dbg !38 + %243 = extractelement <8 x i32> %140, i64 6, !dbg !38 + %244 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %243, i32 16, i32 31), !dbg !38 + %245 = fcmp ogt float %239, %242, !dbg !40 + %246 = fcmp oeq float %239, %242, !dbg !41 + %247 = fcmp uno float %242, 0.000000e+00, !dbg !43 + %248 = xor i1 %247, true, !dbg !44 + %249 = extractelement <8 x i1> %155, i64 6, !dbg !45 + %250 = and i1 %249, %248, !dbg !46 + %251 = or i1 %245, %250, !dbg !47 + %252 = and i1 %249, %247, !dbg !45 + %253 = or i1 %246, %252, !dbg !48 + %254 = icmp slt i32 %243, %244, !dbg !49 + %255 = and i1 %254, %253, !dbg !50 + %256 = or i1 %251, %255, !dbg !51 + %257 = select i1 %256, float %239, float %242, !dbg !52 + %258 = select i1 %256, i32 %243, i32 %244, !dbg !53 + %259 = bitcast float %257 to i32, !dbg !38 + %260 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %259, i32 8, i32 31), !dbg !38 + %261 = bitcast i32 %260 to float, !dbg !38 + %262 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %258, i32 8, i32 31), !dbg !38 + %263 = fcmp ogt float %257, %261, !dbg !40 + %264 = fcmp oeq float %257, %261, !dbg !41 + %265 = fcmp uno float %257, 0.000000e+00, !dbg !42 + %266 = fcmp uno float %261, 0.000000e+00, !dbg !43 + %267 = xor i1 %266, true, !dbg !44 + %268 = and i1 %265, %267, !dbg !46 + %269 = or i1 %263, %268, !dbg !47 + %270 = and i1 %266, %265, !dbg !45 + %271 = or i1 %264, %270, !dbg !48 + %272 = icmp slt i32 %258, %262, !dbg !49 + %273 = and i1 %272, %271, !dbg !50 + %274 = or i1 %269, %273, !dbg !51 + %275 = select i1 %274, float %257, float %261, !dbg !52 + %276 = select i1 %274, i32 %258, i32 %262, !dbg !53 + %277 = bitcast float %275 to i32, !dbg !38 + %278 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %277, i32 4, i32 31), !dbg !38 + %279 = bitcast i32 %278 to float, !dbg !38 + %280 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %276, i32 4, i32 31), !dbg !38 + %281 = fcmp ogt float %275, %279, !dbg !40 + %282 = fcmp oeq float %275, %279, !dbg !41 + %283 = fcmp uno float %275, 0.000000e+00, !dbg !42 + %284 = fcmp uno float %279, 0.000000e+00, !dbg !43 + %285 = xor i1 %284, true, !dbg !44 + %286 = and i1 %283, %285, !dbg !46 + %287 = or i1 %281, %286, !dbg !47 + %288 = and i1 %284, %283, !dbg !45 + %289 = or i1 %282, %288, !dbg !48 + %290 = icmp slt i32 %276, %280, !dbg !49 + %291 = and i1 %290, %289, !dbg !50 + %292 = or i1 %287, %291, !dbg !51 + %293 = select i1 %292, float %275, float %279, !dbg !52 + %294 = select i1 %292, i32 %276, i32 %280, !dbg !53 + %295 = bitcast float %293 to i32, !dbg !38 + %296 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %295, i32 2, i32 31), !dbg !38 + %297 = bitcast i32 %296 to float, !dbg !38 + %298 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %294, i32 2, i32 31), !dbg !38 + %299 = fcmp ogt float %293, %297, !dbg !40 + %300 = fcmp oeq float %293, %297, !dbg !41 + %301 = fcmp uno float %293, 0.000000e+00, !dbg !42 + %302 = fcmp uno float %297, 0.000000e+00, !dbg !43 + %303 = xor i1 %302, true, !dbg !44 + %304 = and i1 %301, %303, !dbg !46 + %305 = or i1 %299, %304, !dbg !47 + %306 = and i1 %302, %301, !dbg !45 + %307 = or i1 %300, %306, !dbg !48 + %308 = icmp slt i32 %294, %298, !dbg !49 + %309 = and i1 %308, %307, !dbg !50 + %310 = or i1 %305, %309, !dbg !51 + %311 = select i1 %310, float %293, float %297, !dbg !52 + %312 = select i1 %310, i32 %294, i32 %298, !dbg !53 + %313 = bitcast float %311 to i32, !dbg !38 + %314 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %313, i32 1, i32 31), !dbg !38 + %315 = bitcast i32 %314 to float, !dbg !38 + %316 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %312, i32 1, i32 31), !dbg !38 + %317 = fcmp ogt float %311, %315, !dbg !40 + %318 = fcmp oeq float %311, %315, !dbg !41 + %319 = fcmp uno float %311, 0.000000e+00, !dbg !42 + %320 = fcmp uno float %315, 0.000000e+00, !dbg !43 + %321 = xor i1 %320, true, !dbg !44 + %322 = and i1 %319, %321, !dbg !46 + %323 = or i1 %317, %322, !dbg !47 + %324 = and i1 %320, %319, !dbg !45 + %325 = or i1 %318, %324, !dbg !48 + %326 = icmp slt i32 %312, %316, !dbg !49 + %327 = and i1 %326, %325, !dbg !50 + %328 = or i1 %323, %327, !dbg !51 + %329 = select i1 %328, i32 %312, i32 %316, !dbg !53 + %330 = extractelement <8 x float> %139, i64 5, !dbg !38 + %331 = bitcast float %330 to i32, !dbg !38 + %332 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %331, i32 16, i32 31), !dbg !38 + %333 = bitcast i32 %332 to float, !dbg !38 + %334 = extractelement <8 x i32> %140, i64 5, !dbg !38 + %335 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %334, i32 16, i32 31), !dbg !38 + %336 = fcmp ogt float %330, %333, !dbg !40 + %337 = fcmp oeq float %330, %333, !dbg !41 + %338 = fcmp uno float %333, 0.000000e+00, !dbg !43 + %339 = xor i1 %338, true, !dbg !44 + %340 = extractelement <8 x i1> %155, i64 5, !dbg !45 + %341 = and i1 %340, %339, !dbg !46 + %342 = or i1 %336, %341, !dbg !47 + %343 = and i1 %340, %338, !dbg !45 + %344 = or i1 %337, %343, !dbg !48 + %345 = icmp slt i32 %334, %335, !dbg !49 + %346 = and i1 %345, %344, !dbg !50 + %347 = or i1 %342, %346, !dbg !51 + %348 = select i1 %347, float %330, float %333, !dbg !52 + %349 = select i1 %347, i32 %334, i32 %335, !dbg !53 + %350 = bitcast float %348 to i32, !dbg !38 + %351 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %350, i32 8, i32 31), !dbg !38 + %352 = bitcast i32 %351 to float, !dbg !38 + %353 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %349, i32 8, i32 31), !dbg !38 + %354 = fcmp ogt float %348, %352, !dbg !40 + %355 = fcmp oeq float %348, %352, !dbg !41 + %356 = fcmp uno float %348, 0.000000e+00, !dbg !42 + %357 = fcmp uno float %352, 0.000000e+00, !dbg !43 + %358 = xor i1 %357, true, !dbg !44 + %359 = and i1 %356, %358, !dbg !46 + %360 = or i1 %354, %359, !dbg !47 + %361 = and i1 %357, %356, !dbg !45 + %362 = or i1 %355, %361, !dbg !48 + %363 = icmp slt i32 %349, %353, !dbg !49 + %364 = and i1 %363, %362, !dbg !50 + %365 = or i1 %360, %364, !dbg !51 + %366 = select i1 %365, float %348, float %352, !dbg !52 + %367 = select i1 %365, i32 %349, i32 %353, !dbg !53 + %368 = bitcast float %366 to i32, !dbg !38 + %369 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %368, i32 4, i32 31), !dbg !38 + %370 = bitcast i32 %369 to float, !dbg !38 + %371 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %367, i32 4, i32 31), !dbg !38 + %372 = fcmp ogt float %366, %370, !dbg !40 + %373 = fcmp oeq float %366, %370, !dbg !41 + %374 = fcmp uno float %366, 0.000000e+00, !dbg !42 + %375 = fcmp uno float %370, 0.000000e+00, !dbg !43 + %376 = xor i1 %375, true, !dbg !44 + %377 = and i1 %374, %376, !dbg !46 + %378 = or i1 %372, %377, !dbg !47 + %379 = and i1 %375, %374, !dbg !45 + %380 = or i1 %373, %379, !dbg !48 + %381 = icmp slt i32 %367, %371, !dbg !49 + %382 = and i1 %381, %380, !dbg !50 + %383 = or i1 %378, %382, !dbg !51 + %384 = select i1 %383, float %366, float %370, !dbg !52 + %385 = select i1 %383, i32 %367, i32 %371, !dbg !53 + %386 = bitcast float %384 to i32, !dbg !38 + %387 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %386, i32 2, i32 31), !dbg !38 + %388 = bitcast i32 %387 to float, !dbg !38 + %389 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %385, i32 2, i32 31), !dbg !38 + %390 = fcmp ogt float %384, %388, !dbg !40 + %391 = fcmp oeq float %384, %388, !dbg !41 + %392 = fcmp uno float %384, 0.000000e+00, !dbg !42 + %393 = fcmp uno float %388, 0.000000e+00, !dbg !43 + %394 = xor i1 %393, true, !dbg !44 + %395 = and i1 %392, %394, !dbg !46 + %396 = or i1 %390, %395, !dbg !47 + %397 = and i1 %393, %392, !dbg !45 + %398 = or i1 %391, %397, !dbg !48 + %399 = icmp slt i32 %385, %389, !dbg !49 + %400 = and i1 %399, %398, !dbg !50 + %401 = or i1 %396, %400, !dbg !51 + %402 = select i1 %401, float %384, float %388, !dbg !52 + %403 = select i1 %401, i32 %385, i32 %389, !dbg !53 + %404 = bitcast float %402 to i32, !dbg !38 + %405 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %404, i32 1, i32 31), !dbg !38 + %406 = bitcast i32 %405 to float, !dbg !38 + %407 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %403, i32 1, i32 31), !dbg !38 + %408 = fcmp ogt float %402, %406, !dbg !40 + %409 = fcmp oeq float %402, %406, !dbg !41 + %410 = fcmp uno float %402, 0.000000e+00, !dbg !42 + %411 = fcmp uno float %406, 0.000000e+00, !dbg !43 + %412 = xor i1 %411, true, !dbg !44 + %413 = and i1 %410, %412, !dbg !46 + %414 = or i1 %408, %413, !dbg !47 + %415 = and i1 %411, %410, !dbg !45 + %416 = or i1 %409, %415, !dbg !48 + %417 = icmp slt i32 %403, %407, !dbg !49 + %418 = and i1 %417, %416, !dbg !50 + %419 = or i1 %414, %418, !dbg !51 + %420 = select i1 %419, i32 %403, i32 %407, !dbg !53 + %421 = extractelement <8 x float> %139, i64 4, !dbg !38 + %422 = bitcast float %421 to i32, !dbg !38 + %423 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %422, i32 16, i32 31), !dbg !38 + %424 = bitcast i32 %423 to float, !dbg !38 + %425 = extractelement <8 x i32> %140, i64 4, !dbg !38 + %426 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %425, i32 16, i32 31), !dbg !38 + %427 = fcmp ogt float %421, %424, !dbg !40 + %428 = fcmp oeq float %421, %424, !dbg !41 + %429 = fcmp uno float %424, 0.000000e+00, !dbg !43 + %430 = xor i1 %429, true, !dbg !44 + %431 = extractelement <8 x i1> %155, i64 4, !dbg !45 + %432 = and i1 %431, %430, !dbg !46 + %433 = or i1 %427, %432, !dbg !47 + %434 = and i1 %431, %429, !dbg !45 + %435 = or i1 %428, %434, !dbg !48 + %436 = icmp slt i32 %425, %426, !dbg !49 + %437 = and i1 %436, %435, !dbg !50 + %438 = or i1 %433, %437, !dbg !51 + %439 = select i1 %438, float %421, float %424, !dbg !52 + %440 = select i1 %438, i32 %425, i32 %426, !dbg !53 + %441 = bitcast float %439 to i32, !dbg !38 + %442 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %441, i32 8, i32 31), !dbg !38 + %443 = bitcast i32 %442 to float, !dbg !38 + %444 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %440, i32 8, i32 31), !dbg !38 + %445 = fcmp ogt float %439, %443, !dbg !40 + %446 = fcmp oeq float %439, %443, !dbg !41 + %447 = fcmp uno float %439, 0.000000e+00, !dbg !42 + %448 = fcmp uno float %443, 0.000000e+00, !dbg !43 + %449 = xor i1 %448, true, !dbg !44 + %450 = and i1 %447, %449, !dbg !46 + %451 = or i1 %445, %450, !dbg !47 + %452 = and i1 %448, %447, !dbg !45 + %453 = or i1 %446, %452, !dbg !48 + %454 = icmp slt i32 %440, %444, !dbg !49 + %455 = and i1 %454, %453, !dbg !50 + %456 = or i1 %451, %455, !dbg !51 + %457 = select i1 %456, float %439, float %443, !dbg !52 + %458 = select i1 %456, i32 %440, i32 %444, !dbg !53 + %459 = bitcast float %457 to i32, !dbg !38 + %460 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %459, i32 4, i32 31), !dbg !38 + %461 = bitcast i32 %460 to float, !dbg !38 + %462 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %458, i32 4, i32 31), !dbg !38 + %463 = fcmp ogt float %457, %461, !dbg !40 + %464 = fcmp oeq float %457, %461, !dbg !41 + %465 = fcmp uno float %457, 0.000000e+00, !dbg !42 + %466 = fcmp uno float %461, 0.000000e+00, !dbg !43 + %467 = xor i1 %466, true, !dbg !44 + %468 = and i1 %465, %467, !dbg !46 + %469 = or i1 %463, %468, !dbg !47 + %470 = and i1 %466, %465, !dbg !45 + %471 = or i1 %464, %470, !dbg !48 + %472 = icmp slt i32 %458, %462, !dbg !49 + %473 = and i1 %472, %471, !dbg !50 + %474 = or i1 %469, %473, !dbg !51 + %475 = select i1 %474, float %457, float %461, !dbg !52 + %476 = select i1 %474, i32 %458, i32 %462, !dbg !53 + %477 = bitcast float %475 to i32, !dbg !38 + %478 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %477, i32 2, i32 31), !dbg !38 + %479 = bitcast i32 %478 to float, !dbg !38 + %480 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %476, i32 2, i32 31), !dbg !38 + %481 = fcmp ogt float %475, %479, !dbg !40 + %482 = fcmp oeq float %475, %479, !dbg !41 + %483 = fcmp uno float %475, 0.000000e+00, !dbg !42 + %484 = fcmp uno float %479, 0.000000e+00, !dbg !43 + %485 = xor i1 %484, true, !dbg !44 + %486 = and i1 %483, %485, !dbg !46 + %487 = or i1 %481, %486, !dbg !47 + %488 = and i1 %484, %483, !dbg !45 + %489 = or i1 %482, %488, !dbg !48 + %490 = icmp slt i32 %476, %480, !dbg !49 + %491 = and i1 %490, %489, !dbg !50 + %492 = or i1 %487, %491, !dbg !51 + %493 = select i1 %492, float %475, float %479, !dbg !52 + %494 = select i1 %492, i32 %476, i32 %480, !dbg !53 + %495 = bitcast float %493 to i32, !dbg !38 + %496 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %495, i32 1, i32 31), !dbg !38 + %497 = bitcast i32 %496 to float, !dbg !38 + %498 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %494, i32 1, i32 31), !dbg !38 + %499 = fcmp ogt float %493, %497, !dbg !40 + %500 = fcmp oeq float %493, %497, !dbg !41 + %501 = fcmp uno float %493, 0.000000e+00, !dbg !42 + %502 = fcmp uno float %497, 0.000000e+00, !dbg !43 + %503 = xor i1 %502, true, !dbg !44 + %504 = and i1 %501, %503, !dbg !46 + %505 = or i1 %499, %504, !dbg !47 + %506 = and i1 %502, %501, !dbg !45 + %507 = or i1 %500, %506, !dbg !48 + %508 = icmp slt i32 %494, %498, !dbg !49 + %509 = and i1 %508, %507, !dbg !50 + %510 = or i1 %505, %509, !dbg !51 + %511 = select i1 %510, i32 %494, i32 %498, !dbg !53 + %512 = extractelement <8 x float> %139, i64 3, !dbg !38 + %513 = bitcast float %512 to i32, !dbg !38 + %514 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %513, i32 16, i32 31), !dbg !38 + %515 = bitcast i32 %514 to float, !dbg !38 + %516 = extractelement <8 x i32> %140, i64 3, !dbg !38 + %517 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %516, i32 16, i32 31), !dbg !38 + %518 = fcmp ogt float %512, %515, !dbg !40 + %519 = fcmp oeq float %512, %515, !dbg !41 + %520 = fcmp uno float %515, 0.000000e+00, !dbg !43 + %521 = xor i1 %520, true, !dbg !44 + %522 = extractelement <8 x i1> %155, i64 3, !dbg !45 + %523 = and i1 %522, %521, !dbg !46 + %524 = or i1 %518, %523, !dbg !47 + %525 = and i1 %522, %520, !dbg !45 + %526 = or i1 %519, %525, !dbg !48 + %527 = icmp slt i32 %516, %517, !dbg !49 + %528 = and i1 %527, %526, !dbg !50 + %529 = or i1 %524, %528, !dbg !51 + %530 = select i1 %529, float %512, float %515, !dbg !52 + %531 = select i1 %529, i32 %516, i32 %517, !dbg !53 + %532 = bitcast float %530 to i32, !dbg !38 + %533 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %532, i32 8, i32 31), !dbg !38 + %534 = bitcast i32 %533 to float, !dbg !38 + %535 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %531, i32 8, i32 31), !dbg !38 + %536 = fcmp ogt float %530, %534, !dbg !40 + %537 = fcmp oeq float %530, %534, !dbg !41 + %538 = fcmp uno float %530, 0.000000e+00, !dbg !42 + %539 = fcmp uno float %534, 0.000000e+00, !dbg !43 + %540 = xor i1 %539, true, !dbg !44 + %541 = and i1 %538, %540, !dbg !46 + %542 = or i1 %536, %541, !dbg !47 + %543 = and i1 %539, %538, !dbg !45 + %544 = or i1 %537, %543, !dbg !48 + %545 = icmp slt i32 %531, %535, !dbg !49 + %546 = and i1 %545, %544, !dbg !50 + %547 = or i1 %542, %546, !dbg !51 + %548 = select i1 %547, float %530, float %534, !dbg !52 + %549 = select i1 %547, i32 %531, i32 %535, !dbg !53 + %550 = bitcast float %548 to i32, !dbg !38 + %551 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %550, i32 4, i32 31), !dbg !38 + %552 = bitcast i32 %551 to float, !dbg !38 + %553 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %549, i32 4, i32 31), !dbg !38 + %554 = fcmp ogt float %548, %552, !dbg !40 + %555 = fcmp oeq float %548, %552, !dbg !41 + %556 = fcmp uno float %548, 0.000000e+00, !dbg !42 + %557 = fcmp uno float %552, 0.000000e+00, !dbg !43 + %558 = xor i1 %557, true, !dbg !44 + %559 = and i1 %556, %558, !dbg !46 + %560 = or i1 %554, %559, !dbg !47 + %561 = and i1 %557, %556, !dbg !45 + %562 = or i1 %555, %561, !dbg !48 + %563 = icmp slt i32 %549, %553, !dbg !49 + %564 = and i1 %563, %562, !dbg !50 + %565 = or i1 %560, %564, !dbg !51 + %566 = select i1 %565, float %548, float %552, !dbg !52 + %567 = select i1 %565, i32 %549, i32 %553, !dbg !53 + %568 = bitcast float %566 to i32, !dbg !38 + %569 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %568, i32 2, i32 31), !dbg !38 + %570 = bitcast i32 %569 to float, !dbg !38 + %571 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %567, i32 2, i32 31), !dbg !38 + %572 = fcmp ogt float %566, %570, !dbg !40 + %573 = fcmp oeq float %566, %570, !dbg !41 + %574 = fcmp uno float %566, 0.000000e+00, !dbg !42 + %575 = fcmp uno float %570, 0.000000e+00, !dbg !43 + %576 = xor i1 %575, true, !dbg !44 + %577 = and i1 %574, %576, !dbg !46 + %578 = or i1 %572, %577, !dbg !47 + %579 = and i1 %575, %574, !dbg !45 + %580 = or i1 %573, %579, !dbg !48 + %581 = icmp slt i32 %567, %571, !dbg !49 + %582 = and i1 %581, %580, !dbg !50 + %583 = or i1 %578, %582, !dbg !51 + %584 = select i1 %583, float %566, float %570, !dbg !52 + %585 = select i1 %583, i32 %567, i32 %571, !dbg !53 + %586 = bitcast float %584 to i32, !dbg !38 + %587 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %586, i32 1, i32 31), !dbg !38 + %588 = bitcast i32 %587 to float, !dbg !38 + %589 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %585, i32 1, i32 31), !dbg !38 + %590 = fcmp ogt float %584, %588, !dbg !40 + %591 = fcmp oeq float %584, %588, !dbg !41 + %592 = fcmp uno float %584, 0.000000e+00, !dbg !42 + %593 = fcmp uno float %588, 0.000000e+00, !dbg !43 + %594 = xor i1 %593, true, !dbg !44 + %595 = and i1 %592, %594, !dbg !46 + %596 = or i1 %590, %595, !dbg !47 + %597 = and i1 %593, %592, !dbg !45 + %598 = or i1 %591, %597, !dbg !48 + %599 = icmp slt i32 %585, %589, !dbg !49 + %600 = and i1 %599, %598, !dbg !50 + %601 = or i1 %596, %600, !dbg !51 + %602 = select i1 %601, i32 %585, i32 %589, !dbg !53 + %603 = extractelement <8 x float> %139, i64 2, !dbg !38 + %604 = bitcast float %603 to i32, !dbg !38 + %605 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %604, i32 16, i32 31), !dbg !38 + %606 = bitcast i32 %605 to float, !dbg !38 + %607 = extractelement <8 x i32> %140, i64 2, !dbg !38 + %608 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %607, i32 16, i32 31), !dbg !38 + %609 = fcmp ogt float %603, %606, !dbg !40 + %610 = fcmp oeq float %603, %606, !dbg !41 + %611 = fcmp uno float %606, 0.000000e+00, !dbg !43 + %612 = xor i1 %611, true, !dbg !44 + %613 = extractelement <8 x i1> %155, i64 2, !dbg !45 + %614 = and i1 %613, %612, !dbg !46 + %615 = or i1 %609, %614, !dbg !47 + %616 = and i1 %613, %611, !dbg !45 + %617 = or i1 %610, %616, !dbg !48 + %618 = icmp slt i32 %607, %608, !dbg !49 + %619 = and i1 %618, %617, !dbg !50 + %620 = or i1 %615, %619, !dbg !51 + %621 = select i1 %620, float %603, float %606, !dbg !52 + %622 = select i1 %620, i32 %607, i32 %608, !dbg !53 + %623 = bitcast float %621 to i32, !dbg !38 + %624 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %623, i32 8, i32 31), !dbg !38 + %625 = bitcast i32 %624 to float, !dbg !38 + %626 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %622, i32 8, i32 31), !dbg !38 + %627 = fcmp ogt float %621, %625, !dbg !40 + %628 = fcmp oeq float %621, %625, !dbg !41 + %629 = fcmp uno float %621, 0.000000e+00, !dbg !42 + %630 = fcmp uno float %625, 0.000000e+00, !dbg !43 + %631 = xor i1 %630, true, !dbg !44 + %632 = and i1 %629, %631, !dbg !46 + %633 = or i1 %627, %632, !dbg !47 + %634 = and i1 %630, %629, !dbg !45 + %635 = or i1 %628, %634, !dbg !48 + %636 = icmp slt i32 %622, %626, !dbg !49 + %637 = and i1 %636, %635, !dbg !50 + %638 = or i1 %633, %637, !dbg !51 + %639 = select i1 %638, float %621, float %625, !dbg !52 + %640 = select i1 %638, i32 %622, i32 %626, !dbg !53 + %641 = bitcast float %639 to i32, !dbg !38 + %642 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %641, i32 4, i32 31), !dbg !38 + %643 = bitcast i32 %642 to float, !dbg !38 + %644 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %640, i32 4, i32 31), !dbg !38 + %645 = fcmp ogt float %639, %643, !dbg !40 + %646 = fcmp oeq float %639, %643, !dbg !41 + %647 = fcmp uno float %639, 0.000000e+00, !dbg !42 + %648 = fcmp uno float %643, 0.000000e+00, !dbg !43 + %649 = xor i1 %648, true, !dbg !44 + %650 = and i1 %647, %649, !dbg !46 + %651 = or i1 %645, %650, !dbg !47 + %652 = and i1 %648, %647, !dbg !45 + %653 = or i1 %646, %652, !dbg !48 + %654 = icmp slt i32 %640, %644, !dbg !49 + %655 = and i1 %654, %653, !dbg !50 + %656 = or i1 %651, %655, !dbg !51 + %657 = select i1 %656, float %639, float %643, !dbg !52 + %658 = select i1 %656, i32 %640, i32 %644, !dbg !53 + %659 = bitcast float %657 to i32, !dbg !38 + %660 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %659, i32 2, i32 31), !dbg !38 + %661 = bitcast i32 %660 to float, !dbg !38 + %662 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %658, i32 2, i32 31), !dbg !38 + %663 = fcmp ogt float %657, %661, !dbg !40 + %664 = fcmp oeq float %657, %661, !dbg !41 + %665 = fcmp uno float %657, 0.000000e+00, !dbg !42 + %666 = fcmp uno float %661, 0.000000e+00, !dbg !43 + %667 = xor i1 %666, true, !dbg !44 + %668 = and i1 %665, %667, !dbg !46 + %669 = or i1 %663, %668, !dbg !47 + %670 = and i1 %666, %665, !dbg !45 + %671 = or i1 %664, %670, !dbg !48 + %672 = icmp slt i32 %658, %662, !dbg !49 + %673 = and i1 %672, %671, !dbg !50 + %674 = or i1 %669, %673, !dbg !51 + %675 = select i1 %674, float %657, float %661, !dbg !52 + %676 = select i1 %674, i32 %658, i32 %662, !dbg !53 + %677 = bitcast float %675 to i32, !dbg !38 + %678 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %677, i32 1, i32 31), !dbg !38 + %679 = bitcast i32 %678 to float, !dbg !38 + %680 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %676, i32 1, i32 31), !dbg !38 + %681 = fcmp ogt float %675, %679, !dbg !40 + %682 = fcmp oeq float %675, %679, !dbg !41 + %683 = fcmp uno float %675, 0.000000e+00, !dbg !42 + %684 = fcmp uno float %679, 0.000000e+00, !dbg !43 + %685 = xor i1 %684, true, !dbg !44 + %686 = and i1 %683, %685, !dbg !46 + %687 = or i1 %681, %686, !dbg !47 + %688 = and i1 %684, %683, !dbg !45 + %689 = or i1 %682, %688, !dbg !48 + %690 = icmp slt i32 %676, %680, !dbg !49 + %691 = and i1 %690, %689, !dbg !50 + %692 = or i1 %687, %691, !dbg !51 + %693 = select i1 %692, i32 %676, i32 %680, !dbg !53 + %694 = extractelement <8 x float> %139, i64 1, !dbg !38 + %695 = bitcast float %694 to i32, !dbg !38 + %696 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %695, i32 16, i32 31), !dbg !38 + %697 = bitcast i32 %696 to float, !dbg !38 + %698 = extractelement <8 x i32> %140, i64 1, !dbg !38 + %699 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %698, i32 16, i32 31), !dbg !38 + %700 = fcmp ogt float %694, %697, !dbg !40 + %701 = fcmp oeq float %694, %697, !dbg !41 + %702 = fcmp uno float %697, 0.000000e+00, !dbg !43 + %703 = xor i1 %702, true, !dbg !44 + %704 = extractelement <8 x i1> %155, i64 1, !dbg !45 + %705 = and i1 %704, %703, !dbg !46 + %706 = or i1 %700, %705, !dbg !47 + %707 = and i1 %704, %702, !dbg !45 + %708 = or i1 %701, %707, !dbg !48 + %709 = icmp slt i32 %698, %699, !dbg !49 + %710 = and i1 %709, %708, !dbg !50 + %711 = or i1 %706, %710, !dbg !51 + %712 = select i1 %711, float %694, float %697, !dbg !52 + %713 = select i1 %711, i32 %698, i32 %699, !dbg !53 + %714 = bitcast float %712 to i32, !dbg !38 + %715 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %714, i32 8, i32 31), !dbg !38 + %716 = bitcast i32 %715 to float, !dbg !38 + %717 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %713, i32 8, i32 31), !dbg !38 + %718 = fcmp ogt float %712, %716, !dbg !40 + %719 = fcmp oeq float %712, %716, !dbg !41 + %720 = fcmp uno float %712, 0.000000e+00, !dbg !42 + %721 = fcmp uno float %716, 0.000000e+00, !dbg !43 + %722 = xor i1 %721, true, !dbg !44 + %723 = and i1 %720, %722, !dbg !46 + %724 = or i1 %718, %723, !dbg !47 + %725 = and i1 %721, %720, !dbg !45 + %726 = or i1 %719, %725, !dbg !48 + %727 = icmp slt i32 %713, %717, !dbg !49 + %728 = and i1 %727, %726, !dbg !50 + %729 = or i1 %724, %728, !dbg !51 + %730 = select i1 %729, float %712, float %716, !dbg !52 + %731 = select i1 %729, i32 %713, i32 %717, !dbg !53 + %732 = bitcast float %730 to i32, !dbg !38 + %733 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %732, i32 4, i32 31), !dbg !38 + %734 = bitcast i32 %733 to float, !dbg !38 + %735 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %731, i32 4, i32 31), !dbg !38 + %736 = fcmp ogt float %730, %734, !dbg !40 + %737 = fcmp oeq float %730, %734, !dbg !41 + %738 = fcmp uno float %730, 0.000000e+00, !dbg !42 + %739 = fcmp uno float %734, 0.000000e+00, !dbg !43 + %740 = xor i1 %739, true, !dbg !44 + %741 = and i1 %738, %740, !dbg !46 + %742 = or i1 %736, %741, !dbg !47 + %743 = and i1 %739, %738, !dbg !45 + %744 = or i1 %737, %743, !dbg !48 + %745 = icmp slt i32 %731, %735, !dbg !49 + %746 = and i1 %745, %744, !dbg !50 + %747 = or i1 %742, %746, !dbg !51 + %748 = select i1 %747, float %730, float %734, !dbg !52 + %749 = select i1 %747, i32 %731, i32 %735, !dbg !53 + %750 = bitcast float %748 to i32, !dbg !38 + %751 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %750, i32 2, i32 31), !dbg !38 + %752 = bitcast i32 %751 to float, !dbg !38 + %753 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %749, i32 2, i32 31), !dbg !38 + %754 = fcmp ogt float %748, %752, !dbg !40 + %755 = fcmp oeq float %748, %752, !dbg !41 + %756 = fcmp uno float %748, 0.000000e+00, !dbg !42 + %757 = fcmp uno float %752, 0.000000e+00, !dbg !43 + %758 = xor i1 %757, true, !dbg !44 + %759 = and i1 %756, %758, !dbg !46 + %760 = or i1 %754, %759, !dbg !47 + %761 = and i1 %757, %756, !dbg !45 + %762 = or i1 %755, %761, !dbg !48 + %763 = icmp slt i32 %749, %753, !dbg !49 + %764 = and i1 %763, %762, !dbg !50 + %765 = or i1 %760, %764, !dbg !51 + %766 = select i1 %765, float %748, float %752, !dbg !52 + %767 = select i1 %765, i32 %749, i32 %753, !dbg !53 + %768 = bitcast float %766 to i32, !dbg !38 + %769 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %768, i32 1, i32 31), !dbg !38 + %770 = bitcast i32 %769 to float, !dbg !38 + %771 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %767, i32 1, i32 31), !dbg !38 + %772 = fcmp ogt float %766, %770, !dbg !40 + %773 = fcmp oeq float %766, %770, !dbg !41 + %774 = fcmp uno float %766, 0.000000e+00, !dbg !42 + %775 = fcmp uno float %770, 0.000000e+00, !dbg !43 + %776 = xor i1 %775, true, !dbg !44 + %777 = and i1 %774, %776, !dbg !46 + %778 = or i1 %772, %777, !dbg !47 + %779 = and i1 %775, %774, !dbg !45 + %780 = or i1 %773, %779, !dbg !48 + %781 = icmp slt i32 %767, %771, !dbg !49 + %782 = and i1 %781, %780, !dbg !50 + %783 = or i1 %778, %782, !dbg !51 + %784 = select i1 %783, i32 %767, i32 %771, !dbg !53 + %785 = extractelement <8 x float> %139, i64 0, !dbg !38 + %786 = bitcast float %785 to i32, !dbg !38 + %787 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %786, i32 16, i32 31), !dbg !38 + %788 = bitcast i32 %787 to float, !dbg !38 + %789 = extractelement <8 x i32> %140, i64 0, !dbg !38 + %790 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %789, i32 16, i32 31), !dbg !38 + %791 = fcmp ogt float %785, %788, !dbg !40 + %792 = fcmp oeq float %785, %788, !dbg !41 + %793 = fcmp uno float %788, 0.000000e+00, !dbg !43 + %794 = xor i1 %793, true, !dbg !44 + %795 = extractelement <8 x i1> %155, i64 0, !dbg !45 + %796 = and i1 %795, %794, !dbg !46 + %797 = or i1 %791, %796, !dbg !47 + %798 = and i1 %795, %793, !dbg !45 + %799 = or i1 %792, %798, !dbg !48 + %800 = icmp slt i32 %789, %790, !dbg !49 + %801 = and i1 %800, %799, !dbg !50 + %802 = or i1 %797, %801, !dbg !51 + %803 = select i1 %802, float %785, float %788, !dbg !52 + %804 = select i1 %802, i32 %789, i32 %790, !dbg !53 + %805 = bitcast float %803 to i32, !dbg !38 + %806 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %805, i32 8, i32 31), !dbg !38 + %807 = bitcast i32 %806 to float, !dbg !38 + %808 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %804, i32 8, i32 31), !dbg !38 + %809 = fcmp ogt float %803, %807, !dbg !40 + %810 = fcmp oeq float %803, %807, !dbg !41 + %811 = fcmp uno float %803, 0.000000e+00, !dbg !42 + %812 = fcmp uno float %807, 0.000000e+00, !dbg !43 + %813 = xor i1 %812, true, !dbg !44 + %814 = and i1 %811, %813, !dbg !46 + %815 = or i1 %809, %814, !dbg !47 + %816 = and i1 %812, %811, !dbg !45 + %817 = or i1 %810, %816, !dbg !48 + %818 = icmp slt i32 %804, %808, !dbg !49 + %819 = and i1 %818, %817, !dbg !50 + %820 = or i1 %815, %819, !dbg !51 + %821 = select i1 %820, float %803, float %807, !dbg !52 + %822 = select i1 %820, i32 %804, i32 %808, !dbg !53 + %823 = bitcast float %821 to i32, !dbg !38 + %824 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %823, i32 4, i32 31), !dbg !38 + %825 = bitcast i32 %824 to float, !dbg !38 + %826 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %822, i32 4, i32 31), !dbg !38 + %827 = fcmp ogt float %821, %825, !dbg !40 + %828 = fcmp oeq float %821, %825, !dbg !41 + %829 = fcmp uno float %821, 0.000000e+00, !dbg !42 + %830 = fcmp uno float %825, 0.000000e+00, !dbg !43 + %831 = xor i1 %830, true, !dbg !44 + %832 = and i1 %829, %831, !dbg !46 + %833 = or i1 %827, %832, !dbg !47 + %834 = and i1 %830, %829, !dbg !45 + %835 = or i1 %828, %834, !dbg !48 + %836 = icmp slt i32 %822, %826, !dbg !49 + %837 = and i1 %836, %835, !dbg !50 + %838 = or i1 %833, %837, !dbg !51 + %839 = select i1 %838, float %821, float %825, !dbg !52 + %840 = select i1 %838, i32 %822, i32 %826, !dbg !53 + %841 = bitcast float %839 to i32, !dbg !38 + %842 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %841, i32 2, i32 31), !dbg !38 + %843 = bitcast i32 %842 to float, !dbg !38 + %844 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %840, i32 2, i32 31), !dbg !38 + %845 = fcmp ogt float %839, %843, !dbg !40 + %846 = fcmp oeq float %839, %843, !dbg !41 + %847 = fcmp uno float %839, 0.000000e+00, !dbg !42 + %848 = fcmp uno float %843, 0.000000e+00, !dbg !43 + %849 = xor i1 %848, true, !dbg !44 + %850 = and i1 %847, %849, !dbg !46 + %851 = or i1 %845, %850, !dbg !47 + %852 = and i1 %848, %847, !dbg !45 + %853 = or i1 %846, %852, !dbg !48 + %854 = icmp slt i32 %840, %844, !dbg !49 + %855 = and i1 %854, %853, !dbg !50 + %856 = or i1 %851, %855, !dbg !51 + %857 = select i1 %856, float %839, float %843, !dbg !52 + %858 = select i1 %856, i32 %840, i32 %844, !dbg !53 + %859 = bitcast float %857 to i32, !dbg !38 + %860 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %859, i32 1, i32 31), !dbg !38 + %861 = bitcast i32 %860 to float, !dbg !38 + %862 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %858, i32 1, i32 31), !dbg !38 + %863 = fcmp ogt float %857, %861, !dbg !40 + %864 = fcmp oeq float %857, %861, !dbg !41 + %865 = fcmp uno float %857, 0.000000e+00, !dbg !42 + %866 = fcmp uno float %861, 0.000000e+00, !dbg !43 + %867 = xor i1 %866, true, !dbg !44 + %868 = and i1 %865, %867, !dbg !46 + %869 = or i1 %863, %868, !dbg !47 + %870 = and i1 %866, %865, !dbg !45 + %871 = or i1 %864, %870, !dbg !48 + %872 = icmp slt i32 %858, %862, !dbg !49 + %873 = and i1 %872, %871, !dbg !50 + %874 = or i1 %869, %873, !dbg !51 + %875 = select i1 %874, i32 %858, i32 %862, !dbg !53 + %876 = and i32 %146, 1, !dbg !38 + %877 = icmp eq i32 %145, 0, !dbg !38 + %878 = lshr exact i32 %12, 5, !dbg !38 + %879 = or disjoint i32 %878, %876, !dbg !38 + %880 = getelementptr float, ptr addrspace(3) @global_smem, i32 %879, !dbg !38 + %881 = select i1 %237, i32 %222, i32 %223, !dbg !52 + %882 = insertelement <1 x i32> poison, i32 %881, i64 0, !dbg !38 + tail call void asm sideeffect "@$2 st.shared.b32 [ $0 + 0 ], $1;", "r,r,b"(ptr addrspace(3) %880, <1 x i32> %882, i1 %877) #4, !dbg !38 + %883 = getelementptr i32, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 512), i32 %879, !dbg !38 + %884 = insertelement <1 x i32> poison, i32 %238, i64 0, !dbg !38 + tail call void asm sideeffect "@$2 st.shared.b32 [ $0 + 0 ], $1;", "r,r,b"(ptr addrspace(3) %883, <1 x i32> %884, i1 %877) #4, !dbg !38 + %885 = shl nuw nsw i32 %15, 1, !dbg !38 + %886 = or disjoint i32 %885, %876, !dbg !38 + %887 = getelementptr float, ptr addrspace(3) @global_smem, i32 %886, !dbg !38 + %888 = select i1 %328, i32 %313, i32 %314, !dbg !52 + %889 = insertelement <1 x i32> poison, i32 %888, i64 0, !dbg !38 + tail call void asm sideeffect "@$2 st.shared.b32 [ $0 + 0 ], $1;", "r,r,b"(ptr addrspace(3) %887, <1 x i32> %889, i1 %877) #4, !dbg !38 + %890 = getelementptr i32, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 512), i32 %886, !dbg !38 + %891 = insertelement <1 x i32> poison, i32 %329, i64 0, !dbg !38 + tail call void asm sideeffect "@$2 st.shared.b32 [ $0 + 0 ], $1;", "r,r,b"(ptr addrspace(3) %890, <1 x i32> %891, i1 %877) #4, !dbg !38 + %892 = shl nuw nsw i32 %16, 1, !dbg !38 + %893 = or disjoint i32 %892, %876, !dbg !38 + %894 = getelementptr float, ptr addrspace(3) @global_smem, i32 %893, !dbg !38 + %895 = select i1 %419, i32 %404, i32 %405, !dbg !52 + %896 = insertelement <1 x i32> poison, i32 %895, i64 0, !dbg !38 + tail call void asm sideeffect "@$2 st.shared.b32 [ $0 + 0 ], $1;", "r,r,b"(ptr addrspace(3) %894, <1 x i32> %896, i1 %877) #4, !dbg !38 + %897 = getelementptr i32, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 512), i32 %893, !dbg !38 + %898 = insertelement <1 x i32> poison, i32 %420, i64 0, !dbg !38 + tail call void asm sideeffect "@$2 st.shared.b32 [ $0 + 0 ], $1;", "r,r,b"(ptr addrspace(3) %897, <1 x i32> %898, i1 %877) #4, !dbg !38 + %899 = shl nuw nsw i32 %17, 1, !dbg !38 + %900 = or disjoint i32 %899, %876, !dbg !38 + %901 = getelementptr float, ptr addrspace(3) @global_smem, i32 %900, !dbg !38 + %902 = select i1 %510, i32 %495, i32 %496, !dbg !52 + %903 = insertelement <1 x i32> poison, i32 %902, i64 0, !dbg !38 + tail call void asm sideeffect "@$2 st.shared.b32 [ $0 + 0 ], $1;", "r,r,b"(ptr addrspace(3) %901, <1 x i32> %903, i1 %877) #4, !dbg !38 + %904 = getelementptr i32, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 512), i32 %900, !dbg !38 + %905 = insertelement <1 x i32> poison, i32 %511, i64 0, !dbg !38 + tail call void asm sideeffect "@$2 st.shared.b32 [ $0 + 0 ], $1;", "r,r,b"(ptr addrspace(3) %904, <1 x i32> %905, i1 %877) #4, !dbg !38 + %906 = extractelement <4 x i32> %20, i64 3, !dbg !38 + %907 = shl nuw nsw i32 %906, 1, !dbg !38 + %908 = or disjoint i32 %907, %876, !dbg !38 + %909 = getelementptr float, ptr addrspace(3) @global_smem, i32 %908, !dbg !38 + %910 = select i1 %601, i32 %586, i32 %587, !dbg !52 + %911 = insertelement <1 x i32> poison, i32 %910, i64 0, !dbg !38 + tail call void asm sideeffect "@$2 st.shared.b32 [ $0 + 0 ], $1;", "r,r,b"(ptr addrspace(3) %909, <1 x i32> %911, i1 %877) #4, !dbg !38 + %912 = getelementptr i32, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 512), i32 %908, !dbg !38 + %913 = insertelement <1 x i32> poison, i32 %602, i64 0, !dbg !38 + tail call void asm sideeffect "@$2 st.shared.b32 [ $0 + 0 ], $1;", "r,r,b"(ptr addrspace(3) %912, <1 x i32> %913, i1 %877) #4, !dbg !38 + %914 = extractelement <4 x i32> %20, i64 2, !dbg !38 + %915 = shl nuw nsw i32 %914, 1, !dbg !38 + %916 = or disjoint i32 %915, %876, !dbg !38 + %917 = getelementptr float, ptr addrspace(3) @global_smem, i32 %916, !dbg !38 + %918 = select i1 %692, i32 %677, i32 %678, !dbg !52 + %919 = insertelement <1 x i32> poison, i32 %918, i64 0, !dbg !38 + tail call void asm sideeffect "@$2 st.shared.b32 [ $0 + 0 ], $1;", "r,r,b"(ptr addrspace(3) %917, <1 x i32> %919, i1 %877) #4, !dbg !38 + %920 = getelementptr i32, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 512), i32 %916, !dbg !38 + %921 = insertelement <1 x i32> poison, i32 %693, i64 0, !dbg !38 + tail call void asm sideeffect "@$2 st.shared.b32 [ $0 + 0 ], $1;", "r,r,b"(ptr addrspace(3) %920, <1 x i32> %921, i1 %877) #4, !dbg !38 + %922 = extractelement <4 x i32> %20, i64 1, !dbg !38 + %923 = shl nuw nsw i32 %922, 1, !dbg !38 + %924 = or disjoint i32 %923, %876, !dbg !38 + %925 = getelementptr float, ptr addrspace(3) @global_smem, i32 %924, !dbg !38 + %926 = select i1 %783, i32 %768, i32 %769, !dbg !52 + %927 = insertelement <1 x i32> poison, i32 %926, i64 0, !dbg !38 + tail call void asm sideeffect "@$2 st.shared.b32 [ $0 + 0 ], $1;", "r,r,b"(ptr addrspace(3) %925, <1 x i32> %927, i1 %877) #4, !dbg !38 + %928 = getelementptr i32, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 512), i32 %924, !dbg !38 + %929 = insertelement <1 x i32> poison, i32 %784, i64 0, !dbg !38 + tail call void asm sideeffect "@$2 st.shared.b32 [ $0 + 0 ], $1;", "r,r,b"(ptr addrspace(3) %928, <1 x i32> %929, i1 %877) #4, !dbg !38 + %930 = extractelement <4 x i32> %20, i64 0, !dbg !38 + %931 = shl nuw nsw i32 %930, 1, !dbg !38 + %932 = or disjoint i32 %931, %876, !dbg !38 + %933 = getelementptr float, ptr addrspace(3) @global_smem, i32 %932, !dbg !38 + %934 = select i1 %874, i32 %859, i32 %860, !dbg !52 + %935 = insertelement <1 x i32> poison, i32 %934, i64 0, !dbg !38 + tail call void asm sideeffect "@$2 st.shared.b32 [ $0 + 0 ], $1;", "r,r,b"(ptr addrspace(3) %933, <1 x i32> %935, i1 %877) #4, !dbg !38 + %936 = getelementptr i32, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 512), i32 %932, !dbg !38 + %937 = insertelement <1 x i32> poison, i32 %875, i64 0, !dbg !38 + tail call void asm sideeffect "@$2 st.shared.b32 [ $0 + 0 ], $1;", "r,r,b"(ptr addrspace(3) %936, <1 x i32> %937, i1 %877) #4, !dbg !38 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !38 + %938 = icmp samesign ult i32 %11, 128, !dbg !38 + %939 = getelementptr float, ptr addrspace(3) @global_smem, i32 %11, !dbg !38 + %940 = tail call i32 asm sideeffect "@$2 ld.shared.b32 $0, [ $1 + 0 ];", "=r,r,b"(ptr addrspace(3) %939, i1 %938) #4, !dbg !38 + %941 = bitcast i32 %940 to float, !dbg !38 + %942 = getelementptr i32, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 512), i32 %11, !dbg !38 + %943 = tail call i32 asm sideeffect "@$2 ld.shared.b32 $0, [ $1 + 0 ];", "=r,r,b"(ptr addrspace(3) %942, i1 %938) #4, !dbg !38 + %944 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %940, i32 1, i32 31), !dbg !38 + %945 = bitcast i32 %944 to float, !dbg !38 + %946 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %943, i32 1, i32 31), !dbg !38 + %947 = fcmp ogt float %941, %945, !dbg !40 + %948 = fcmp oeq float %941, %945, !dbg !41 + %949 = fcmp uno float %941, 0.000000e+00, !dbg !42 + %950 = fcmp uno float %945, 0.000000e+00, !dbg !43 + %951 = xor i1 %950, true, !dbg !44 + %952 = and i1 %949, %951, !dbg !46 + %953 = or i1 %947, %952, !dbg !47 + %954 = and i1 %949, %950, !dbg !45 + %955 = or i1 %948, %954, !dbg !48 + %956 = icmp slt i32 %943, %946, !dbg !49 + %957 = and i1 %956, %955, !dbg !50 + %958 = or i1 %953, %957, !dbg !51 + %959 = select i1 %958, i32 %943, i32 %946, !dbg !53 + %960 = and i32 %11, 897, !dbg !38 + %961 = icmp eq i32 %960, 0, !dbg !38 + %962 = select i1 %958, i32 %940, i32 %944, !dbg !52 + %963 = insertelement <1 x i32> poison, i32 %962, i64 0, !dbg !38 + tail call void asm sideeffect "@$2 st.shared.b32 [ $0 + 0 ], $1;", "r,r,b"(ptr addrspace(3) %939, <1 x i32> %963, i1 %961) #4, !dbg !38 + %964 = insertelement <1 x i32> poison, i32 %959, i64 0, !dbg !38 + tail call void asm sideeffect "@$2 st.shared.b32 [ $0 + 0 ], $1;", "r,r,b"(ptr addrspace(3) %942, <1 x i32> %964, i1 %961) #4, !dbg !38 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !38 + %965 = getelementptr i32, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 512), i32 %878, !dbg !38 + %966 = load i32, ptr addrspace(3) %965, align 8, !dbg !38 + %967 = getelementptr i32, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 512), i32 %885, !dbg !38 + %968 = load i32, ptr addrspace(3) %967, align 8, !dbg !38 + %969 = getelementptr i32, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 512), i32 %892, !dbg !38 + %970 = load i32, ptr addrspace(3) %969, align 8, !dbg !38 + %971 = getelementptr i32, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 512), i32 %899, !dbg !38 + %972 = load i32, ptr addrspace(3) %971, align 8, !dbg !38 + %973 = getelementptr i32, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 512), i32 %907, !dbg !38 + %974 = load i32, ptr addrspace(3) %973, align 8, !dbg !38 + %975 = getelementptr i32, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 512), i32 %915, !dbg !38 + %976 = load i32, ptr addrspace(3) %975, align 8, !dbg !38 + %977 = getelementptr i32, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 512), i32 %923, !dbg !38 + %978 = load i32, ptr addrspace(3) %977, align 8, !dbg !38 + %979 = getelementptr i32, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 512), i32 %931, !dbg !38 + %980 = load i32, ptr addrspace(3) %979, align 8, !dbg !38 + %981 = sext i32 %143 to i64, !dbg !54 + %982 = getelementptr i64, ptr addrspace(1) %1, i64 %981, !dbg !54 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !55 + %983 = lshr exact i32 %12, 2, !dbg !55 + %984 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %983, !dbg !55 + %985 = insertelement <4 x i32> poison, i32 %966, i64 0, !dbg !55 + %986 = insertelement <4 x i32> %985, i32 %968, i64 1, !dbg !55 + %987 = insertelement <4 x i32> %986, i32 %970, i64 2, !dbg !55 + %988 = insertelement <4 x i32> %987, i32 %972, i64 3, !dbg !55 + store <4 x i32> %988, ptr addrspace(3) %984, align 16, !dbg !55 + %989 = getelementptr inbounds nuw i8, ptr addrspace(3) %984, i32 128, !dbg !55 + %990 = insertelement <4 x i32> poison, i32 %974, i64 0, !dbg !55 + %991 = insertelement <4 x i32> %990, i32 %976, i64 1, !dbg !55 + %992 = insertelement <4 x i32> %991, i32 %978, i64 2, !dbg !55 + %993 = insertelement <4 x i32> %992, i32 %980, i64 3, !dbg !55 + store <4 x i32> %993, ptr addrspace(3) %989, align 16, !dbg !55 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !55 + %994 = shl nuw nsw i32 %11, 4, !dbg !55 + %995 = and i32 %994, 112, !dbg !55 + %996 = lshr i32 %11, 1, !dbg !55 + %997 = and i32 %996, 12, !dbg !55 + %998 = shl nuw nsw i32 %11, 2, !dbg !55 + %999 = and i32 %998, 128, !dbg !55 + %1000 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %995, !dbg !55 + %1001 = getelementptr inbounds nuw i8, ptr addrspace(3) %1000, i32 %999, !dbg !55 + %1002 = getelementptr inbounds nuw i8, ptr addrspace(3) %1001, i32 %997, !dbg !55 + %1003 = load i32, ptr addrspace(3) %1002, align 4, !dbg !55 + %1004 = sext i32 %1003 to i64, !dbg !55 + %1005 = icmp eq i32 %12, 0, !dbg !55 + %1006 = and i1 %1005, %144, !dbg !55 + tail call void asm sideeffect "@$2 st.global.b64 [ $1 + 0 ], { $0 };", "l,l,b"(i64 %1004, ptr addrspace(1) %982, i1 %1006) #4, !dbg !55 + ret void, !dbg !56 +} + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare noundef range(i32 0, 2147483647) i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #1 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare noundef range(i32 0, 1024) i32 @llvm.nvvm.read.ptx.sreg.tid.x() #1 + +; Function Attrs: convergent nocallback nounwind memory(inaccessiblemem: readwrite) +declare i32 @llvm.nvvm.shfl.sync.bfly.i32(i32, i32, i32, i32) #2 + +; Function Attrs: convergent nocallback nounwind +declare void @llvm.nvvm.barrier.cta.sync.aligned.all(i32) #3 + +attributes #0 = { nounwind "nvvm.reqntid"="512" } +attributes #1 = { mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) } +attributes #2 = { convergent nocallback nounwind memory(inaccessiblemem: readwrite) } +attributes #3 = { convergent nocallback nounwind } +attributes #4 = { nounwind } + +!llvm.dbg.cu = !{!0} +!llvm.module.flags = !{!2, !3} + +!0 = distinct !DICompileUnit(language: DW_LANG_C, file: !1, producer: "triton", isOptimized: true, runtimeVersion: 0, emissionKind: LineTablesOnly) +!1 = !DIFile(filename: "cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py", directory: "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va") +!2 = !{i32 2, !"Debug Info Version", i32 3} +!3 = !{i32 4, !"nvvm-reflect-ftz", i32 1} +!4 = distinct !DISubprogram(name: "triton_red_fused_argmax_1", linkageName: "triton_red_fused_argmax_1", scope: !1, file: !1, line: 18, type: !5, scopeLine: 18, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !0) +!5 = !DISubroutineType(cc: DW_CC_normal, types: !6) +!6 = !{} +!7 = !DILocation(line: 22, column: 28, scope: !4) +!8 = !DILocation(line: 22, column: 33, scope: !4) +!9 = !DILocation(line: 23, column: 44, scope: !4) +!10 = !DILocation(line: 23, column: 23, scope: !4) +!11 = !DILocation(line: 24, column: 21, scope: !4) +!12 = !DILocation(line: 27, column: 19, scope: !4) +!13 = !DILocation(line: 28, column: 19, scope: !4) +!14 = !DILocation(line: 38, column: 56, scope: !4) +!15 = !DILocation(line: 32, column: 40, scope: !4) +!16 = !DILocation(line: 38, column: 61, scope: !4) +!17 = !DILocation(line: 33, column: 31, scope: !4) +!18 = !DILocation(line: 38, column: 34, scope: !4) +!19 = !DILocation(line: 147, column: 29, scope: !20, inlinedAt: !22) +!20 = distinct !DILexicalBlockFile(scope: !4, file: !21, discriminator: 0) +!21 = !DIFile(filename: "triton_helpers.py", directory: "/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime") +!22 = !DILocation(line: 41, column: 38, scope: !4) +!23 = !DILocation(line: 155, column: 69, scope: !20, inlinedAt: !22) +!24 = !DILocation(line: 144, column: 21, scope: !20, inlinedAt: !22) +!25 = !DILocation(line: 145, column: 23, scope: !20, inlinedAt: !22) +!26 = !DILocation(line: 148, column: 29, scope: !20, inlinedAt: !22) +!27 = !DILocation(line: 149, column: 31, scope: !20, inlinedAt: !22) +!28 = !DILocation(line: 149, column: 27, scope: !20, inlinedAt: !22) +!29 = !DILocation(line: 149, column: 16, scope: !20, inlinedAt: !22) +!30 = !DILocation(line: 151, column: 27, scope: !20, inlinedAt: !22) +!31 = !DILocation(line: 151, column: 17, scope: !20, inlinedAt: !22) +!32 = !DILocation(line: 154, column: 31, scope: !20, inlinedAt: !22) +!33 = !DILocation(line: 154, column: 21, scope: !20, inlinedAt: !22) +!34 = !DILocation(line: 154, column: 12, scope: !20, inlinedAt: !22) +!35 = !DILocation(line: 155, column: 35, scope: !20, inlinedAt: !22) +!36 = !DILocation(line: 43, column: 54, scope: !4) +!37 = !DILocation(line: 44, column: 66, scope: !4) +!38 = !DILocation(line: 165, column: 42, scope: !20, inlinedAt: !39) +!39 = !DILocation(line: 45, column: 75, scope: !4) +!40 = !DILocation(line: 144, column: 21, scope: !20, inlinedAt: !39) +!41 = !DILocation(line: 145, column: 23, scope: !20, inlinedAt: !39) +!42 = !DILocation(line: 147, column: 29, scope: !20, inlinedAt: !39) +!43 = !DILocation(line: 148, column: 29, scope: !20, inlinedAt: !39) +!44 = !DILocation(line: 149, column: 31, scope: !20, inlinedAt: !39) +!45 = !DILocation(line: 151, column: 27, scope: !20, inlinedAt: !39) +!46 = !DILocation(line: 149, column: 27, scope: !20, inlinedAt: !39) +!47 = !DILocation(line: 149, column: 16, scope: !20, inlinedAt: !39) +!48 = !DILocation(line: 151, column: 17, scope: !20, inlinedAt: !39) +!49 = !DILocation(line: 154, column: 31, scope: !20, inlinedAt: !39) +!50 = !DILocation(line: 154, column: 21, scope: !20, inlinedAt: !39) +!51 = !DILocation(line: 154, column: 12, scope: !20, inlinedAt: !39) +!52 = !DILocation(line: 155, column: 35, scope: !20, inlinedAt: !39) +!53 = !DILocation(line: 155, column: 69, scope: !20, inlinedAt: !39) +!54 = !DILocation(line: 47, column: 25, scope: !4) +!55 = !DILocation(line: 47, column: 36, scope: !4) +!56 = !DILocation(line: 47, column: 4, scope: !4) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/P3QZALUCGDQIGARFKKM54ROSVAIQRSFS3BCW42DE3EAUTSMCDEEQ/triton_red_fused_argmax_1.ptx b/SpecForge-ext/cache/compiled_kernels/triton/0/P3QZALUCGDQIGARFKKM54ROSVAIQRSFS3BCW42DE3EAUTSMCDEEQ/triton_red_fused_argmax_1.ptx new file mode 100644 index 0000000000000000000000000000000000000000..3b7f078858280d4c83d8a33a4e489a673e3a2bde --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/P3QZALUCGDQIGARFKKM54ROSVAIQRSFS3BCW42DE3EAUTSMCDEEQ/triton_red_fused_argmax_1.ptx @@ -0,0 +1,2198 @@ +// +// Generated by LLVM NVPTX Back-End +// + +.version 8.7 +.target sm_90a +.address_size 64 + + // .globl triton_red_fused_argmax_1 // -- Begin function triton_red_fused_argmax_1 +.extern .shared .align 16 .b8 global_smem[]; + // @triton_red_fused_argmax_1 +.visible .entry triton_red_fused_argmax_1( + .param .u64 .ptr .global .align 1 triton_red_fused_argmax_1_param_0, + .param .u64 .ptr .global .align 1 triton_red_fused_argmax_1_param_1, + .param .u64 triton_red_fused_argmax_1_param_2, + .param .u64 triton_red_fused_argmax_1_param_3, + .param .u32 triton_red_fused_argmax_1_param_4, + .param .u32 triton_red_fused_argmax_1_param_5, + .param .u64 .ptr .global .align 1 triton_red_fused_argmax_1_param_6, + .param .u64 .ptr .global .align 1 triton_red_fused_argmax_1_param_7 +) +.reqntid 512 +{ + .reg .pred %p<645>; + .reg .b32 %r<399>; + .reg .b64 %rd<214>; +$L__func_begin0: + +// %bb.0: + .loc 1 22 28 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:22:28 + mov.u32 %r45, %ctaid.x; + .loc 1 22 33 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:22:33 + shl.b32 %r1, %r45, 6; + ld.param.b64 %rd87, [triton_red_fused_argmax_1_param_2]; + .loc 1 23 44 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:23:44 + mov.u32 %r2, %tid.x; + bfe.u32 %r46, %r2, 6, 3; + or.b32 %r5, %r46, 8; + or.b32 %r6, %r46, 16; + or.b32 %r7, %r46, 24; + .loc 1 23 23 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:23:23 + or.b32 %r52, %r6, %r1; + or.b32 %r53, %r5, %r1; + or.b32 %r54, %r46, %r1; + .loc 1 27 19 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:27:19 + cvt.s64.s32 %rd8, %r54; + cvt.s64.s32 %rd1, %r53; + .loc 1 28 19 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:28:19 + or.b64 %rd88, %rd8, %rd87; + and.b64 %rd89, %rd88, -4294967296; + setp.ne.b64 %p9, %rd89, 0; + @%p9 bra $L__BB0_2; + bra.uni $L__BB0_1; +$L__BB0_2: + div.s64 %rd193, %rd8, %rd87; + bra.uni $L__BB0_3; +$L__BB0_1: + cvt.u32.u64 %r55, %rd87; + cvt.u32.u64 %r56, %rd8; + div.u32 %r57, %r56, %r55; + cvt.u64.u32 %rd193, %r57; +$L__BB0_3: + .loc 1 0 0 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:0 + or.b32 %r11, %r46, 32; + or.b32 %r51, %r7, %r1; + .loc 1 27 19 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:27:19 + cvt.s64.s32 %rd2, %r52; + .loc 1 28 19 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:28:19 + or.b64 %rd91, %rd1, %rd87; + and.b64 %rd92, %rd91, -4294967296; + setp.ne.b64 %p10, %rd92, 0; + @%p10 bra $L__BB0_5; + bra.uni $L__BB0_4; +$L__BB0_5: + div.s64 %rd194, %rd1, %rd87; + bra.uni $L__BB0_6; +$L__BB0_4: + cvt.u32.u64 %r58, %rd87; + cvt.u32.u64 %r59, %rd1; + div.u32 %r60, %r59, %r58; + cvt.u64.u32 %rd194, %r60; +$L__BB0_6: + .loc 1 0 0 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:0 + or.b32 %r10, %r46, 40; + or.b32 %r50, %r11, %r1; + .loc 1 27 19 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:27:19 + cvt.s64.s32 %rd3, %r51; + .loc 1 28 19 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:28:19 + or.b64 %rd94, %rd2, %rd87; + and.b64 %rd95, %rd94, -4294967296; + setp.ne.b64 %p11, %rd95, 0; + @%p11 bra $L__BB0_8; + bra.uni $L__BB0_7; +$L__BB0_8: + div.s64 %rd195, %rd2, %rd87; + bra.uni $L__BB0_9; +$L__BB0_7: + cvt.u32.u64 %r61, %rd87; + cvt.u32.u64 %r62, %rd2; + div.u32 %r63, %r62, %r61; + cvt.u64.u32 %rd195, %r63; +$L__BB0_9: + .loc 1 0 0 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:0 + or.b32 %r9, %r46, 48; + or.b32 %r49, %r10, %r1; + .loc 1 27 19 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:27:19 + cvt.s64.s32 %rd4, %r50; + .loc 1 28 19 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:28:19 + or.b64 %rd97, %rd3, %rd87; + and.b64 %rd98, %rd97, -4294967296; + setp.ne.b64 %p12, %rd98, 0; + @%p12 bra $L__BB0_11; + bra.uni $L__BB0_10; +$L__BB0_11: + div.s64 %rd196, %rd3, %rd87; + bra.uni $L__BB0_12; +$L__BB0_10: + cvt.u32.u64 %r64, %rd87; + cvt.u32.u64 %r65, %rd3; + div.u32 %r66, %r65, %r64; + cvt.u64.u32 %rd196, %r66; +$L__BB0_12: + .loc 1 0 0 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:0 + or.b32 %r8, %r46, 56; + or.b32 %r48, %r9, %r1; + .loc 1 27 19 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:27:19 + cvt.s64.s32 %rd5, %r49; + .loc 1 28 19 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:28:19 + or.b64 %rd100, %rd4, %rd87; + and.b64 %rd101, %rd100, -4294967296; + setp.ne.b64 %p13, %rd101, 0; + @%p13 bra $L__BB0_14; + bra.uni $L__BB0_13; +$L__BB0_14: + div.s64 %rd197, %rd4, %rd87; + bra.uni $L__BB0_15; +$L__BB0_13: + cvt.u32.u64 %r67, %rd87; + cvt.u32.u64 %r68, %rd4; + div.u32 %r69, %r68, %r67; + cvt.u64.u32 %rd197, %r69; +$L__BB0_15: + .loc 1 0 0 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:0 + or.b32 %r47, %r8, %r1; + .loc 1 27 19 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:27:19 + cvt.s64.s32 %rd6, %r48; + .loc 1 28 19 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:28:19 + or.b64 %rd103, %rd5, %rd87; + and.b64 %rd104, %rd103, -4294967296; + setp.ne.b64 %p14, %rd104, 0; + @%p14 bra $L__BB0_17; + bra.uni $L__BB0_16; +$L__BB0_17: + div.s64 %rd198, %rd5, %rd87; + bra.uni $L__BB0_18; +$L__BB0_16: + cvt.u32.u64 %r70, %rd87; + cvt.u32.u64 %r71, %rd5; + div.u32 %r72, %r71, %r70; + cvt.u64.u32 %rd198, %r72; +$L__BB0_18: + .loc 1 27 19 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:27:19 + mul.lo.s64 %rd90, %rd193, %rd87; + mul.lo.s64 %rd93, %rd194, %rd87; + mul.lo.s64 %rd96, %rd195, %rd87; + mul.lo.s64 %rd99, %rd196, %rd87; + mul.lo.s64 %rd102, %rd197, %rd87; + cvt.s64.s32 %rd7, %r47; + mul.lo.s64 %rd105, %rd198, %rd87; + .loc 1 28 19 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:28:19 + or.b64 %rd106, %rd6, %rd87; + and.b64 %rd107, %rd106, -4294967296; + setp.ne.b64 %p15, %rd107, 0; + @%p15 bra $L__BB0_20; + bra.uni $L__BB0_19; +$L__BB0_20: + div.s64 %rd199, %rd6, %rd87; + bra.uni $L__BB0_21; +$L__BB0_19: + cvt.u32.u64 %r73, %rd87; + cvt.u32.u64 %r74, %rd6; + div.u32 %r75, %r74, %r73; + cvt.u64.u32 %rd199, %r75; +$L__BB0_21: + .loc 1 0 19 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:0:19 + ld.param.b32 %r44, [triton_red_fused_argmax_1_param_4]; + ld.param.b64 %rd86, [triton_red_fused_argmax_1_param_3]; + ld.param.b64 %rd84, [triton_red_fused_argmax_1_param_0]; + and.b32 %r4, %r2, 63; + .loc 1 27 19 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:27:19 + sub.s64 %rd13, %rd8, %rd90; + sub.s64 %rd18, %rd1, %rd93; + sub.s64 %rd23, %rd2, %rd96; + sub.s64 %rd28, %rd3, %rd99; + sub.s64 %rd33, %rd4, %rd102; + sub.s64 %rd38, %rd5, %rd105; + mul.lo.s64 %rd108, %rd199, %rd87; + sub.s64 %rd43, %rd6, %rd108; + .loc 1 28 19 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:28:19 + or.b64 %rd109, %rd7, %rd87; + and.b64 %rd110, %rd109, -4294967296; + setp.ne.b64 %p16, %rd110, 0; + @%p16 bra $L__BB0_23; + bra.uni $L__BB0_22; +$L__BB0_23: + div.s64 %rd200, %rd7, %rd87; + bra.uni $L__BB0_24; +$L__BB0_22: + cvt.u32.u64 %r76, %rd87; + cvt.u32.u64 %r77, %rd7; + div.u32 %r78, %r77, %r76; + cvt.u64.u32 %rd200, %r78; +$L__BB0_24: + .loc 1 0 19 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:0:19 + ld.param.b64 %rd85, [triton_red_fused_argmax_1_param_1]; + and.b32 %r3, %r2, 448; + .loc 1 24 21 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:24:21 + setp.lt.s32 %p24, %r47, %r44; + setp.lt.s32 %p23, %r48, %r44; + setp.lt.s32 %p22, %r49, %r44; + setp.lt.s32 %p21, %r50, %r44; + setp.lt.s32 %p20, %r51, %r44; + setp.lt.s32 %p19, %r52, %r44; + setp.lt.s32 %p18, %r53, %r44; + setp.lt.s32 %p17, %r54, %r44; + .loc 1 27 19 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:27:19 + mul.lo.s64 %rd116, %rd200, %rd87; + sub.s64 %rd117, %rd7, %rd116; + .loc 1 38 56 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:38:56 + mul.lo.s64 %rd118, %rd193, %rd86; + mul.lo.s64 %rd119, %rd194, %rd86; + mul.lo.s64 %rd120, %rd195, %rd86; + mul.lo.s64 %rd121, %rd196, %rd86; + mul.lo.s64 %rd122, %rd197, %rd86; + mul.lo.s64 %rd123, %rd198, %rd86; + mul.lo.s64 %rd124, %rd199, %rd86; + mul.lo.s64 %rd125, %rd200, %rd86; + mad.lo.s64 %rd126, %rd13, 128000, %rd84; + .loc 1 32 40 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:32:40 + shl.b64 %rd127, %rd118, 2; + add.s64 %rd201, %rd126, %rd127; + mad.lo.s64 %rd128, %rd18, 128000, %rd84; + shl.b64 %rd129, %rd119, 2; + add.s64 %rd202, %rd128, %rd129; + mad.lo.s64 %rd130, %rd23, 128000, %rd84; + shl.b64 %rd131, %rd120, 2; + add.s64 %rd203, %rd130, %rd131; + mad.lo.s64 %rd132, %rd28, 128000, %rd84; + shl.b64 %rd133, %rd121, 2; + add.s64 %rd204, %rd132, %rd133; + mad.lo.s64 %rd134, %rd33, 128000, %rd84; + shl.b64 %rd135, %rd122, 2; + add.s64 %rd205, %rd134, %rd135; + mad.lo.s64 %rd136, %rd38, 128000, %rd84; + shl.b64 %rd137, %rd123, 2; + add.s64 %rd206, %rd136, %rd137; + mad.lo.s64 %rd138, %rd43, 128000, %rd84; + shl.b64 %rd139, %rd124, 2; + add.s64 %rd207, %rd138, %rd139; + mad.lo.s64 %rd140, %rd117, 128000, %rd84; + shl.b64 %rd141, %rd125, 2; + add.s64 %rd208, %rd140, %rd141; + cvt.u64.u32 %rd56, %r4; + mov.b32 %r87, 0fFF800000; + mov.b64 %rd210, {%r87, %r87}; + mul.wide.u32 %rd57, %r4, 4; + mov.b32 %r391, 2147483647; + mov.b64 %rd209, 0; + mov.b64 %rd211, %rd210; + mov.b64 %rd212, %rd210; + mov.b64 %rd213, %rd210; + mov.b32 %r392, %r391; + mov.b32 %r393, %r391; + mov.b32 %r394, %r391; + mov.b32 %r395, %r391; + mov.b32 %r396, %r391; + mov.b32 %r397, %r391; + mov.b32 %r398, %r391; +$L__BB0_25: // =>This Inner Loop Header: Depth=1 + .loc 1 38 34 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:38:34 + add.s64 %rd166, %rd56, %rd209; + add.s64 %rd143, %rd201, %rd57; + add.s64 %rd146, %rd202, %rd57; + add.s64 %rd149, %rd203, %rd57; + add.s64 %rd152, %rd204, %rd57; + add.s64 %rd155, %rd205, %rd57; + add.s64 %rd158, %rd206, %rd57; + add.s64 %rd161, %rd207, %rd57; + .loc 1 38 61 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:38:61 + add.s64 %rd164, %rd208, %rd57; + // begin inline asm + mov.u64 %rd142, 0x0; + createpolicy.fractional.L2::evict_first.b64 %rd142, 1.0; + // end inline asm + mov.b32 %r89, 0; + // begin inline asm + mov.u32 %r88, %r89; + @%p17 ld.global.L1::evict_first.L2::cache_hint.b32 { %r88 }, [ %rd143 + 0 ], %rd142; + // end inline asm + // begin inline asm + mov.u64 %rd145, 0x0; + createpolicy.fractional.L2::evict_first.b64 %rd145, 1.0; + // end inline asm + // begin inline asm + mov.u32 %r90, %r89; + @%p18 ld.global.L1::evict_first.L2::cache_hint.b32 { %r90 }, [ %rd146 + 0 ], %rd145; + // end inline asm + // begin inline asm + mov.u64 %rd148, 0x0; + createpolicy.fractional.L2::evict_first.b64 %rd148, 1.0; + // end inline asm + // begin inline asm + mov.u32 %r92, %r89; + @%p19 ld.global.L1::evict_first.L2::cache_hint.b32 { %r92 }, [ %rd149 + 0 ], %rd148; + // end inline asm + // begin inline asm + mov.u64 %rd151, 0x0; + createpolicy.fractional.L2::evict_first.b64 %rd151, 1.0; + // end inline asm + // begin inline asm + mov.u32 %r94, %r89; + @%p20 ld.global.L1::evict_first.L2::cache_hint.b32 { %r94 }, [ %rd152 + 0 ], %rd151; + // end inline asm + // begin inline asm + mov.u64 %rd154, 0x0; + createpolicy.fractional.L2::evict_first.b64 %rd154, 1.0; + // end inline asm + // begin inline asm + mov.u32 %r96, %r89; + @%p21 ld.global.L1::evict_first.L2::cache_hint.b32 { %r96 }, [ %rd155 + 0 ], %rd154; + // end inline asm + // begin inline asm + mov.u64 %rd157, 0x0; + createpolicy.fractional.L2::evict_first.b64 %rd157, 1.0; + // end inline asm + // begin inline asm + mov.u32 %r98, %r89; + @%p22 ld.global.L1::evict_first.L2::cache_hint.b32 { %r98 }, [ %rd158 + 0 ], %rd157; + // end inline asm + // begin inline asm + mov.u64 %rd160, 0x0; + createpolicy.fractional.L2::evict_first.b64 %rd160, 1.0; + // end inline asm + // begin inline asm + mov.u32 %r100, %r89; + @%p23 ld.global.L1::evict_first.L2::cache_hint.b32 { %r100 }, [ %rd161 + 0 ], %rd160; + // end inline asm + // begin inline asm + mov.u64 %rd163, 0x0; + createpolicy.fractional.L2::evict_first.b64 %rd163, 1.0; + // end inline asm + // begin inline asm + mov.u32 %r102, %r89; + @%p24 ld.global.L1::evict_first.L2::cache_hint.b32 { %r102 }, [ %rd164 + 0 ], %rd163; + // end inline asm +$L__tmp0: + .loc 2 147 29 // triton_helpers.py:147:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:41:38 ] + mov.b64 {%r104, %r105}, %rd210; + setp.nan.f32 %p33, %r104, %r104; + setp.nan.f32 %p34, %r105, %r105; + mov.b64 {%r106, %r107}, %rd211; + setp.nan.f32 %p35, %r106, %r106; + setp.nan.f32 %p36, %r107, %r107; + mov.b64 {%r108, %r109}, %rd212; + setp.nan.f32 %p37, %r108, %r108; + setp.nan.f32 %p38, %r109, %r109; + mov.b64 {%r110, %r111}, %rd213; + setp.nan.f32 %p39, %r110, %r110; + setp.nan.f32 %p40, %r111, %r111; +$L__tmp1: + .loc 1 38 61 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:38:61 + cvt.u64.u32 %rd167, %r100; + shl.b64 %rd168, %rd167, 32; + cvt.u64.u32 %rd169, %r102; + or.b64 %rd170, %rd169, %rd168; + cvt.u64.u32 %rd171, %r96; + shl.b64 %rd172, %rd171, 32; + cvt.u64.u32 %rd173, %r98; + or.b64 %rd174, %rd173, %rd172; + cvt.u64.u32 %rd175, %r92; + shl.b64 %rd176, %rd175, 32; + cvt.u64.u32 %rd177, %r94; + or.b64 %rd178, %rd177, %rd176; + cvt.u64.u32 %rd179, %r88; + shl.b64 %rd180, %rd179, 32; + cvt.u64.u32 %rd181, %r90; + or.b64 %rd182, %rd181, %rd180; +$L__tmp2: + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:41:38 ] + mov.b64 {%r112, %r113}, %rd182; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:41:38 ] + setp.gt.f32 %p41, %r111, %r113; + setp.gt.f32 %p42, %r110, %r112; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:41:38 ] + mov.b64 {%r114, %r115}, %rd178; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:41:38 ] + setp.gt.f32 %p43, %r109, %r115; + setp.gt.f32 %p44, %r108, %r114; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:41:38 ] + mov.b64 {%r116, %r117}, %rd174; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:41:38 ] + setp.gt.f32 %p45, %r107, %r117; + setp.gt.f32 %p46, %r106, %r116; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:41:38 ] + mov.b64 {%r118, %r119}, %rd170; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:41:38 ] + setp.gt.f32 %p47, %r105, %r119; + setp.gt.f32 %p48, %r104, %r118; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:41:38 ] + setp.eq.f32 %p49, %r104, %r118; + setp.eq.f32 %p50, %r105, %r119; + setp.eq.f32 %p51, %r106, %r116; + setp.eq.f32 %p52, %r107, %r117; + setp.eq.f32 %p53, %r108, %r114; + setp.eq.f32 %p54, %r109, %r115; + setp.eq.f32 %p55, %r110, %r112; + setp.eq.f32 %p56, %r111, %r113; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:41:38 ] + setp.nan.f32 %p57, %r113, %r113; + setp.nan.f32 %p58, %r112, %r112; + setp.nan.f32 %p59, %r115, %r115; + setp.nan.f32 %p60, %r114, %r114; + setp.nan.f32 %p61, %r117, %r117; + setp.nan.f32 %p62, %r116, %r116; + setp.nan.f32 %p63, %r119, %r119; + setp.nan.f32 %p64, %r118, %r118; + setp.num.f32 %p65, %r118, %r118; + setp.num.f32 %p66, %r119, %r119; + setp.num.f32 %p67, %r116, %r116; + setp.num.f32 %p68, %r117, %r117; + setp.num.f32 %p69, %r114, %r114; + setp.num.f32 %p70, %r115, %r115; + setp.num.f32 %p71, %r112, %r112; + setp.num.f32 %p72, %r113, %r113; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:41:38 ] + and.pred %p73, %p40, %p72; + and.pred %p74, %p39, %p71; + and.pred %p75, %p38, %p70; + and.pred %p76, %p37, %p69; + and.pred %p77, %p36, %p68; + and.pred %p78, %p35, %p67; + and.pred %p79, %p34, %p66; + and.pred %p80, %p33, %p65; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:41:38 ] + or.pred %p81, %p48, %p80; + or.pred %p82, %p47, %p79; + or.pred %p83, %p46, %p78; + or.pred %p84, %p45, %p77; + or.pred %p85, %p44, %p76; + or.pred %p86, %p43, %p75; + or.pred %p87, %p42, %p74; + or.pred %p88, %p41, %p73; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:41:38 ] + and.pred %p89, %p33, %p64; + and.pred %p90, %p34, %p63; + and.pred %p91, %p35, %p62; + and.pred %p92, %p36, %p61; + and.pred %p93, %p37, %p60; + and.pred %p94, %p38, %p59; + and.pred %p95, %p39, %p58; + and.pred %p96, %p40, %p57; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:41:38 ] + or.pred %p97, %p56, %p96; + or.pred %p98, %p55, %p95; + or.pred %p99, %p54, %p94; + or.pred %p100, %p53, %p93; + or.pred %p101, %p52, %p92; + or.pred %p102, %p51, %p91; + or.pred %p103, %p50, %p90; + or.pred %p104, %p49, %p89; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:41:38 ] + cvt.s64.s32 %rd183, %r391; + cvt.s64.s32 %rd184, %r392; + cvt.s64.s32 %rd185, %r393; + cvt.s64.s32 %rd186, %r394; + cvt.s64.s32 %rd187, %r395; + cvt.s64.s32 %rd188, %r396; + cvt.s64.s32 %rd189, %r397; + cvt.s64.s32 %rd190, %r398; + setp.gt.s64 %p105, %rd166, %rd190; + setp.gt.s64 %p106, %rd166, %rd189; + setp.gt.s64 %p107, %rd166, %rd188; + setp.gt.s64 %p108, %rd166, %rd187; + setp.gt.s64 %p109, %rd166, %rd186; + setp.gt.s64 %p110, %rd166, %rd185; + setp.gt.s64 %p111, %rd166, %rd184; + setp.gt.s64 %p112, %rd166, %rd183; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:41:38 ] + and.pred %p113, %p112, %p104; + and.pred %p114, %p111, %p103; + and.pred %p115, %p110, %p102; + and.pred %p116, %p109, %p101; + and.pred %p117, %p108, %p100; + and.pred %p118, %p107, %p99; + and.pred %p119, %p106, %p98; + and.pred %p120, %p105, %p97; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:41:38 ] + or.pred %p121, %p88, %p120; + or.pred %p122, %p87, %p119; + or.pred %p123, %p86, %p118; + or.pred %p124, %p85, %p117; + or.pred %p125, %p84, %p116; + or.pred %p126, %p83, %p115; + or.pred %p127, %p82, %p114; + or.pred %p128, %p81, %p113; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:41:38 ] + selp.f32 %r120, %r104, %r118, %p128; + selp.f32 %r121, %r105, %r119, %p127; + selp.f32 %r122, %r106, %r116, %p126; + selp.f32 %r123, %r107, %r117, %p125; + selp.f32 %r124, %r108, %r114, %p124; + selp.f32 %r125, %r109, %r115, %p123; + selp.f32 %r126, %r110, %r112, %p122; + selp.f32 %r127, %r111, %r113, %p121; + cvt.u32.u64 %r128, %rd166; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:41:38 ] + selp.b32 %r129, %r391, %r128, %p128; + selp.b32 %r130, %r392, %r128, %p127; + selp.b32 %r131, %r393, %r128, %p126; + selp.b32 %r132, %r394, %r128, %p125; + selp.b32 %r133, %r395, %r128, %p124; + selp.b32 %r134, %r396, %r128, %p123; + selp.b32 %r135, %r397, %r128, %p122; + selp.b32 %r136, %r398, %r128, %p121; +$L__tmp3: + .loc 1 43 54 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:43:54 + selp.f32 %r137, %r127, %r111, %p17; + selp.f32 %r138, %r126, %r110, %p18; + mov.b64 %rd213, {%r138, %r137}; + selp.f32 %r139, %r125, %r109, %p19; + selp.f32 %r140, %r124, %r108, %p20; + mov.b64 %rd212, {%r140, %r139}; + selp.f32 %r141, %r123, %r107, %p21; + selp.f32 %r142, %r122, %r106, %p22; + mov.b64 %rd211, {%r142, %r141}; + selp.f32 %r143, %r121, %r105, %p23; + selp.f32 %r144, %r120, %r104, %p24; + mov.b64 %rd210, {%r144, %r143}; + .loc 1 44 66 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:44:66 + selp.b32 %r398, %r136, %r398, %p17; + selp.b32 %r397, %r135, %r397, %p18; + selp.b32 %r396, %r134, %r396, %p19; + selp.b32 %r395, %r133, %r395, %p20; + selp.b32 %r394, %r132, %r394, %p21; + selp.b32 %r393, %r131, %r393, %p22; + selp.b32 %r392, %r130, %r392, %p23; + selp.b32 %r391, %r129, %r391, %p24; + .loc 1 32 40 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:32:40 + add.s64 %rd75, %rd209, 64; + add.s64 %rd208, %rd208, 256; + add.s64 %rd207, %rd207, 256; + add.s64 %rd206, %rd206, 256; + add.s64 %rd205, %rd205, 256; + add.s64 %rd204, %rd204, 256; + add.s64 %rd203, %rd203, 256; + add.s64 %rd202, %rd202, 256; + add.s64 %rd201, %rd201, 256; + setp.lt.u64 %p129, %rd209, 31936; + mov.b64 %rd209, %rd75; + @%p129 bra $L__BB0_25; +// %bb.26: + .loc 1 23 23 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:23:23 + or.b32 %r185, %r1, %r4; + .loc 1 24 21 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:24:21 + setp.lt.s32 %p151, %r185, %r44; + .loc 1 23 44 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:23:44 + and.b32 %r186, %r2, 31; +$L__tmp4: + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + mov.b64 {%r187, %r188}, %rd213; + shfl.sync.bfly.b32 %r189, %r188, 16, 31, -1; + shfl.sync.bfly.b32 %r190, %r398, 16, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p152, %r188, %r189; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p153, %r188, %r189; + .loc 2 147 29 // triton_helpers.py:147:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + mov.b64 {%r191, %r192}, %rd210; + setp.nan.f32 %p154, %r191, %r191; + setp.nan.f32 %p155, %r192, %r192; + mov.b64 {%r193, %r194}, %rd211; + setp.nan.f32 %p156, %r193, %r193; + setp.nan.f32 %p157, %r194, %r194; + mov.b64 {%r195, %r196}, %rd212; + setp.nan.f32 %p158, %r195, %r195; + setp.nan.f32 %p159, %r196, %r196; + setp.nan.f32 %p160, %r187, %r187; + setp.nan.f32 %p161, %r188, %r188; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p162, %r189, %r189; + setp.num.f32 %p163, %r189, %r189; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p164, %p161, %p163; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p165, %p152, %p164; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p166, %p161, %p162; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p167, %p153, %p166; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p168, %r398, %r190; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p169, %p168, %p167; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p170, %p165, %p169; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.f32 %r197, %r188, %r189, %p170; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r198, %r398, %r190, %p170; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r199, %r197, 8, 31, -1; + shfl.sync.bfly.b32 %r200, %r198, 8, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p171, %r197, %r199; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p172, %r197, %r199; + .loc 2 147 29 // triton_helpers.py:147:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p173, %r197, %r197; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p174, %r199, %r199; + setp.num.f32 %p175, %r199, %r199; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p176, %p173, %p175; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p177, %p171, %p176; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p178, %p174, %p173; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p179, %p172, %p178; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p180, %r198, %r200; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p181, %p180, %p179; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p182, %p177, %p181; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.f32 %r201, %r197, %r199, %p182; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r202, %r198, %r200, %p182; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r203, %r201, 4, 31, -1; + shfl.sync.bfly.b32 %r204, %r202, 4, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p183, %r201, %r203; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p184, %r201, %r203; + .loc 2 147 29 // triton_helpers.py:147:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p185, %r201, %r201; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p186, %r203, %r203; + setp.num.f32 %p187, %r203, %r203; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p188, %p185, %p187; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p189, %p183, %p188; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p190, %p186, %p185; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p191, %p184, %p190; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p192, %r202, %r204; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p193, %p192, %p191; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p194, %p189, %p193; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.f32 %r205, %r201, %r203, %p194; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r206, %r202, %r204, %p194; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r207, %r205, 2, 31, -1; + shfl.sync.bfly.b32 %r208, %r206, 2, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p195, %r205, %r207; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p196, %r205, %r207; + .loc 2 147 29 // triton_helpers.py:147:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p197, %r205, %r205; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p198, %r207, %r207; + setp.num.f32 %p199, %r207, %r207; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p200, %p197, %p199; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p201, %p195, %p200; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p202, %p198, %p197; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p203, %p196, %p202; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p204, %r206, %r208; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p205, %p204, %p203; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p206, %p201, %p205; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.f32 %r209, %r205, %r207, %p206; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r210, %r206, %r208, %p206; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r211, %r209, 1, 31, -1; + shfl.sync.bfly.b32 %r212, %r210, 1, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p207, %r209, %r211; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p208, %r209, %r211; + .loc 2 147 29 // triton_helpers.py:147:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p209, %r209, %r209; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p210, %r211, %r211; + setp.num.f32 %p211, %r211, %r211; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p212, %p209, %p211; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p213, %p207, %p212; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p214, %p210, %p209; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p215, %p208, %p214; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p216, %r210, %r212; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p217, %p216, %p215; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p218, %p213, %p217; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r148, %r210, %r212, %p218; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r213, %r187, 16, 31, -1; + shfl.sync.bfly.b32 %r214, %r397, 16, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p219, %r187, %r213; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p220, %r187, %r213; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p221, %r213, %r213; + setp.num.f32 %p222, %r213, %r213; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p223, %p160, %p222; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p224, %p219, %p223; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p225, %p160, %p221; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p226, %p220, %p225; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p227, %r397, %r214; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p228, %p227, %p226; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p229, %p224, %p228; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.f32 %r215, %r187, %r213, %p229; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r216, %r397, %r214, %p229; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r217, %r215, 8, 31, -1; + shfl.sync.bfly.b32 %r218, %r216, 8, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p230, %r215, %r217; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p231, %r215, %r217; + .loc 2 147 29 // triton_helpers.py:147:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p232, %r215, %r215; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p233, %r217, %r217; + setp.num.f32 %p234, %r217, %r217; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p235, %p232, %p234; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p236, %p230, %p235; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p237, %p233, %p232; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p238, %p231, %p237; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p239, %r216, %r218; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p240, %p239, %p238; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p241, %p236, %p240; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.f32 %r219, %r215, %r217, %p241; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r220, %r216, %r218, %p241; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r221, %r219, 4, 31, -1; + shfl.sync.bfly.b32 %r222, %r220, 4, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p242, %r219, %r221; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p243, %r219, %r221; + .loc 2 147 29 // triton_helpers.py:147:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p244, %r219, %r219; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p245, %r221, %r221; + setp.num.f32 %p246, %r221, %r221; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p247, %p244, %p246; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p248, %p242, %p247; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p249, %p245, %p244; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p250, %p243, %p249; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p251, %r220, %r222; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p252, %p251, %p250; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p253, %p248, %p252; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.f32 %r223, %r219, %r221, %p253; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r224, %r220, %r222, %p253; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r225, %r223, 2, 31, -1; + shfl.sync.bfly.b32 %r226, %r224, 2, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p254, %r223, %r225; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p255, %r223, %r225; + .loc 2 147 29 // triton_helpers.py:147:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p256, %r223, %r223; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p257, %r225, %r225; + setp.num.f32 %p258, %r225, %r225; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p259, %p256, %p258; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p260, %p254, %p259; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p261, %p257, %p256; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p262, %p255, %p261; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p263, %r224, %r226; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p264, %p263, %p262; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p265, %p260, %p264; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.f32 %r227, %r223, %r225, %p265; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r228, %r224, %r226, %p265; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r229, %r227, 1, 31, -1; + shfl.sync.bfly.b32 %r230, %r228, 1, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p266, %r227, %r229; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p267, %r227, %r229; + .loc 2 147 29 // triton_helpers.py:147:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p268, %r227, %r227; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p269, %r229, %r229; + setp.num.f32 %p270, %r229, %r229; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p271, %p268, %p270; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p272, %p266, %p271; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p273, %p269, %p268; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p274, %p267, %p273; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p275, %r228, %r230; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p276, %p275, %p274; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p277, %p272, %p276; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r152, %r228, %r230, %p277; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r231, %r196, 16, 31, -1; + shfl.sync.bfly.b32 %r232, %r396, 16, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p278, %r196, %r231; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p279, %r196, %r231; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p280, %r231, %r231; + setp.num.f32 %p281, %r231, %r231; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p282, %p159, %p281; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p283, %p278, %p282; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p284, %p159, %p280; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p285, %p279, %p284; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p286, %r396, %r232; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p287, %p286, %p285; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p288, %p283, %p287; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.f32 %r233, %r196, %r231, %p288; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r234, %r396, %r232, %p288; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r235, %r233, 8, 31, -1; + shfl.sync.bfly.b32 %r236, %r234, 8, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p289, %r233, %r235; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p290, %r233, %r235; + .loc 2 147 29 // triton_helpers.py:147:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p291, %r233, %r233; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p292, %r235, %r235; + setp.num.f32 %p293, %r235, %r235; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p294, %p291, %p293; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p295, %p289, %p294; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p296, %p292, %p291; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p297, %p290, %p296; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p298, %r234, %r236; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p299, %p298, %p297; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p300, %p295, %p299; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.f32 %r237, %r233, %r235, %p300; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r238, %r234, %r236, %p300; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r239, %r237, 4, 31, -1; + shfl.sync.bfly.b32 %r240, %r238, 4, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p301, %r237, %r239; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p302, %r237, %r239; + .loc 2 147 29 // triton_helpers.py:147:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p303, %r237, %r237; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p304, %r239, %r239; + setp.num.f32 %p305, %r239, %r239; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p306, %p303, %p305; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p307, %p301, %p306; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p308, %p304, %p303; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p309, %p302, %p308; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p310, %r238, %r240; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p311, %p310, %p309; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p312, %p307, %p311; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.f32 %r241, %r237, %r239, %p312; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r242, %r238, %r240, %p312; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r243, %r241, 2, 31, -1; + shfl.sync.bfly.b32 %r244, %r242, 2, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p313, %r241, %r243; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p314, %r241, %r243; + .loc 2 147 29 // triton_helpers.py:147:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p315, %r241, %r241; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p316, %r243, %r243; + setp.num.f32 %p317, %r243, %r243; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p318, %p315, %p317; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p319, %p313, %p318; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p320, %p316, %p315; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p321, %p314, %p320; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p322, %r242, %r244; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p323, %p322, %p321; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p324, %p319, %p323; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.f32 %r245, %r241, %r243, %p324; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r246, %r242, %r244, %p324; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r247, %r245, 1, 31, -1; + shfl.sync.bfly.b32 %r248, %r246, 1, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p325, %r245, %r247; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p326, %r245, %r247; + .loc 2 147 29 // triton_helpers.py:147:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p327, %r245, %r245; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p328, %r247, %r247; + setp.num.f32 %p329, %r247, %r247; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p330, %p327, %p329; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p331, %p325, %p330; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p332, %p328, %p327; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p333, %p326, %p332; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p334, %r246, %r248; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p335, %p334, %p333; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p336, %p331, %p335; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r156, %r246, %r248, %p336; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r249, %r195, 16, 31, -1; + shfl.sync.bfly.b32 %r250, %r395, 16, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p337, %r195, %r249; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p338, %r195, %r249; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p339, %r249, %r249; + setp.num.f32 %p340, %r249, %r249; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p341, %p158, %p340; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p342, %p337, %p341; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p343, %p158, %p339; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p344, %p338, %p343; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p345, %r395, %r250; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p346, %p345, %p344; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p347, %p342, %p346; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.f32 %r251, %r195, %r249, %p347; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r252, %r395, %r250, %p347; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r253, %r251, 8, 31, -1; + shfl.sync.bfly.b32 %r254, %r252, 8, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p348, %r251, %r253; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p349, %r251, %r253; + .loc 2 147 29 // triton_helpers.py:147:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p350, %r251, %r251; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p351, %r253, %r253; + setp.num.f32 %p352, %r253, %r253; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p353, %p350, %p352; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p354, %p348, %p353; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p355, %p351, %p350; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p356, %p349, %p355; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p357, %r252, %r254; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p358, %p357, %p356; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p359, %p354, %p358; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.f32 %r255, %r251, %r253, %p359; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r256, %r252, %r254, %p359; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r257, %r255, 4, 31, -1; + shfl.sync.bfly.b32 %r258, %r256, 4, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p360, %r255, %r257; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p361, %r255, %r257; + .loc 2 147 29 // triton_helpers.py:147:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p362, %r255, %r255; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p363, %r257, %r257; + setp.num.f32 %p364, %r257, %r257; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p365, %p362, %p364; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p366, %p360, %p365; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p367, %p363, %p362; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p368, %p361, %p367; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p369, %r256, %r258; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p370, %p369, %p368; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p371, %p366, %p370; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.f32 %r259, %r255, %r257, %p371; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r260, %r256, %r258, %p371; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r261, %r259, 2, 31, -1; + shfl.sync.bfly.b32 %r262, %r260, 2, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p372, %r259, %r261; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p373, %r259, %r261; + .loc 2 147 29 // triton_helpers.py:147:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p374, %r259, %r259; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p375, %r261, %r261; + setp.num.f32 %p376, %r261, %r261; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p377, %p374, %p376; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p378, %p372, %p377; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p379, %p375, %p374; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p380, %p373, %p379; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p381, %r260, %r262; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p382, %p381, %p380; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p383, %p378, %p382; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.f32 %r263, %r259, %r261, %p383; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r264, %r260, %r262, %p383; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r265, %r263, 1, 31, -1; + shfl.sync.bfly.b32 %r266, %r264, 1, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p384, %r263, %r265; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p385, %r263, %r265; + .loc 2 147 29 // triton_helpers.py:147:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p386, %r263, %r263; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p387, %r265, %r265; + setp.num.f32 %p388, %r265, %r265; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p389, %p386, %p388; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p390, %p384, %p389; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p391, %p387, %p386; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p392, %p385, %p391; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p393, %r264, %r266; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p394, %p393, %p392; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p395, %p390, %p394; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r160, %r264, %r266, %p395; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r267, %r194, 16, 31, -1; + shfl.sync.bfly.b32 %r268, %r394, 16, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p396, %r194, %r267; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p397, %r194, %r267; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p398, %r267, %r267; + setp.num.f32 %p399, %r267, %r267; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p400, %p157, %p399; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p401, %p396, %p400; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p402, %p157, %p398; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p403, %p397, %p402; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p404, %r394, %r268; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p405, %p404, %p403; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p406, %p401, %p405; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.f32 %r269, %r194, %r267, %p406; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r270, %r394, %r268, %p406; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r271, %r269, 8, 31, -1; + shfl.sync.bfly.b32 %r272, %r270, 8, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p407, %r269, %r271; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p408, %r269, %r271; + .loc 2 147 29 // triton_helpers.py:147:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p409, %r269, %r269; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p410, %r271, %r271; + setp.num.f32 %p411, %r271, %r271; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p412, %p409, %p411; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p413, %p407, %p412; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p414, %p410, %p409; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p415, %p408, %p414; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p416, %r270, %r272; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p417, %p416, %p415; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p418, %p413, %p417; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.f32 %r273, %r269, %r271, %p418; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r274, %r270, %r272, %p418; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r275, %r273, 4, 31, -1; + shfl.sync.bfly.b32 %r276, %r274, 4, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p419, %r273, %r275; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p420, %r273, %r275; + .loc 2 147 29 // triton_helpers.py:147:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p421, %r273, %r273; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p422, %r275, %r275; + setp.num.f32 %p423, %r275, %r275; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p424, %p421, %p423; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p425, %p419, %p424; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p426, %p422, %p421; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p427, %p420, %p426; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p428, %r274, %r276; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p429, %p428, %p427; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p430, %p425, %p429; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.f32 %r277, %r273, %r275, %p430; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r278, %r274, %r276, %p430; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r279, %r277, 2, 31, -1; + shfl.sync.bfly.b32 %r280, %r278, 2, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p431, %r277, %r279; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p432, %r277, %r279; + .loc 2 147 29 // triton_helpers.py:147:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p433, %r277, %r277; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p434, %r279, %r279; + setp.num.f32 %p435, %r279, %r279; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p436, %p433, %p435; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p437, %p431, %p436; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p438, %p434, %p433; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p439, %p432, %p438; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p440, %r278, %r280; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p441, %p440, %p439; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p442, %p437, %p441; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.f32 %r281, %r277, %r279, %p442; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r282, %r278, %r280, %p442; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r283, %r281, 1, 31, -1; + shfl.sync.bfly.b32 %r284, %r282, 1, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p443, %r281, %r283; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p444, %r281, %r283; + .loc 2 147 29 // triton_helpers.py:147:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p445, %r281, %r281; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p446, %r283, %r283; + setp.num.f32 %p447, %r283, %r283; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p448, %p445, %p447; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p449, %p443, %p448; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p450, %p446, %p445; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p451, %p444, %p450; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p452, %r282, %r284; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p453, %p452, %p451; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p454, %p449, %p453; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r164, %r282, %r284, %p454; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r285, %r193, 16, 31, -1; + shfl.sync.bfly.b32 %r286, %r393, 16, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p455, %r193, %r285; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p456, %r193, %r285; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p457, %r285, %r285; + setp.num.f32 %p458, %r285, %r285; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p459, %p156, %p458; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p460, %p455, %p459; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p461, %p156, %p457; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p462, %p456, %p461; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p463, %r393, %r286; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p464, %p463, %p462; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p465, %p460, %p464; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.f32 %r287, %r193, %r285, %p465; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r288, %r393, %r286, %p465; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r289, %r287, 8, 31, -1; + shfl.sync.bfly.b32 %r290, %r288, 8, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p466, %r287, %r289; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p467, %r287, %r289; + .loc 2 147 29 // triton_helpers.py:147:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p468, %r287, %r287; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p469, %r289, %r289; + setp.num.f32 %p470, %r289, %r289; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p471, %p468, %p470; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p472, %p466, %p471; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p473, %p469, %p468; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p474, %p467, %p473; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p475, %r288, %r290; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p476, %p475, %p474; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p477, %p472, %p476; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.f32 %r291, %r287, %r289, %p477; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r292, %r288, %r290, %p477; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r293, %r291, 4, 31, -1; + shfl.sync.bfly.b32 %r294, %r292, 4, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p478, %r291, %r293; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p479, %r291, %r293; + .loc 2 147 29 // triton_helpers.py:147:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p480, %r291, %r291; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p481, %r293, %r293; + setp.num.f32 %p482, %r293, %r293; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p483, %p480, %p482; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p484, %p478, %p483; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p485, %p481, %p480; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p486, %p479, %p485; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p487, %r292, %r294; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p488, %p487, %p486; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p489, %p484, %p488; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.f32 %r295, %r291, %r293, %p489; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r296, %r292, %r294, %p489; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r297, %r295, 2, 31, -1; + shfl.sync.bfly.b32 %r298, %r296, 2, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p490, %r295, %r297; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p491, %r295, %r297; + .loc 2 147 29 // triton_helpers.py:147:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p492, %r295, %r295; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p493, %r297, %r297; + setp.num.f32 %p494, %r297, %r297; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p495, %p492, %p494; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p496, %p490, %p495; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p497, %p493, %p492; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p498, %p491, %p497; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p499, %r296, %r298; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p500, %p499, %p498; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p501, %p496, %p500; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.f32 %r299, %r295, %r297, %p501; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r300, %r296, %r298, %p501; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r301, %r299, 1, 31, -1; + shfl.sync.bfly.b32 %r302, %r300, 1, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p502, %r299, %r301; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p503, %r299, %r301; + .loc 2 147 29 // triton_helpers.py:147:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p504, %r299, %r299; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p505, %r301, %r301; + setp.num.f32 %p506, %r301, %r301; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p507, %p504, %p506; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p508, %p502, %p507; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p509, %p505, %p504; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p510, %p503, %p509; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p511, %r300, %r302; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p512, %p511, %p510; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p513, %p508, %p512; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r168, %r300, %r302, %p513; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r303, %r192, 16, 31, -1; + shfl.sync.bfly.b32 %r304, %r392, 16, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p514, %r192, %r303; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p515, %r192, %r303; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p516, %r303, %r303; + setp.num.f32 %p517, %r303, %r303; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p518, %p155, %p517; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p519, %p514, %p518; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p520, %p155, %p516; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p521, %p515, %p520; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p522, %r392, %r304; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p523, %p522, %p521; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p524, %p519, %p523; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.f32 %r305, %r192, %r303, %p524; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r306, %r392, %r304, %p524; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r307, %r305, 8, 31, -1; + shfl.sync.bfly.b32 %r308, %r306, 8, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p525, %r305, %r307; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p526, %r305, %r307; + .loc 2 147 29 // triton_helpers.py:147:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p527, %r305, %r305; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p528, %r307, %r307; + setp.num.f32 %p529, %r307, %r307; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p530, %p527, %p529; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p531, %p525, %p530; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p532, %p528, %p527; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p533, %p526, %p532; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p534, %r306, %r308; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p535, %p534, %p533; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p536, %p531, %p535; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.f32 %r309, %r305, %r307, %p536; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r310, %r306, %r308, %p536; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r311, %r309, 4, 31, -1; + shfl.sync.bfly.b32 %r312, %r310, 4, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p537, %r309, %r311; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p538, %r309, %r311; + .loc 2 147 29 // triton_helpers.py:147:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p539, %r309, %r309; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p540, %r311, %r311; + setp.num.f32 %p541, %r311, %r311; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p542, %p539, %p541; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p543, %p537, %p542; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p544, %p540, %p539; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p545, %p538, %p544; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p546, %r310, %r312; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p547, %p546, %p545; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p548, %p543, %p547; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.f32 %r313, %r309, %r311, %p548; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r314, %r310, %r312, %p548; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r315, %r313, 2, 31, -1; + shfl.sync.bfly.b32 %r316, %r314, 2, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p549, %r313, %r315; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p550, %r313, %r315; + .loc 2 147 29 // triton_helpers.py:147:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p551, %r313, %r313; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p552, %r315, %r315; + setp.num.f32 %p553, %r315, %r315; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p554, %p551, %p553; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p555, %p549, %p554; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p556, %p552, %p551; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p557, %p550, %p556; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p558, %r314, %r316; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p559, %p558, %p557; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p560, %p555, %p559; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.f32 %r317, %r313, %r315, %p560; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r318, %r314, %r316, %p560; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r319, %r317, 1, 31, -1; + shfl.sync.bfly.b32 %r320, %r318, 1, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p561, %r317, %r319; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p562, %r317, %r319; + .loc 2 147 29 // triton_helpers.py:147:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p563, %r317, %r317; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p564, %r319, %r319; + setp.num.f32 %p565, %r319, %r319; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p566, %p563, %p565; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p567, %p561, %p566; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p568, %p564, %p563; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p569, %p562, %p568; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p570, %r318, %r320; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p571, %p570, %p569; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p572, %p567, %p571; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r172, %r318, %r320, %p572; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r321, %r191, 16, 31, -1; + shfl.sync.bfly.b32 %r322, %r391, 16, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p573, %r191, %r321; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p574, %r191, %r321; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p575, %r321, %r321; + setp.num.f32 %p576, %r321, %r321; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p577, %p154, %p576; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p578, %p573, %p577; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p579, %p154, %p575; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p580, %p574, %p579; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p581, %r391, %r322; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p582, %p581, %p580; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p583, %p578, %p582; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.f32 %r323, %r191, %r321, %p583; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r324, %r391, %r322, %p583; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r325, %r323, 8, 31, -1; + shfl.sync.bfly.b32 %r326, %r324, 8, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p584, %r323, %r325; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p585, %r323, %r325; + .loc 2 147 29 // triton_helpers.py:147:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p586, %r323, %r323; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p587, %r325, %r325; + setp.num.f32 %p588, %r325, %r325; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p589, %p586, %p588; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p590, %p584, %p589; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p591, %p587, %p586; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p592, %p585, %p591; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p593, %r324, %r326; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p594, %p593, %p592; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p595, %p590, %p594; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.f32 %r327, %r323, %r325, %p595; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r328, %r324, %r326, %p595; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r329, %r327, 4, 31, -1; + shfl.sync.bfly.b32 %r330, %r328, 4, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p596, %r327, %r329; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p597, %r327, %r329; + .loc 2 147 29 // triton_helpers.py:147:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p598, %r327, %r327; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p599, %r329, %r329; + setp.num.f32 %p600, %r329, %r329; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p601, %p598, %p600; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p602, %p596, %p601; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p603, %p599, %p598; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p604, %p597, %p603; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p605, %r328, %r330; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p606, %p605, %p604; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p607, %p602, %p606; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.f32 %r331, %r327, %r329, %p607; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r332, %r328, %r330, %p607; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r333, %r331, 2, 31, -1; + shfl.sync.bfly.b32 %r334, %r332, 2, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p608, %r331, %r333; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p609, %r331, %r333; + .loc 2 147 29 // triton_helpers.py:147:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p610, %r331, %r331; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p611, %r333, %r333; + setp.num.f32 %p612, %r333, %r333; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p613, %p610, %p612; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p614, %p608, %p613; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p615, %p611, %p610; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p616, %p609, %p615; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p617, %r332, %r334; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p618, %p617, %p616; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p619, %p614, %p618; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.f32 %r335, %r331, %r333, %p619; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r336, %r332, %r334, %p619; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + shfl.sync.bfly.b32 %r337, %r335, 1, 31, -1; + shfl.sync.bfly.b32 %r338, %r336, 1, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p620, %r335, %r337; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p621, %r335, %r337; + .loc 2 147 29 // triton_helpers.py:147:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p622, %r335, %r335; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p623, %r337, %r337; + setp.num.f32 %p624, %r337, %r337; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p625, %p622, %p624; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p626, %p620, %p625; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p627, %p623, %p622; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p628, %p621, %p627; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p629, %r336, %r338; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p630, %p629, %p628; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p631, %p626, %p630; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r176, %r336, %r338, %p631; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + bfe.u32 %r339, %r2, 5, 1; + setp.eq.b32 %p130, %r186, 0; + shr.u32 %r340, %r3, 5; + or.b32 %r341, %r340, %r339; + shl.b32 %r342, %r341, 2; + mov.b32 %r343, global_smem; + add.s32 %r145, %r343, %r342; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r146, %r209, %r211, %p218; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + // begin inline asm + @%p130 st.shared.b32 [ %r145 + 0 ], %r146; + // end inline asm + add.s32 %r344, %r343, 512; + add.s32 %r147, %r344, %r342; + // begin inline asm + @%p130 st.shared.b32 [ %r147 + 0 ], %r148; + // end inline asm + shl.b32 %r345, %r339, 2; + shl.b32 %r346, %r5, 3; + or.b32 %r347, %r346, %r345; + add.s32 %r149, %r343, %r347; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r150, %r227, %r229, %p277; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + // begin inline asm + @%p130 st.shared.b32 [ %r149 + 0 ], %r150; + // end inline asm + add.s32 %r151, %r344, %r347; + // begin inline asm + @%p130 st.shared.b32 [ %r151 + 0 ], %r152; + // end inline asm + shl.b32 %r348, %r6, 3; + or.b32 %r349, %r348, %r345; + add.s32 %r153, %r343, %r349; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r154, %r245, %r247, %p336; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + // begin inline asm + @%p130 st.shared.b32 [ %r153 + 0 ], %r154; + // end inline asm + add.s32 %r155, %r344, %r349; + // begin inline asm + @%p130 st.shared.b32 [ %r155 + 0 ], %r156; + // end inline asm + shl.b32 %r350, %r7, 3; + or.b32 %r351, %r350, %r345; + add.s32 %r157, %r343, %r351; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r158, %r263, %r265, %p395; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + // begin inline asm + @%p130 st.shared.b32 [ %r157 + 0 ], %r158; + // end inline asm + add.s32 %r159, %r344, %r351; + // begin inline asm + @%p130 st.shared.b32 [ %r159 + 0 ], %r160; + // end inline asm + shl.b32 %r352, %r11, 3; + or.b32 %r353, %r352, %r345; + add.s32 %r161, %r343, %r353; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r162, %r281, %r283, %p454; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + // begin inline asm + @%p130 st.shared.b32 [ %r161 + 0 ], %r162; + // end inline asm + add.s32 %r163, %r344, %r353; + // begin inline asm + @%p130 st.shared.b32 [ %r163 + 0 ], %r164; + // end inline asm + shl.b32 %r354, %r10, 3; + or.b32 %r355, %r354, %r345; + add.s32 %r165, %r343, %r355; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r166, %r299, %r301, %p513; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + // begin inline asm + @%p130 st.shared.b32 [ %r165 + 0 ], %r166; + // end inline asm + add.s32 %r167, %r344, %r355; + // begin inline asm + @%p130 st.shared.b32 [ %r167 + 0 ], %r168; + // end inline asm + shl.b32 %r356, %r9, 3; + or.b32 %r357, %r356, %r345; + add.s32 %r169, %r343, %r357; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r170, %r317, %r319, %p572; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + // begin inline asm + @%p130 st.shared.b32 [ %r169 + 0 ], %r170; + // end inline asm + add.s32 %r171, %r344, %r357; + // begin inline asm + @%p130 st.shared.b32 [ %r171 + 0 ], %r172; + // end inline asm + shl.b32 %r358, %r8, 3; + or.b32 %r359, %r358, %r345; + add.s32 %r173, %r343, %r359; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r174, %r335, %r337, %p631; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + // begin inline asm + @%p130 st.shared.b32 [ %r173 + 0 ], %r174; + // end inline asm + add.s32 %r175, %r344, %r359; + // begin inline asm + @%p130 st.shared.b32 [ %r175 + 0 ], %r176; + // end inline asm + bar.sync 0; + setp.lt.u32 %p146, %r2, 128; + shl.b32 %r360, %r2, 2; + add.s32 %r178, %r343, %r360; + // begin inline asm + @%p146 ld.shared.b32 %r177, [ %r178 + 0 ]; + // end inline asm + add.s32 %r180, %r344, %r360; + // begin inline asm + @%p146 ld.shared.b32 %r179, [ %r180 + 0 ]; + // end inline asm + shfl.sync.bfly.b32 %r361, %r177, 1, 31, -1; + shfl.sync.bfly.b32 %r362, %r179, 1, 31, -1; + .loc 2 144 21 // triton_helpers.py:144:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.gt.f32 %p632, %r177, %r361; + .loc 2 145 23 // triton_helpers.py:145:23 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.eq.f32 %p633, %r177, %r361; + .loc 2 147 29 // triton_helpers.py:147:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p634, %r177, %r177; + .loc 2 148 29 // triton_helpers.py:148:29 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.nan.f32 %p635, %r361, %r361; + setp.num.f32 %p636, %r361, %r361; + .loc 2 149 27 // triton_helpers.py:149:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p637, %p634, %p636; + .loc 2 149 16 // triton_helpers.py:149:16 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p638, %p632, %p637; + .loc 2 151 27 // triton_helpers.py:151:27 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p639, %p634, %p635; + .loc 2 151 17 // triton_helpers.py:151:17 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p640, %p633, %p639; + .loc 2 154 31 // triton_helpers.py:154:31 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + setp.lt.s32 %p641, %r179, %r362; + .loc 2 154 21 // triton_helpers.py:154:21 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.pred %p642, %p641, %p640; + .loc 2 154 12 // triton_helpers.py:154:12 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + or.pred %p643, %p638, %p642; + .loc 2 155 69 // triton_helpers.py:155:69 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r184, %r179, %r362, %p643; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + and.b32 %r363, %r2, 897; + setp.eq.b32 %p148, %r363, 0; + .loc 2 155 35 // triton_helpers.py:155:35 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + selp.b32 %r182, %r177, %r361, %p643; + .loc 2 165 42 // triton_helpers.py:165:42 @[ cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:45:75 ] + // begin inline asm + @%p148 st.shared.b32 [ %r178 + 0 ], %r182; + // end inline asm + // begin inline asm + @%p148 st.shared.b32 [ %r180 + 0 ], %r184; + // end inline asm + bar.sync 0; + shr.u32 %r364, %r3, 3; + add.s32 %r365, %r344, %r364; + ld.shared.b32 %r366, [%r365]; + add.s32 %r367, %r344, %r346; + ld.shared.b32 %r368, [%r367]; + add.s32 %r369, %r344, %r348; + ld.shared.b32 %r370, [%r369]; + add.s32 %r371, %r344, %r350; + ld.shared.b32 %r372, [%r371]; + add.s32 %r373, %r344, %r352; + ld.shared.b32 %r374, [%r373]; + add.s32 %r375, %r344, %r354; + ld.shared.b32 %r376, [%r375]; + add.s32 %r377, %r344, %r356; + ld.shared.b32 %r378, [%r377]; + add.s32 %r379, %r344, %r358; + ld.shared.b32 %r380, [%r379]; +$L__tmp5: + .loc 1 47 25 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:47:25 + mad.wide.s32 %rd192, %r185, 8, %rd85; + .loc 1 47 36 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:47:36 + bar.sync 0; + shr.u32 %r381, %r3, 2; + add.s32 %r382, %r343, %r381; + st.shared.v4.b32 [%r382], {%r366, %r368, %r370, %r372}; + st.shared.v4.b32 [%r382+128], {%r374, %r376, %r378, %r380}; + bar.sync 0; + shl.b32 %r383, %r2, 4; + and.b32 %r384, %r383, 112; + shr.u32 %r385, %r2, 1; + and.b32 %r386, %r385, 12; + and.b32 %r387, %r360, 128; + add.s32 %r388, %r343, %r384; + add.s32 %r389, %r388, %r387; + add.s32 %r390, %r389, %r386; + ld.shared.s32 %rd191, [%r390]; + setp.eq.b32 %p644, %r3, 0; + and.pred %p150, %p644, %p151; + // begin inline asm + @%p150 st.global.b64 [ %rd192 + 0 ], { %rd191 }; + // end inline asm + .loc 1 47 4 // cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py:47:4 + ret; +$L__tmp6: +$L__func_end0: + // -- End function +} + .file 1 "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py" + .file 2 "/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py" + .section .debug_abbrev + { +.b8 1 // Abbreviation Code +.b8 17 // DW_TAG_compile_unit +.b8 1 // DW_CHILDREN_yes +.b8 37 // DW_AT_producer +.b8 8 // DW_FORM_string +.b8 19 // DW_AT_language +.b8 5 // DW_FORM_data2 +.b8 3 // DW_AT_name +.b8 8 // DW_FORM_string +.b8 16 // DW_AT_stmt_list +.b8 6 // DW_FORM_data4 +.b8 27 // DW_AT_comp_dir +.b8 8 // DW_FORM_string +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 2 // Abbreviation Code +.b8 46 // DW_TAG_subprogram +.b8 0 // DW_CHILDREN_no +.b8 3 // DW_AT_name +.b8 8 // DW_FORM_string +.b8 32 // DW_AT_inline +.b8 11 // DW_FORM_data1 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 3 // Abbreviation Code +.b8 46 // DW_TAG_subprogram +.b8 1 // DW_CHILDREN_yes +.b8 17 // DW_AT_low_pc +.b8 1 // DW_FORM_addr +.b8 18 // DW_AT_high_pc +.b8 1 // DW_FORM_addr +.b8 49 // DW_AT_abstract_origin +.b8 19 // DW_FORM_ref4 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 4 // Abbreviation Code +.b8 29 // DW_TAG_inlined_subroutine +.b8 0 // DW_CHILDREN_no +.b8 49 // DW_AT_abstract_origin +.b8 19 // DW_FORM_ref4 +.b8 17 // DW_AT_low_pc +.b8 1 // DW_FORM_addr +.b8 18 // DW_AT_high_pc +.b8 1 // DW_FORM_addr +.b8 88 // DW_AT_call_file +.b8 11 // DW_FORM_data1 +.b8 89 // DW_AT_call_line +.b8 11 // DW_FORM_data1 +.b8 87 // DW_AT_call_column +.b8 11 // DW_FORM_data1 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 0 // EOM(3) + } + .section .debug_info + { +.b32 234 // Length of Unit +.b8 2 // DWARF version number +.b8 0 +.b32 .debug_abbrev // Offset Into Abbrev. Section +.b8 8 // Address Size (in bytes) +.b8 1 // Abbrev [1] 0xb:0xe3 DW_TAG_compile_unit +.b8 116 // DW_AT_producer +.b8 114 +.b8 105 +.b8 116 +.b8 111 +.b8 110 +.b8 0 +.b8 2 // DW_AT_language +.b8 0 +.b8 99 // DW_AT_name +.b8 118 +.b8 97 +.b8 51 +.b8 51 +.b8 119 +.b8 53 +.b8 55 +.b8 99 +.b8 106 +.b8 118 +.b8 109 +.b8 109 +.b8 121 +.b8 101 +.b8 113 +.b8 104 +.b8 122 +.b8 51 +.b8 104 +.b8 122 +.b8 97 +.b8 103 +.b8 104 +.b8 111 +.b8 120 +.b8 104 +.b8 55 +.b8 116 +.b8 121 +.b8 115 +.b8 117 +.b8 52 +.b8 107 +.b8 106 +.b8 120 +.b8 103 +.b8 50 +.b8 52 +.b8 104 +.b8 98 +.b8 100 +.b8 103 +.b8 116 +.b8 119 +.b8 101 +.b8 99 +.b8 53 +.b8 99 +.b8 110 +.b8 51 +.b8 115 +.b8 46 +.b8 112 +.b8 121 +.b8 0 +.b32 .debug_line // DW_AT_stmt_list +.b8 47 // DW_AT_comp_dir +.b8 119 +.b8 111 +.b8 114 +.b8 107 +.b8 115 +.b8 112 +.b8 97 +.b8 99 +.b8 101 +.b8 47 +.b8 104 +.b8 97 +.b8 110 +.b8 114 +.b8 117 +.b8 105 +.b8 47 +.b8 83 +.b8 112 +.b8 101 +.b8 99 +.b8 70 +.b8 111 +.b8 114 +.b8 103 +.b8 101 +.b8 45 +.b8 101 +.b8 120 +.b8 116 +.b8 47 +.b8 99 +.b8 97 +.b8 99 +.b8 104 +.b8 101 +.b8 47 +.b8 99 +.b8 111 +.b8 109 +.b8 112 +.b8 105 +.b8 108 +.b8 101 +.b8 100 +.b8 95 +.b8 107 +.b8 101 +.b8 114 +.b8 110 +.b8 101 +.b8 108 +.b8 115 +.b8 47 +.b8 118 +.b8 97 +.b8 0 +.b8 2 // Abbrev [2] 0x8b:0x1c DW_TAG_subprogram +.b8 116 // DW_AT_name +.b8 114 +.b8 105 +.b8 116 +.b8 111 +.b8 110 +.b8 95 +.b8 114 +.b8 101 +.b8 100 +.b8 95 +.b8 102 +.b8 117 +.b8 115 +.b8 101 +.b8 100 +.b8 95 +.b8 97 +.b8 114 +.b8 103 +.b8 109 +.b8 97 +.b8 120 +.b8 95 +.b8 49 +.b8 0 +.b8 1 // DW_AT_inline +.b8 3 // Abbrev [3] 0xa7:0x46 DW_TAG_subprogram +.b64 $L__func_begin0 // DW_AT_low_pc +.b64 $L__func_end0 // DW_AT_high_pc +.b32 139 // DW_AT_abstract_origin +.b8 4 // Abbrev [4] 0xbc:0x18 DW_TAG_inlined_subroutine +.b32 139 // DW_AT_abstract_origin +.b64 $L__tmp0 // DW_AT_low_pc +.b64 $L__tmp3 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 41 // DW_AT_call_line +.b8 38 // DW_AT_call_column +.b8 4 // Abbrev [4] 0xd4:0x18 DW_TAG_inlined_subroutine +.b32 139 // DW_AT_abstract_origin +.b64 $L__tmp4 // DW_AT_low_pc +.b64 $L__tmp5 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 45 // DW_AT_call_line +.b8 75 // DW_AT_call_column +.b8 0 // End Of Children Mark +.b8 0 // End Of Children Mark + } + .section .debug_macinfo { } diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/P3QZALUCGDQIGARFKKM54ROSVAIQRSFS3BCW42DE3EAUTSMCDEEQ/triton_red_fused_argmax_1.ttgir b/SpecForge-ext/cache/compiled_kernels/triton/0/P3QZALUCGDQIGARFKKM54ROSVAIQRSFS3BCW42DE3EAUTSMCDEEQ/triton_red_fused_argmax_1.ttgir new file mode 100644 index 0000000000000000000000000000000000000000..0ec6aee138514f424db4f9f8f05e4bc31369e647 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/P3QZALUCGDQIGARFKKM54ROSVAIQRSFS3BCW42DE3EAUTSMCDEEQ/triton_red_fused_argmax_1.ttgir @@ -0,0 +1,217 @@ +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [8, 2], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 8], order = [0, 1]}> +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":18:0) +#loc1 = loc(unknown) +#loc39 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":45:75) +#loc44 = loc("in_ptr0"(#loc)) +#loc45 = loc("out_ptr0"(#loc)) +#loc46 = loc("ks0"(#loc)) +#loc47 = loc("ks1"(#loc)) +#loc48 = loc("xnumel"(#loc)) +#loc49 = loc("r0_numel"(#loc)) +#loc85 = loc(callsite(#loc1 at #loc39)) +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 16 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @triton_red_fused_argmax_1(%in_ptr0: !tt.ptr loc("in_ptr0"(#loc)), %out_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr0"(#loc)), %ks0: i64 loc("ks0"(#loc)), %ks1: i64 loc("ks1"(#loc)), %xnumel: i32 loc("xnumel"(#loc)), %r0_numel: i32 {tt.divisibility = 16 : i32} loc("r0_numel"(#loc))) attributes {noinline = false} { + %cst = arith.constant dense<32000> : tensor<64x1xi64, #blocked> loc(#loc1) + %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked> loc(#loc1) + %c0_i32 = arith.constant 0 : i32 loc(#loc1) + %c32000_i32 = arith.constant 32000 : i32 loc(#loc1) + %cst_1 = arith.constant dense : tensor<64x64xi1, #blocked> loc(#loc1) + %true = arith.constant true loc(#loc1) + %cst_2 = arith.constant dense<32000> : tensor<1x64xi32, #blocked> loc(#loc1) + %cst_3 = arith.constant dense<2147483647> : tensor<64x64xi32, #blocked> loc(#loc1) + %cst_4 = arith.constant dense<0xFF800000> : tensor<64x64xf32, #blocked> loc(#loc1) + %c64_i32 = arith.constant 64 : i32 loc(#loc1) + %xoffset = tt.get_program_id x : i32 loc(#loc50) + %xoffset_5 = arith.muli %xoffset, %c64_i32 : i32 loc(#loc51) + %xindex = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc52) + %xindex_6 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc52) + %xindex_7 = tt.expand_dims %xindex {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc52) + %xindex_8 = tt.expand_dims %xindex_6 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> loc(#loc52) + %xindex_9 = tt.splat %xoffset_5 : i32 -> tensor<64x1xi32, #blocked> loc(#loc53) + %xindex_10 = tt.splat %xoffset_5 : i32 -> tensor<64x1xi32, #blocked1> loc(#loc53) + %xindex_11 = arith.addi %xindex_9, %xindex_7 : tensor<64x1xi32, #blocked> loc(#loc53) + %xindex_12 = arith.addi %xindex_10, %xindex_8 : tensor<64x1xi32, #blocked1> loc(#loc53) + %xmask = tt.splat %xnumel : i32 -> tensor<64x1xi32, #blocked> loc(#loc54) + %xmask_13 = tt.splat %xnumel : i32 -> tensor<64x1xi32, #blocked1> loc(#loc54) + %xmask_14 = arith.cmpi slt, %xindex_11, %xmask : tensor<64x1xi32, #blocked> loc(#loc54) + %xmask_15 = arith.cmpi slt, %xindex_12, %xmask_13 : tensor<64x1xi32, #blocked1> loc(#loc54) + %r0_base = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc55) + %r0_base_16 = tt.expand_dims %r0_base {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> loc(#loc55) + %x0 = arith.extsi %xindex_11 : tensor<64x1xi32, #blocked> to tensor<64x1xi64, #blocked> loc(#loc56) + %x0_17 = tt.splat %ks0 : i64 -> tensor<64x1xi64, #blocked> loc(#loc56) + %x0_18 = arith.remsi %x0, %x0_17 : tensor<64x1xi64, #blocked> loc(#loc56) + %x1 = arith.divsi %x0, %x0_17 : tensor<64x1xi64, #blocked> loc(#loc57) + %tmp0 = arith.muli %x0_18, %cst : tensor<64x1xi64, #blocked> loc(#loc58) + %tmp0_19 = tt.broadcast %tmp0 : tensor<64x1xi64, #blocked> -> tensor<64x64xi64, #blocked> loc(#loc59) + %tmp0_20 = tt.splat %ks1 : i64 -> tensor<64x1xi64, #blocked> loc(#loc60) + %tmp0_21 = arith.muli %tmp0_20, %x1 : tensor<64x1xi64, #blocked> loc(#loc60) + %tmp0_22 = tt.broadcast %tmp0_21 : tensor<64x1xi64, #blocked> -> tensor<64x64xi64, #blocked> loc(#loc61) + %tmp0_23 = tt.splat %in_ptr0 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked> loc(#loc62) + %tmp0_24 = tt.broadcast %xmask_14 : tensor<64x1xi1, #blocked> -> tensor<64x64xi1, #blocked> loc(#loc63) + %_tmp2_index:2 = scf.for %r0_offset = %c0_i32 to %c32000_i32 step %c64_i32 iter_args(%_tmp2 = %cst_4, %_tmp2_index_25 = %cst_3) -> (tensor<64x64xf32, #blocked>, tensor<64x64xi32, #blocked>) : i32 { + %r0_index = tt.splat %r0_offset : i32 -> tensor<1x64xi32, #blocked> loc(#loc65) + %r0_index_26 = arith.addi %r0_index, %r0_base_16 : tensor<1x64xi32, #blocked> loc(#loc65) + %r0_mask = arith.cmpi slt, %r0_index_26, %cst_2 : tensor<1x64xi32, #blocked> loc(#loc66) + %tmp0_27 = arith.extsi %r0_index_26 : tensor<1x64xi32, #blocked> to tensor<1x64xi64, #blocked> loc(#loc59) + %tmp0_28 = tt.broadcast %tmp0_27 : tensor<1x64xi64, #blocked> -> tensor<64x64xi64, #blocked> loc(#loc59) + %tmp0_29 = arith.addi %tmp0_28, %tmp0_19 : tensor<64x64xi64, #blocked> loc(#loc59) + %tmp0_30 = arith.addi %tmp0_29, %tmp0_22 : tensor<64x64xi64, #blocked> loc(#loc61) + %tmp0_31 = tt.addptr %tmp0_23, %tmp0_30 : tensor<64x64x!tt.ptr, #blocked>, tensor<64x64xi64, #blocked> loc(#loc62) + %tmp0_32 = tt.broadcast %r0_mask : tensor<1x64xi1, #blocked> -> tensor<64x64xi1, #blocked> loc(#loc63) + %tmp0_33 = arith.andi %tmp0_32, %tmp0_24 : tensor<64x64xi1, #blocked> loc(#loc63) + %tmp0_34 = tt.load %tmp0_31, %tmp0_33, %cst_0 evictionPolicy = evict_first : tensor<64x64x!tt.ptr, #blocked> loc(#loc67) + %mask = arith.cmpf ogt, %_tmp2, %tmp0_34 : tensor<64x64xf32, #blocked> loc(#loc110) + %equal = arith.cmpf oeq, %_tmp2, %tmp0_34 : tensor<64x64xf32, #blocked> loc(#loc111) + %a_isnan = arith.cmpf une, %_tmp2, %_tmp2 : tensor<64x64xf32, #blocked> loc(#loc90) + %b_isnan = arith.cmpf une, %tmp0_34, %tmp0_34 : tensor<64x64xf32, #blocked> loc(#loc91) + %mask_35 = arith.xori %b_isnan, %cst_1 : tensor<64x64xi1, #blocked> loc(#loc92) + %mask_36 = arith.andi %a_isnan, %mask_35 : tensor<64x64xi1, #blocked> loc(#loc93) + %mask_37 = arith.ori %mask, %mask_36 : tensor<64x64xi1, #blocked> loc(#loc112) + %equal_38 = arith.andi %a_isnan, %b_isnan : tensor<64x64xi1, #blocked> loc(#loc95) + %equal_39 = arith.ori %equal, %equal_38 : tensor<64x64xi1, #blocked> loc(#loc113) + %mask_40 = tt.broadcast %r0_index_26 : tensor<1x64xi32, #blocked> -> tensor<64x64xi32, #blocked> loc(#loc97) + %mask_41 = arith.cmpi slt, %_tmp2_index_25, %mask_40 : tensor<64x64xi32, #blocked> loc(#loc97) + %mask_42 = arith.andi %equal_39, %mask_41 : tensor<64x64xi1, #blocked> loc(#loc98) + %mask_43 = arith.ori %mask_37, %mask_42 : tensor<64x64xi1, #blocked> loc(#loc99) + %5 = arith.select %mask_43, %_tmp2, %tmp0_34 : tensor<64x64xi1, #blocked>, tensor<64x64xf32, #blocked> loc(#loc80) + %6 = arith.select %mask_43, %_tmp2_index_25, %mask_40 : tensor<64x64xi1, #blocked>, tensor<64x64xi32, #blocked> loc(#loc81) + %_tmp2_44 = arith.select %tmp0_33, %5, %_tmp2 : tensor<64x64xi1, #blocked>, tensor<64x64xf32, #blocked> loc(#loc82) + %_tmp2_index_45 = arith.select %tmp0_33, %6, %_tmp2_index_25 : tensor<64x64xi1, #blocked>, tensor<64x64xi32, #blocked> loc(#loc83) + scf.yield %_tmp2_44, %_tmp2_index_45 : tensor<64x64xf32, #blocked>, tensor<64x64xi32, #blocked> loc(#loc37) + } loc(#loc87) + %0:2 = "tt.reduce"(%_tmp2_index#0, %_tmp2_index#1) <{axis = 1 : i32}> ({ + ^bb0(%arg6: f32 loc(callsite(#loc1 at #loc39)), %arg7: i32 loc(callsite(#loc1 at #loc39)), %arg8: f32 loc(callsite(#loc1 at #loc39)), %arg9: i32 loc(callsite(#loc1 at #loc39))): + %mask = arith.cmpf ogt, %arg6, %arg8 : f32 loc(#loc114) + %equal = arith.cmpf oeq, %arg6, %arg8 : f32 loc(#loc115) + %a_isnan = arith.cmpf une, %arg6, %arg6 : f32 loc(#loc100) + %b_isnan = arith.cmpf une, %arg8, %arg8 : f32 loc(#loc101) + %mask_25 = arith.xori %b_isnan, %true : i1 loc(#loc102) + %mask_26 = arith.andi %a_isnan, %mask_25 : i1 loc(#loc103) + %mask_27 = arith.ori %mask, %mask_26 : i1 loc(#loc116) + %equal_28 = arith.andi %a_isnan, %b_isnan : i1 loc(#loc104) + %equal_29 = arith.ori %equal, %equal_28 : i1 loc(#loc117) + %mask_30 = arith.cmpi slt, %arg7, %arg9 : i32 loc(#loc105) + %mask_31 = arith.andi %equal_29, %mask_30 : i1 loc(#loc106) + %mask_32 = arith.ori %mask_27, %mask_31 : i1 loc(#loc107) + %5 = arith.select %mask_32, %arg6, %arg8 : f32 loc(#loc108) + %6 = arith.select %mask_32, %arg7, %arg9 : i32 loc(#loc109) + tt.reduce.return %5, %6 : f32, i32 loc(#loc84) + }) : (tensor<64x64xf32, #blocked>, tensor<64x64xi32, #blocked>) -> (tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>>) loc(#loc84) + %tmp2 = tt.expand_dims %0#1 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc86) + %1 = tt.splat %out_ptr0 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked1> loc(#loc41) + %2 = tt.addptr %1, %xindex_12 : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> loc(#loc41) + %3 = ttg.convert_layout %tmp2 : tensor<64x1xi32, #blocked> -> tensor<64x1xi32, #blocked1> loc(#loc42) + %4 = arith.extsi %3 : tensor<64x1xi32, #blocked1> to tensor<64x1xi64, #blocked1> loc(#loc42) + tt.store %2, %4, %xmask_15 : tensor<64x1x!tt.ptr, #blocked1> loc(#loc42) + tt.return loc(#loc43) + } loc(#loc) +} loc(#loc) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":22:28) +#loc3 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":22:33) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":23:44) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":23:23) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":24:21) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":25:37) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":27:19) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":28:19) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":38:47) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":38:41) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":38:56) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":38:52) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":38:34) +#loc15 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":38:71) +#loc16 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":32:40) +#loc17 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":33:31) +#loc18 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":34:29) +#loc19 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":38:61) +#loc20 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":144:21) +#loc21 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":41:38) +#loc22 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":145:23) +#loc23 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":147:29) +#loc24 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":148:29) +#loc25 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":149:31) +#loc26 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":149:27) +#loc27 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":149:16) +#loc28 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":151:27) +#loc29 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":151:17) +#loc30 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":154:31) +#loc31 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":154:21) +#loc32 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":154:12) +#loc33 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":155:35) +#loc34 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":155:69) +#loc35 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":43:54) +#loc36 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":44:66) +#loc37 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":44:8) +#loc38 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":165:42) +#loc40 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":46:20) +#loc41 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":47:25) +#loc42 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":47:36) +#loc43 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":47:4) +#loc50 = loc("xoffset"(#loc2)) +#loc51 = loc("xoffset"(#loc3)) +#loc52 = loc("xindex"(#loc4)) +#loc53 = loc("xindex"(#loc5)) +#loc54 = loc("xmask"(#loc6)) +#loc55 = loc("r0_base"(#loc7)) +#loc56 = loc("x0"(#loc8)) +#loc57 = loc("x1"(#loc9)) +#loc58 = loc("tmp0"(#loc10)) +#loc59 = loc("tmp0"(#loc11)) +#loc60 = loc("tmp0"(#loc12)) +#loc61 = loc("tmp0"(#loc13)) +#loc62 = loc("tmp0"(#loc14)) +#loc63 = loc("tmp0"(#loc15)) +#loc64 = loc("_tmp2"(#loc16)) +#loc65 = loc("r0_index"(#loc17)) +#loc66 = loc("r0_mask"(#loc18)) +#loc67 = loc("tmp0"(#loc19)) +#loc68 = loc("mask"(#loc20)) +#loc69 = loc("equal"(#loc22)) +#loc70 = loc("a_isnan"(#loc23)) +#loc71 = loc("b_isnan"(#loc24)) +#loc72 = loc("mask"(#loc25)) +#loc73 = loc("mask"(#loc26)) +#loc74 = loc("mask"(#loc27)) +#loc75 = loc("equal"(#loc28)) +#loc76 = loc("equal"(#loc29)) +#loc77 = loc("mask"(#loc30)) +#loc78 = loc("mask"(#loc31)) +#loc79 = loc("mask"(#loc32)) +#loc80 = loc(callsite(#loc33 at #loc21)) +#loc81 = loc(callsite(#loc34 at #loc21)) +#loc82 = loc("_tmp2"(#loc35)) +#loc83 = loc("_tmp2_index"(#loc36)) +#loc84 = loc(callsite(#loc38 at #loc39)) +#loc86 = loc("tmp2"(#loc40)) +#loc87 = loc("_tmp2_index"(#loc64)) +#loc88 = loc("mask"(#loc68)) +#loc89 = loc("equal"(#loc69)) +#loc90 = loc(callsite(#loc70 at #loc21)) +#loc91 = loc(callsite(#loc71 at #loc21)) +#loc92 = loc(callsite(#loc72 at #loc21)) +#loc93 = loc(callsite(#loc73 at #loc21)) +#loc94 = loc("mask"(#loc74)) +#loc95 = loc(callsite(#loc75 at #loc21)) +#loc96 = loc("equal"(#loc76)) +#loc97 = loc(callsite(#loc77 at #loc21)) +#loc98 = loc(callsite(#loc78 at #loc21)) +#loc99 = loc(callsite(#loc79 at #loc21)) +#loc100 = loc(callsite(#loc70 at #loc84)) +#loc101 = loc(callsite(#loc71 at #loc84)) +#loc102 = loc(callsite(#loc72 at #loc84)) +#loc103 = loc(callsite(#loc73 at #loc84)) +#loc104 = loc(callsite(#loc75 at #loc84)) +#loc105 = loc(callsite(#loc77 at #loc84)) +#loc106 = loc(callsite(#loc78 at #loc84)) +#loc107 = loc(callsite(#loc79 at #loc84)) +#loc108 = loc(callsite(#loc33 at #loc84)) +#loc109 = loc(callsite(#loc34 at #loc84)) +#loc110 = loc(callsite(#loc88 at #loc21)) +#loc111 = loc(callsite(#loc89 at #loc21)) +#loc112 = loc(callsite(#loc94 at #loc21)) +#loc113 = loc(callsite(#loc96 at #loc21)) +#loc114 = loc(callsite(#loc88 at #loc84)) +#loc115 = loc(callsite(#loc89 at #loc84)) +#loc116 = loc(callsite(#loc94 at #loc84)) +#loc117 = loc(callsite(#loc96 at #loc84)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/P3QZALUCGDQIGARFKKM54ROSVAIQRSFS3BCW42DE3EAUTSMCDEEQ/triton_red_fused_argmax_1.ttir b/SpecForge-ext/cache/compiled_kernels/triton/0/P3QZALUCGDQIGARFKKM54ROSVAIQRSFS3BCW42DE3EAUTSMCDEEQ/triton_red_fused_argmax_1.ttir new file mode 100644 index 0000000000000000000000000000000000000000..fb2e506ea50df8259e56e7d9846dd71b7c970d7e --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/P3QZALUCGDQIGARFKKM54ROSVAIQRSFS3BCW42DE3EAUTSMCDEEQ/triton_red_fused_argmax_1.ttir @@ -0,0 +1,213 @@ +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":18:0) +#loc1 = loc(unknown) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":45:75) +#loc47 = loc("in_ptr0"(#loc)) +#loc48 = loc("out_ptr0"(#loc)) +#loc49 = loc("ks0"(#loc)) +#loc50 = loc("ks1"(#loc)) +#loc51 = loc("xnumel"(#loc)) +#loc52 = loc("r0_numel"(#loc)) +#loc53 = loc(callsite(#loc1 at #loc2)) +module { + tt.func public @triton_red_fused_argmax_1(%in_ptr0: !tt.ptr loc("in_ptr0"(#loc)), %out_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr0"(#loc)), %ks0: i64 loc("ks0"(#loc)), %ks1: i64 loc("ks1"(#loc)), %xnumel: i32 loc("xnumel"(#loc)), %r0_numel: i32 {tt.divisibility = 16 : i32} loc("r0_numel"(#loc))) attributes {noinline = false} { + %true = arith.constant true loc(#loc53) + %cst = arith.constant dense : tensor<64x64xi1> loc(#loc1) + %c32000_i32 = arith.constant 32000 : i32 loc(#loc3) + %c0_i32 = arith.constant 0 : i32 loc(#loc3) + %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32> loc(#loc1) + %cst_1 = arith.constant dense<32000> : tensor<64x1xi64> loc(#loc1) + %cst_2 = arith.constant dense<32000> : tensor<1x64xi32> loc(#loc1) + %_tmp2_index = arith.constant dense<2147483647> : tensor<64x64xi32> loc(#loc54) + %_tmp2 = arith.constant dense<0xFF800000> : tensor<64x64xf32> loc(#loc55) + %c64_i32 = arith.constant 64 : i32 loc(#loc1) + %xoffset = tt.get_program_id x : i32 loc(#loc56) + %xoffset_3 = arith.muli %xoffset, %c64_i32 : i32 loc(#loc57) + %xindex = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc58) + %xindex_4 = tt.expand_dims %xindex {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc59) + %xindex_5 = tt.splat %xoffset_3 : i32 -> tensor<64x1xi32> loc(#loc60) + %xindex_6 = arith.addi %xindex_5, %xindex_4 : tensor<64x1xi32> loc(#loc60) + %xmask = tt.splat %xnumel : i32 -> tensor<64x1xi32> loc(#loc61) + %xmask_7 = arith.cmpi slt, %xindex_6, %xmask : tensor<64x1xi32> loc(#loc61) + %r0_base = tt.expand_dims %xindex {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc62) + %x0 = arith.extsi %xindex_6 : tensor<64x1xi32> to tensor<64x1xi64> loc(#loc63) + %x0_8 = tt.splat %ks0 : i64 -> tensor<64x1xi64> loc(#loc63) + %x0_9 = arith.remsi %x0, %x0_8 : tensor<64x1xi64> loc(#loc63) + %x1 = arith.divsi %x0, %x0_8 : tensor<64x1xi64> loc(#loc64) + %_tmp2_index_10:2 = scf.for %r0_offset = %c0_i32 to %c32000_i32 step %c64_i32 iter_args(%_tmp2_11 = %_tmp2, %_tmp2_index_12 = %_tmp2_index) -> (tensor<64x64xf32>, tensor<64x64xi32>) : i32 { + %r0_index = tt.splat %r0_offset : i32 -> tensor<1x64xi32> loc(#loc66) + %r0_index_13 = arith.addi %r0_index, %r0_base : tensor<1x64xi32> loc(#loc66) + %r0_mask = arith.cmpi slt, %r0_index_13, %cst_2 : tensor<1x64xi32> loc(#loc67) + %tmp0 = arith.muli %x0_9, %cst_1 : tensor<64x1xi64> loc(#loc68) + %tmp0_14 = arith.extsi %r0_index_13 : tensor<1x64xi32> to tensor<1x64xi64> loc(#loc69) + %tmp0_15 = tt.broadcast %tmp0_14 : tensor<1x64xi64> -> tensor<64x64xi64> loc(#loc69) + %tmp0_16 = tt.broadcast %tmp0 : tensor<64x1xi64> -> tensor<64x64xi64> loc(#loc69) + %tmp0_17 = arith.addi %tmp0_15, %tmp0_16 : tensor<64x64xi64> loc(#loc69) + %tmp0_18 = tt.splat %ks1 : i64 -> tensor<64x1xi64> loc(#loc70) + %tmp0_19 = arith.muli %tmp0_18, %x1 : tensor<64x1xi64> loc(#loc70) + %tmp0_20 = tt.broadcast %tmp0_19 : tensor<64x1xi64> -> tensor<64x64xi64> loc(#loc71) + %tmp0_21 = arith.addi %tmp0_17, %tmp0_20 : tensor<64x64xi64> loc(#loc71) + %tmp0_22 = tt.splat %in_ptr0 : !tt.ptr -> tensor<64x64x!tt.ptr> loc(#loc72) + %tmp0_23 = tt.addptr %tmp0_22, %tmp0_21 : tensor<64x64x!tt.ptr>, tensor<64x64xi64> loc(#loc72) + %tmp0_24 = tt.broadcast %r0_mask : tensor<1x64xi1> -> tensor<64x64xi1> loc(#loc73) + %tmp0_25 = tt.broadcast %xmask_7 : tensor<64x1xi1> -> tensor<64x64xi1> loc(#loc73) + %tmp0_26 = arith.andi %tmp0_24, %tmp0_25 : tensor<64x64xi1> loc(#loc73) + %tmp0_27 = tt.load %tmp0_23, %tmp0_26, %cst_0 evictionPolicy = evict_first : tensor<64x64x!tt.ptr> loc(#loc74) + %mask = arith.cmpf ogt, %_tmp2_11, %tmp0_27 : tensor<64x64xf32> loc(#loc116) + %equal = arith.cmpf oeq, %_tmp2_11, %tmp0_27 : tensor<64x64xf32> loc(#loc117) + %a_isnan = arith.cmpf une, %_tmp2_11, %_tmp2_11 : tensor<64x64xf32> loc(#loc96) + %b_isnan = arith.cmpf une, %tmp0_27, %tmp0_27 : tensor<64x64xf32> loc(#loc97) + %mask_28 = arith.xori %b_isnan, %cst : tensor<64x64xi1> loc(#loc98) + %mask_29 = arith.andi %a_isnan, %mask_28 : tensor<64x64xi1> loc(#loc99) + %mask_30 = arith.ori %mask, %mask_29 : tensor<64x64xi1> loc(#loc118) + %equal_31 = arith.andi %a_isnan, %b_isnan : tensor<64x64xi1> loc(#loc101) + %equal_32 = arith.ori %equal, %equal_31 : tensor<64x64xi1> loc(#loc119) + %mask_33 = tt.broadcast %r0_index_13 : tensor<1x64xi32> -> tensor<64x64xi32> loc(#loc103) + %mask_34 = arith.cmpi slt, %_tmp2_index_12, %mask_33 : tensor<64x64xi32> loc(#loc103) + %mask_35 = arith.andi %equal_32, %mask_34 : tensor<64x64xi1> loc(#loc104) + %mask_36 = arith.ori %mask_30, %mask_35 : tensor<64x64xi1> loc(#loc105) + %4 = arith.select %mask_36, %_tmp2_11, %tmp0_27 : tensor<64x64xi1>, tensor<64x64xf32> loc(#loc87) + %5 = arith.select %mask_36, %_tmp2_index_12, %mask_33 : tensor<64x64xi1>, tensor<64x64xi32> loc(#loc88) + %_tmp2_37 = arith.select %tmp0_26, %4, %_tmp2_11 : tensor<64x64xi1>, tensor<64x64xf32> loc(#loc89) + %_tmp2_index_38 = arith.select %tmp0_26, %5, %_tmp2_index_12 : tensor<64x64xi1>, tensor<64x64xi32> loc(#loc90) + scf.yield %_tmp2_37, %_tmp2_index_38 : tensor<64x64xf32>, tensor<64x64xi32> loc(#loc41) + } loc(#loc93) + %0:2 = "tt.reduce"(%_tmp2_index_10#0, %_tmp2_index_10#1) <{axis = 1 : i32}> ({ + ^bb0(%arg6: f32 loc(callsite(#loc1 at #loc2)), %arg7: i32 loc(callsite(#loc1 at #loc2)), %arg8: f32 loc(callsite(#loc1 at #loc2)), %arg9: i32 loc(callsite(#loc1 at #loc2))): + %mask = arith.cmpf ogt, %arg6, %arg8 : f32 loc(#loc120) + %equal = arith.cmpf oeq, %arg6, %arg8 : f32 loc(#loc121) + %a_isnan = arith.cmpf une, %arg6, %arg6 : f32 loc(#loc106) + %b_isnan = arith.cmpf une, %arg8, %arg8 : f32 loc(#loc107) + %mask_11 = arith.xori %b_isnan, %true : i1 loc(#loc108) + %mask_12 = arith.andi %a_isnan, %mask_11 : i1 loc(#loc109) + %mask_13 = arith.ori %mask, %mask_12 : i1 loc(#loc122) + %equal_14 = arith.andi %a_isnan, %b_isnan : i1 loc(#loc110) + %equal_15 = arith.ori %equal, %equal_14 : i1 loc(#loc123) + %mask_16 = arith.cmpi slt, %arg7, %arg9 : i32 loc(#loc111) + %mask_17 = arith.andi %equal_15, %mask_16 : i1 loc(#loc112) + %mask_18 = arith.ori %mask_13, %mask_17 : i1 loc(#loc113) + %4 = arith.select %mask_18, %arg6, %arg8 : f32 loc(#loc114) + %5 = arith.select %mask_18, %arg7, %arg9 : i32 loc(#loc115) + tt.reduce.return %4, %5 : f32, i32 loc(#loc91) + }) : (tensor<64x64xf32>, tensor<64x64xi32>) -> (tensor<64xf32>, tensor<64xi32>) loc(#loc91) + %tmp2 = tt.expand_dims %0#1 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc92) + %1 = tt.splat %out_ptr0 : !tt.ptr -> tensor<64x1x!tt.ptr> loc(#loc44) + %2 = tt.addptr %1, %xindex_6 : tensor<64x1x!tt.ptr>, tensor<64x1xi32> loc(#loc44) + %3 = arith.extsi %tmp2 : tensor<64x1xi32> to tensor<64x1xi64> loc(#loc45) + tt.store %2, %3, %xmask_7 : tensor<64x1x!tt.ptr> loc(#loc45) + tt.return loc(#loc46) + } loc(#loc) +} loc(#loc) +#loc3 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":32:40) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":30:58) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":29:55) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":22:28) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":22:33) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":23:36) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":23:44) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":23:23) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":24:21) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":25:37) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":27:19) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":28:19) +#loc15 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":33:31) +#loc16 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":34:29) +#loc17 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":38:47) +#loc18 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":38:41) +#loc19 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":38:56) +#loc20 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":38:52) +#loc21 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":38:34) +#loc22 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":38:71) +#loc23 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":38:61) +#loc24 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":144:21) +#loc25 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":41:38) +#loc26 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":145:23) +#loc27 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":147:29) +#loc28 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":148:29) +#loc29 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":149:31) +#loc30 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":149:27) +#loc31 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":149:16) +#loc32 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":151:27) +#loc33 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":151:17) +#loc34 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":154:31) +#loc35 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":154:21) +#loc36 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":154:12) +#loc37 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":155:35) +#loc38 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":155:69) +#loc39 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":43:54) +#loc40 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":44:66) +#loc41 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":44:8) +#loc42 = loc("/workspace/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":165:42) +#loc43 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":46:20) +#loc44 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":47:25) +#loc45 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":47:36) +#loc46 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/va/cva33w57cjvmmyeqhz3hzaghoxh7tysu4kjxg24hbdgtwec5cn3s.py":47:4) +#loc54 = loc("_tmp2_index"(#loc4)) +#loc55 = loc("_tmp2"(#loc5)) +#loc56 = loc("xoffset"(#loc6)) +#loc57 = loc("xoffset"(#loc7)) +#loc58 = loc("xindex"(#loc8)) +#loc59 = loc("xindex"(#loc9)) +#loc60 = loc("xindex"(#loc10)) +#loc61 = loc("xmask"(#loc11)) +#loc62 = loc("r0_base"(#loc12)) +#loc63 = loc("x0"(#loc13)) +#loc64 = loc("x1"(#loc14)) +#loc65 = loc("_tmp2"(#loc3)) +#loc66 = loc("r0_index"(#loc15)) +#loc67 = loc("r0_mask"(#loc16)) +#loc68 = loc("tmp0"(#loc17)) +#loc69 = loc("tmp0"(#loc18)) +#loc70 = loc("tmp0"(#loc19)) +#loc71 = loc("tmp0"(#loc20)) +#loc72 = loc("tmp0"(#loc21)) +#loc73 = loc("tmp0"(#loc22)) +#loc74 = loc("tmp0"(#loc23)) +#loc75 = loc("mask"(#loc24)) +#loc76 = loc("equal"(#loc26)) +#loc77 = loc("a_isnan"(#loc27)) +#loc78 = loc("b_isnan"(#loc28)) +#loc79 = loc("mask"(#loc29)) +#loc80 = loc("mask"(#loc30)) +#loc81 = loc("mask"(#loc31)) +#loc82 = loc("equal"(#loc32)) +#loc83 = loc("equal"(#loc33)) +#loc84 = loc("mask"(#loc34)) +#loc85 = loc("mask"(#loc35)) +#loc86 = loc("mask"(#loc36)) +#loc87 = loc(callsite(#loc37 at #loc25)) +#loc88 = loc(callsite(#loc38 at #loc25)) +#loc89 = loc("_tmp2"(#loc39)) +#loc90 = loc("_tmp2_index"(#loc40)) +#loc91 = loc(callsite(#loc42 at #loc2)) +#loc92 = loc("tmp2"(#loc43)) +#loc93 = loc("_tmp2_index"(#loc65)) +#loc94 = loc("mask"(#loc75)) +#loc95 = loc("equal"(#loc76)) +#loc96 = loc(callsite(#loc77 at #loc25)) +#loc97 = loc(callsite(#loc78 at #loc25)) +#loc98 = loc(callsite(#loc79 at #loc25)) +#loc99 = loc(callsite(#loc80 at #loc25)) +#loc100 = loc("mask"(#loc81)) +#loc101 = loc(callsite(#loc82 at #loc25)) +#loc102 = loc("equal"(#loc83)) +#loc103 = loc(callsite(#loc84 at #loc25)) +#loc104 = loc(callsite(#loc85 at #loc25)) +#loc105 = loc(callsite(#loc86 at #loc25)) +#loc106 = loc(callsite(#loc77 at #loc91)) +#loc107 = loc(callsite(#loc78 at #loc91)) +#loc108 = loc(callsite(#loc79 at #loc91)) +#loc109 = loc(callsite(#loc80 at #loc91)) +#loc110 = loc(callsite(#loc82 at #loc91)) +#loc111 = loc(callsite(#loc84 at #loc91)) +#loc112 = loc(callsite(#loc85 at #loc91)) +#loc113 = loc(callsite(#loc86 at #loc91)) +#loc114 = loc(callsite(#loc37 at #loc91)) +#loc115 = loc(callsite(#loc38 at #loc91)) +#loc116 = loc(callsite(#loc94 at #loc25)) +#loc117 = loc(callsite(#loc95 at #loc25)) +#loc118 = loc(callsite(#loc100 at #loc25)) +#loc119 = loc(callsite(#loc102 at #loc25)) +#loc120 = loc(callsite(#loc94 at #loc91)) +#loc121 = loc(callsite(#loc95 at #loc91)) +#loc122 = loc(callsite(#loc100 at #loc91)) +#loc123 = loc(callsite(#loc102 at #loc91)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/VRYS24QWRHHWZGMHQTKUKRBT44JWBUITDFBXCOHWEBQVQBBVUKNA/triton_tem_fused_0.source b/SpecForge-ext/cache/compiled_kernels/triton/0/VRYS24QWRHHWZGMHQTKUKRBT44JWBUITDFBXCOHWEBQVQBBVUKNA/triton_tem_fused_0.source new file mode 100644 index 0000000000000000000000000000000000000000..569613b7ce2b6e5350c3d4c579900c194a334dbb --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/VRYS24QWRHHWZGMHQTKUKRBT44JWBUITDFBXCOHWEBQVQBBVUKNA/triton_tem_fused_0.source @@ -0,0 +1,1203 @@ +#loc = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":18:0) +#loc127 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":271:0) +#loc139 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":32:0) +#loc145 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":448:0) +#loc155 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":299:0) +#loc227 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":256:0) +#loc231 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":175:0) +#loc233 = loc(unknown) +#loc236 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":167:0) +#loc240 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":285:0) +#loc244 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":260:0) +#loc248 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":240:0) +#loc269 = loc("arg_Q"(#loc)) +#loc270 = loc("arg_K"(#loc)) +#loc271 = loc("arg_V"(#loc)) +#loc272 = loc("arg_LSE"(#loc)) +#loc273 = loc("arg_MAX"(#loc)) +#loc274 = loc("arg_KV_NUM_BLKS"(#loc)) +#loc275 = loc("arg_KV_IDX"(#loc)) +#loc276 = loc("arg_FULL_KV_NUM_BLKS"(#loc)) +#loc277 = loc("arg_FULL_KV_IDX"(#loc)) +#loc278 = loc("in_ptr9"(#loc)) +#loc279 = loc("out_ptr0"(#loc)) +#loc280 = loc("ks0"(#loc)) +#loc281 = loc("ks1"(#loc)) +#loc282 = loc("ks2"(#loc)) +#loc283 = loc("ks3"(#loc)) +#loc284 = loc("ks4"(#loc)) +#loc285 = loc("ks5"(#loc)) +#loc382 = loc("ptr"(#loc127)) +#loc383 = loc("offs_m"(#loc127)) +#loc384 = loc("offs_n"(#loc127)) +#loc385 = loc("stride_m"(#loc127)) +#loc386 = loc("stride_n"(#loc127)) +#loc387 = loc("M_LEN"(#loc127)) +#loc394 = loc("x"(#loc139)) +#loc395 = loc("arg_Q"(#loc145)) +#loc396 = loc("arg_K"(#loc145)) +#loc397 = loc("arg_V"(#loc145)) +#loc398 = loc("arg_LSE"(#loc145)) +#loc399 = loc("arg_MAX"(#loc145)) +#loc400 = loc("arg_KV_NUM_BLKS"(#loc145)) +#loc401 = loc("arg_KV_IDX"(#loc145)) +#loc402 = loc("arg_FULL_KV_NUM_BLKS"(#loc145)) +#loc403 = loc("arg_FULL_KV_IDX"(#loc145)) +#loc404 = loc("in_ptr9"(#loc145)) +#loc405 = loc("out_ptr0"(#loc145)) +#loc406 = loc("ks0"(#loc145)) +#loc407 = loc("ks1"(#loc145)) +#loc408 = loc("ks2"(#loc145)) +#loc409 = loc("ks3"(#loc145)) +#loc410 = loc("ks4"(#loc145)) +#loc411 = loc("ks5"(#loc145)) +#loc412 = loc("q"(#loc145)) +#loc413 = loc("K"(#loc145)) +#loc414 = loc("V"(#loc145)) +#loc415 = loc("Q_LEN"(#loc145)) +#loc416 = loc("KV_LEN"(#loc145)) +#loc417 = loc("acc"(#loc145)) +#loc418 = loc("l_i"(#loc145)) +#loc419 = loc("m_i"(#loc145)) +#loc420 = loc("off_z"(#loc145)) +#loc421 = loc("off_h"(#loc145)) +#loc422 = loc("offs_m"(#loc145)) +#loc423 = loc("offs_n"(#loc145)) +#loc424 = loc("kv_start"(#loc145)) +#loc425 = loc("kv_indices"(#loc145)) +#loc426 = loc("kv_num_blocks"(#loc145)) +#loc427 = loc("block_n_end"(#loc145)) +#loc428 = loc("stride_kk"(#loc145)) +#loc429 = loc("stride_kn"(#loc145)) +#loc430 = loc("stride_vn"(#loc145)) +#loc431 = loc("stride_vk"(#loc145)) +#loc437 = loc("arg_Q"(#loc155)) +#loc438 = loc("arg_K"(#loc155)) +#loc439 = loc("arg_V"(#loc155)) +#loc440 = loc("arg_LSE"(#loc155)) +#loc441 = loc("arg_MAX"(#loc155)) +#loc442 = loc("arg_KV_NUM_BLKS"(#loc155)) +#loc443 = loc("arg_KV_IDX"(#loc155)) +#loc444 = loc("arg_FULL_KV_NUM_BLKS"(#loc155)) +#loc445 = loc("arg_FULL_KV_IDX"(#loc155)) +#loc446 = loc("in_ptr9"(#loc155)) +#loc447 = loc("out_ptr0"(#loc155)) +#loc448 = loc("ks0"(#loc155)) +#loc449 = loc("ks1"(#loc155)) +#loc450 = loc("ks2"(#loc155)) +#loc451 = loc("ks3"(#loc155)) +#loc452 = loc("ks4"(#loc155)) +#loc453 = loc("ks5"(#loc155)) +#loc454 = loc("q"(#loc155)) +#loc455 = loc("K"(#loc155)) +#loc456 = loc("V"(#loc155)) +#loc457 = loc("Q_LEN"(#loc155)) +#loc458 = loc("KV_LEN"(#loc155)) +#loc459 = loc("acc"(#loc155)) +#loc460 = loc("l_i"(#loc155)) +#loc461 = loc("m_i"(#loc155)) +#loc462 = loc("off_z"(#loc155)) +#loc463 = loc("off_h"(#loc155)) +#loc464 = loc("offs_m"(#loc155)) +#loc465 = loc("offs_n"(#loc155)) +#loc466 = loc("kv_start"(#loc155)) +#loc467 = loc("kv_offset"(#loc155)) +#loc468 = loc("stride_kk"(#loc155)) +#loc469 = loc("stride_kn"(#loc155)) +#loc470 = loc("stride_vn"(#loc155)) +#loc471 = loc("stride_vk"(#loc155)) +#loc541 = loc("indices"(#loc227)) +#loc542 = loc("max_len"(#loc227)) +#loc543 = loc("input"(#loc231)) +#loc544 = loc("a"(#loc236)) +#loc545 = loc("b"(#loc236)) +#loc546 = loc("input"(#loc240)) +#loc547 = loc("a"(#loc244)) +#loc548 = loc("b"(#loc244)) +#loc549 = loc("loop_iter"(#loc248)) +#loc550 = loc("col_indices"(#loc248)) +#loc551 = loc("total_blocks"(#loc248)) +module { + tt.func public @triton_tem_fused_0(%arg_Q: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_Q"(#loc)), %arg_K: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_K"(#loc)), %arg_V: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_V"(#loc)), %arg_LSE: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_LSE"(#loc)), %arg_MAX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_MAX"(#loc)), %arg_KV_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_KV_NUM_BLKS"(#loc)), %arg_KV_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_KV_IDX"(#loc)), %arg_FULL_KV_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_KV_NUM_BLKS"(#loc)), %arg_FULL_KV_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_KV_IDX"(#loc)), %in_ptr9: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr9"(#loc)), %out_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr0"(#loc)), %ks0: i32 loc("ks0"(#loc)), %ks1: i32 loc("ks1"(#loc)), %ks2: i32 loc("ks2"(#loc)), %ks3: i32 loc("ks3"(#loc)), %ks4: i32 loc("ks4"(#loc)), %ks5: i32 loc("ks5"(#loc))) attributes {noinline = false} { + %c4096_i32 = arith.constant 4096 : i32 loc(#loc1) + %c4096_i32_0 = arith.constant 4096 : i32 loc(#loc1) + %0 = arith.muli %c4096_i32_0, %ks0 : i32 loc(#loc1) + %c128_i32 = arith.constant 128 : i32 loc(#loc2) + %c4096_i32_1 = arith.constant 4096 : i32 loc(#loc2) + %c1_i32 = arith.constant 1 : i32 loc(#loc2) + %c1024_i32 = arith.constant 1024 : i32 loc(#loc3) + %c1024_i32_2 = arith.constant 1024 : i32 loc(#loc3) + %1 = arith.muli %c1024_i32_2, %ks1 : i32 loc(#loc3) + %c128_i32_3 = arith.constant 128 : i32 loc(#loc4) + %c128_i32_4 = arith.constant 128 : i32 loc(#loc4) + %2 = arith.muli %c128_i32_4, %ks1 : i32 loc(#loc4) + %c128_i32_5 = arith.constant 128 : i32 loc(#loc5) + %c1_i32_6 = arith.constant 1 : i32 loc(#loc5) + %c1024_i32_7 = arith.constant 1024 : i32 loc(#loc6) + %c1024_i32_8 = arith.constant 1024 : i32 loc(#loc6) + %3 = arith.muli %c1024_i32_8, %ks1 : i32 loc(#loc6) + %c128_i32_9 = arith.constant 128 : i32 loc(#loc7) + %c128_i32_10 = arith.constant 128 : i32 loc(#loc7) + %4 = arith.muli %c128_i32_10, %ks1 : i32 loc(#loc7) + %c128_i32_11 = arith.constant 128 : i32 loc(#loc8) + %c1_i32_12 = arith.constant 1 : i32 loc(#loc8) + %ZQ = arith.constant 8 : i32 loc(#loc286) + %HQ = arith.constant 32 : i32 loc(#loc287) + %ZKV = arith.constant 8 : i32 loc(#loc288) + %q_start = tt.get_program_id x : i32 loc(#loc289) + %off_zq = tt.get_program_id y : i32 loc(#loc290) + %off_hq = tt.get_program_id z : i32 loc(#loc291) + %off_zkv = arith.remsi %off_zq, %ZKV : i32 loc(#loc292) + %off_hkv = arith.constant 4 : i32 loc(#loc293) + %off_hkv_13 = arith.constant 4 : i32 loc(#loc293) + %off_hkv_14 = arith.divsi %off_hq, %off_hkv_13 : i32 loc(#loc293) + %off_g = arith.constant 4 : i32 loc(#loc294) + %off_g_15 = arith.constant 4 : i32 loc(#loc294) + %off_g_16 = arith.remsi %off_hq, %off_g_15 : i32 loc(#loc294) + %q_offset = arith.muli %off_zq, %0 : i32 loc(#loc295) + %q_offset_17 = arith.muli %off_hq, %c128_i32 : i32 loc(#loc296) + %q_offset_18 = arith.addi %q_offset, %q_offset_17 : i32 loc(#loc297) + %k_offset = arith.muli %off_zkv, %1 : i32 loc(#loc298) + %k_offset_19 = arith.muli %off_hkv_14, %2 : i32 loc(#loc299) + %k_offset_20 = arith.addi %k_offset, %k_offset_19 : i32 loc(#loc300) + %v_offset = arith.muli %off_zkv, %3 : i32 loc(#loc301) + %v_offset_21 = arith.muli %off_hkv_14, %4 : i32 loc(#loc302) + %v_offset_22 = arith.addi %v_offset, %v_offset_21 : i32 loc(#loc303) + %Q = tt.addptr %arg_Q, %q_offset_18 : !tt.ptr, i32 loc(#loc304) + %K = tt.addptr %arg_K, %k_offset_20 : !tt.ptr, i32 loc(#loc305) + %V = tt.addptr %arg_V, %v_offset_22 : !tt.ptr, i32 loc(#loc306) + %SPARSE_Z = arith.constant 8 : i32 loc(#loc307) + %SPARSE_HQ = arith.constant 1 : i32 loc(#loc308) + %sparse_idx_z = arith.remsi %off_zq, %SPARSE_Z : i32 loc(#loc309) + %sparse_idx_hq = arith.remsi %off_hq, %SPARSE_HQ : i32 loc(#loc310) + %stride_kv_idx_h = arith.muli %ks3, %ks4 : i32 loc(#loc311) + %m_i = tt.call @"triton.language.standard.zeros____(0, 0)cconstexpr_128__(1,)cconstexpr_fp32_"() : () -> tensor<128xf32> loc(#loc312) + %m_i_23 = arith.constant 0x7F800000 : f32 loc(#loc313) + %m_i_24 = arith.constant 0x7F800000 : f32 loc(#loc313) + %m_i_25 = arith.constant dense<0x7F800000> : tensor<128xf32> loc(#loc313) + %m_i_26 = arith.subf %m_i, %m_i_25 : tensor<128xf32> loc(#loc313) + %l_i = tt.call @"triton.language.standard.zeros____(0, 0)cconstexpr_128__(1,)cconstexpr_fp32_"() : () -> tensor<128xf32> loc(#loc314) + %acc = tt.call @"triton.language.standard.zeros____(0, 0)cconstexpr_128__(0, 1)cconstexpr_128__(1,)cconstexpr_fp32_"() : () -> tensor<128x128xf32> loc(#loc315) + %offs_m = arith.constant 128 : i32 loc(#loc316) + %offs_m_27 = arith.constant 128 : i32 loc(#loc316) + %offs_m_28 = arith.muli %q_start, %offs_m_27 : i32 loc(#loc316) + %offs_m_29 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc317) + %offs_m_30 = tt.splat %offs_m_28 : i32 -> tensor<128xi32> loc(#loc318) + %offs_m_31 = arith.addi %offs_m_30, %offs_m_29 : tensor<128xi32> loc(#loc318) + %sparse_hz_offset = arith.muli %sparse_idx_z, %SPARSE_HQ : i32 loc(#loc319) + %sparse_hz_offset_32 = arith.addi %sparse_hz_offset, %sparse_idx_hq : i32 loc(#loc320) + %sparse_kv_num_blks_offset = arith.muli %sparse_hz_offset_32, %ks2 : i32 loc(#loc321) + %sparse_kv_num_blks_offset_33 = arith.constant 1 : i32 loc(#loc322) + %sparse_kv_num_blks_offset_34 = arith.constant 1 : i32 loc(#loc322) + %sparse_kv_num_blks_offset_35 = arith.divsi %q_start, %sparse_kv_num_blks_offset_34 : i32 loc(#loc322) + %sparse_kv_num_blks_offset_36 = arith.addi %sparse_kv_num_blks_offset, %sparse_kv_num_blks_offset_35 : i32 loc(#loc323) + %sparse_kv_idx_offset = arith.muli %sparse_hz_offset_32, %stride_kv_idx_h : i32 loc(#loc324) + %sparse_kv_idx_offset_37 = arith.constant 1 : i32 loc(#loc325) + %sparse_kv_idx_offset_38 = arith.constant 1 : i32 loc(#loc325) + %sparse_kv_idx_offset_39 = arith.divsi %q_start, %sparse_kv_idx_offset_38 : i32 loc(#loc325) + %sparse_kv_idx_offset_40 = arith.muli %sparse_kv_idx_offset_39, %ks4 : i32 loc(#loc326) + %sparse_kv_idx_offset_41 = arith.addi %sparse_kv_idx_offset, %sparse_kv_idx_offset_40 : i32 loc(#loc327) + %offs_m_42 = arith.constant 128 : i32 loc(#loc328) + %offs_m_43 = arith.constant 128 : i32 loc(#loc328) + %offs_m_44 = arith.muli %q_start, %offs_m_43 : i32 loc(#loc328) + %offs_m_45 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc329) + %offs_m_46 = tt.splat %offs_m_44 : i32 -> tensor<128xi32> loc(#loc330) + %offs_m_47 = arith.addi %offs_m_46, %offs_m_45 : tensor<128xi32> loc(#loc330) + %offs_k = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc331) + %q = tt.call @"torch._inductor.runtime.compile_tasks.chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.load_checked_2d__Pbf16_i32S128S_i32S128S_i32_i32_i32__(5,)cconstexpr_False__(6,)cconstexpr_True__(8,)cconstexpr_128_"(%Q, %offs_m_47, %offs_k, %c4096_i32_1, %c1_i32, %ks0) : (!tt.ptr, tensor<128xi32>, tensor<128xi32>, i32, i32, i32) -> tensor<128x128xbf16> loc(#loc332) + %kv_indices = tt.addptr %arg_KV_IDX, %sparse_kv_idx_offset_41 : !tt.ptr, i32 loc(#loc333) + %kv_start = tt.load %kv_indices : !tt.ptr loc(#loc334) + %kv_start_48 = arith.constant 128 : i32 loc(#loc335) + %kv_start_49 = arith.constant 128 : i32 loc(#loc335) + %kv_start_50 = arith.muli %kv_start, %kv_start_49 : i32 loc(#loc335) + %kv_num_blocks = tt.addptr %arg_KV_NUM_BLKS, %sparse_kv_num_blks_offset_36 : !tt.ptr, i32 loc(#loc336) + %kv_num_blocks_51 = tt.load %kv_num_blocks : !tt.ptr loc(#loc337) + %block_n_end = arith.constant 2 : i32 loc(#loc338) + %block_n_end_52 = arith.constant 2 : i32 loc(#loc338) + %block_n_end_53 = arith.muli %kv_num_blocks_51, %block_n_end_52 : i32 loc(#loc338) + %block_n_end_54 = tt.call @"triton.language.standard.cdiv__i32__(1,)cconstexpr_64_"(%ks1) : (i32) -> i32 loc(#loc339) + %block_n_end_55 = arith.constant 1 : i32 loc(#loc340) + %block_n_end_56 = arith.maxsi %block_n_end_54, %block_n_end_55 : i32 loc(#loc340) + %block_n_end_57 = arith.minsi %block_n_end_53, %block_n_end_56 : i32 loc(#loc341) + %offs_n = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc342) + %offs_n_58 = tt.splat %kv_start_50 : i32 -> tensor<64xi32> loc(#loc343) + %offs_n_59 = arith.addi %offs_n_58, %offs_n : tensor<64xi32> loc(#loc343) + %5 = tt.expand_dims %offs_m_47 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc67) + %6 = tt.expand_dims %offs_n_59 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc68) + %7:3 = tt.call @"torch._inductor.runtime.compile_tasks.chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.forward_inner__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pi32_Pi32_Pi32_Pi32_Pi64_Pbf16_i32_i32_i32_i32_i32_i32_bf16S128_128S_Pbf16_Pbf16_i32_i32_fp32S128_128S_fp32S128S_fp32S128S_i32_i32_i32S128_1S_i32S1_64S_i32_Pi32_i32_i32_i32_i32_i32_i32__(20,)cNone_(21,)cNone_(34,)cconstexpr_0__(36,)cconstexpr_bf16__(41,)cconstexpr_False_"(%arg_Q, %arg_K, %arg_V, %arg_LSE, %arg_MAX, %arg_KV_NUM_BLKS, %arg_KV_IDX, %arg_FULL_KV_NUM_BLKS, %arg_FULL_KV_IDX, %in_ptr9, %out_ptr0, %ks0, %ks1, %ks2, %ks3, %ks4, %ks5, %q, %K, %V, %ks0, %ks1, %acc, %l_i, %m_i_26, %off_zq, %off_hq, %5, %6, %kv_start_50, %kv_indices, %kv_num_blocks_51, %block_n_end_57, %c1_i32_6, %c128_i32_5, %c128_i32_11, %c1_i32_12) : (!tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, i32, i32, i32, i32, i32, i32, tensor<128x128xbf16>, !tt.ptr, !tt.ptr, i32, i32, tensor<128x128xf32>, tensor<128xf32>, tensor<128xf32>, i32, i32, tensor<128x1xi32>, tensor<1x64xi32>, i32, !tt.ptr, i32, i32, i32, i32, i32, i32) -> (tensor<128x128xf32>, tensor<128xf32>, tensor<128xf32>) loc(#loc69) + %kv_indices_60 = tt.addptr %arg_FULL_KV_IDX, %sparse_kv_idx_offset_41 : !tt.ptr, i32 loc(#loc344) + %kv_start_61 = tt.load %kv_indices_60 : !tt.ptr loc(#loc345) + %kv_start_62 = arith.constant 128 : i32 loc(#loc346) + %kv_start_63 = arith.constant 128 : i32 loc(#loc346) + %kv_start_64 = arith.muli %kv_start_61, %kv_start_63 : i32 loc(#loc346) + %kv_num_blocks_65 = tt.addptr %arg_FULL_KV_NUM_BLKS, %sparse_kv_num_blks_offset_36 : !tt.ptr, i32 loc(#loc347) + %kv_num_blocks_66 = tt.load %kv_num_blocks_65 : !tt.ptr loc(#loc348) + %block_n_end_67 = arith.constant 2 : i32 loc(#loc349) + %block_n_end_68 = arith.constant 2 : i32 loc(#loc349) + %block_n_end_69 = arith.muli %kv_num_blocks_66, %block_n_end_68 : i32 loc(#loc349) + %block_n_end_70 = tt.call @"triton.language.standard.cdiv__i32__(1,)cconstexpr_64_"(%ks1) : (i32) -> i32 loc(#loc350) + %block_n_end_71 = arith.constant 1 : i32 loc(#loc351) + %block_n_end_72 = arith.maxsi %block_n_end_70, %block_n_end_71 : i32 loc(#loc351) + %block_n_end_73 = arith.minsi %block_n_end_69, %block_n_end_72 : i32 loc(#loc352) + %offs_n_74 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc353) + %offs_n_75 = tt.splat %kv_start_64 : i32 -> tensor<64xi32> loc(#loc354) + %offs_n_76 = arith.addi %offs_n_75, %offs_n_74 : tensor<64xi32> loc(#loc354) + %8 = tt.expand_dims %offs_m_47 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc81) + %9 = tt.expand_dims %offs_n_76 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc82) + %10:3 = tt.call @"torch._inductor.runtime.compile_tasks.chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.forward_inner__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pi32_Pi32_Pi32_Pi32_Pi64_Pbf16_i32_i32_i32_i32_i32_i32_bf16S128_128S_Pbf16_Pbf16_i32_i32_fp32S128_128S_fp32S128S_fp32S128S_i32_i32_i32S128_1S_i32S1_64S_i32_Pi32_i32_i32_i32_i32_i32_i32__(20,)cNone_(21,)cNone_(34,)cconstexpr_0__(36,)cconstexpr_bf16__(41,)cconstexpr_True_"(%arg_Q, %arg_K, %arg_V, %arg_LSE, %arg_MAX, %arg_KV_NUM_BLKS, %arg_KV_IDX, %arg_FULL_KV_NUM_BLKS, %arg_FULL_KV_IDX, %in_ptr9, %out_ptr0, %ks0, %ks1, %ks2, %ks3, %ks4, %ks5, %q, %K, %V, %ks0, %ks1, %7#0, %7#1, %7#2, %off_zq, %off_hq, %8, %9, %kv_start_64, %kv_indices_60, %kv_num_blocks_66, %block_n_end_73, %c1_i32_6, %c128_i32_5, %c128_i32_11, %c1_i32_12) : (!tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, i32, i32, i32, i32, i32, i32, tensor<128x128xbf16>, !tt.ptr, !tt.ptr, i32, i32, tensor<128x128xf32>, tensor<128xf32>, tensor<128xf32>, i32, i32, tensor<128x1xi32>, tensor<1x64xi32>, i32, !tt.ptr, i32, i32, i32, i32, i32, i32) -> (tensor<128x128xf32>, tensor<128xf32>, tensor<128xf32>) loc(#loc83) + %l_i_77 = arith.constant 0.000000e+00 : f32 loc(#loc355) + %l_i_78 = arith.constant dense<0.000000e+00> : tensor<128xf32> loc(#loc355) + %l_i_79 = arith.cmpf oeq, %10#1, %l_i_78 : tensor<128xf32> loc(#loc355) + %l_i_80 = arith.constant 1 : i32 loc(#loc356) + %l_i_81 = arith.constant 1.000000e+00 : f32 loc(#loc356) + %l_i_82 = arith.constant dense<1.000000e+00> : tensor<128xf32> loc(#loc356) + %l_i_83 = arith.select %l_i_79, %l_i_82, %10#1 : tensor<128xi1>, tensor<128xf32> loc(#loc356) + %acc_84 = tt.expand_dims %l_i_83 {axis = 1 : i32} : tensor<128xf32> -> tensor<128x1xf32> loc(#loc357) + %acc_85 = tt.broadcast %acc_84 : tensor<128x1xf32> -> tensor<128x128xf32> loc(#loc358) + %acc_86 = arith.divf %10#0, %acc_85 : tensor<128x128xf32> loc(#loc358) + %idx_zq = tt.get_program_id y : i32 loc(#loc359) + %idx_hq = tt.get_program_id z : i32 loc(#loc360) + %idx_m = tt.expand_dims %offs_m_47 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc361) + %idx_d = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc362) + %idx_d_87 = tt.expand_dims %idx_d {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc363) + %mask = tt.splat %ks0 : i32 -> tensor<128x1xi32> loc(#loc364) + %mask_88 = arith.cmpi slt, %idx_m, %mask : tensor<128x1xi32> loc(#loc364) + %mask_89 = arith.constant 128 : i32 loc(#loc365) + %mask_90 = arith.constant dense<128> : tensor<1x128xi32> loc(#loc365) + %mask_91 = arith.cmpi slt, %idx_d_87, %mask_90 : tensor<1x128xi32> loc(#loc365) + %mask_92 = tt.broadcast %mask_88 : tensor<128x1xi1> -> tensor<128x128xi1> loc(#loc366) + %mask_93 = tt.broadcast %mask_91 : tensor<1x128xi1> -> tensor<128x128xi1> loc(#loc366) + %mask_94 = arith.andi %mask_92, %mask_93 : tensor<128x128xi1> loc(#loc366) + %xindex = arith.constant 128 : i32 loc(#loc367) + %xindex_95 = arith.constant 128 : i32 loc(#loc367) + %xindex_96 = arith.constant dense<128> : tensor<128x1xi32> loc(#loc367) + %xindex_97 = arith.muli %xindex_96, %idx_m : tensor<128x1xi32> loc(#loc367) + %xindex_98 = tt.broadcast %idx_d_87 : tensor<1x128xi32> -> tensor<128x128xi32> loc(#loc368) + %xindex_99 = tt.broadcast %xindex_97 : tensor<128x1xi32> -> tensor<128x128xi32> loc(#loc368) + %xindex_100 = arith.addi %xindex_98, %xindex_99 : tensor<128x128xi32> loc(#loc368) + %xindex_101 = arith.constant 128 : i32 loc(#loc369) + %xindex_102 = arith.constant 128 : i32 loc(#loc369) + %xindex_103 = arith.muli %xindex_102, %idx_hq : i32 loc(#loc369) + %xindex_104 = arith.muli %xindex_103, %ks0 : i32 loc(#loc370) + %xindex_105 = tt.splat %xindex_104 : i32 -> tensor<128x128xi32> loc(#loc371) + %xindex_106 = arith.addi %xindex_100, %xindex_105 : tensor<128x128xi32> loc(#loc371) + %xindex_107 = arith.constant 4096 : i32 loc(#loc372) + %xindex_108 = arith.constant 4096 : i32 loc(#loc372) + %xindex_109 = arith.muli %xindex_108, %idx_zq : i32 loc(#loc372) + %xindex_110 = arith.muli %xindex_109, %ks0 : i32 loc(#loc373) + %xindex_111 = tt.splat %xindex_110 : i32 -> tensor<128x128xi32> loc(#loc374) + %xindex_112 = arith.addi %xindex_106, %xindex_111 : tensor<128x128xi32> loc(#loc374) + %c128_i32_113 = arith.constant 128 : i32 loc(#loc104) + %c128_i32_114 = arith.constant 128 : i32 loc(#loc104) + %11 = arith.muli %c128_i32_114, %idx_hq : i32 loc(#loc104) + %12 = tt.splat %11 : i32 -> tensor<1x128xi32> loc(#loc105) + %13 = arith.addi %idx_d_87, %12 : tensor<1x128xi32> loc(#loc105) + %c4096_i32_115 = arith.constant 4096 : i32 loc(#loc106) + %c4096_i32_116 = arith.constant 4096 : i32 loc(#loc106) + %cst = arith.constant dense<4096> : tensor<128x1xi32> loc(#loc106) + %14 = arith.muli %cst, %idx_m : tensor<128x1xi32> loc(#loc106) + %15 = tt.broadcast %13 : tensor<1x128xi32> -> tensor<128x128xi32> loc(#loc107) + %16 = tt.broadcast %14 : tensor<128x1xi32> -> tensor<128x128xi32> loc(#loc107) + %17 = arith.addi %15, %16 : tensor<128x128xi32> loc(#loc107) + %c4096_i32_117 = arith.constant 4096 : i32 loc(#loc108) + %c4096_i32_118 = arith.constant 4096 : i32 loc(#loc108) + %18 = arith.muli %c4096_i32_118, %idx_zq : i32 loc(#loc108) + %19 = arith.muli %18, %ks0 : i32 loc(#loc109) + %20 = tt.splat %19 : i32 -> tensor<128x128xi32> loc(#loc110) + %21 = arith.addi %17, %20 : tensor<128x128xi32> loc(#loc110) + %22 = tt.splat %out_ptr0 : !tt.ptr -> tensor<128x128x!tt.ptr> loc(#loc111) + %23 = tt.addptr %22, %21 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> loc(#loc111) + %24 = arith.truncf %acc_86 : tensor<128x128xf32> to tensor<128x128xbf16> loc(#loc112) + tt.store %23, %24, %mask_94 : tensor<128x128x!tt.ptr> loc(#loc112) + %off_hz = arith.muli %off_zq, %HQ : i32 loc(#loc375) + %off_hz_119 = arith.addi %off_hz, %off_hq : i32 loc(#loc376) + %l_ptrs = arith.muli %off_hz_119, %ks0 : i32 loc(#loc377) + %l_ptrs_120 = tt.addptr %arg_LSE, %l_ptrs : !tt.ptr, i32 loc(#loc378) + %l_ptrs_121 = tt.splat %l_ptrs_120 : !tt.ptr -> tensor<128x!tt.ptr> loc(#loc379) + %l_ptrs_122 = tt.addptr %l_ptrs_121, %offs_m_47 : tensor<128x!tt.ptr>, tensor<128xi32> loc(#loc379) + %lse = math.log2 %l_i_83 : tensor<128xf32> loc(#loc380) + %lse_123 = arith.addf %10#2, %lse : tensor<128xf32> loc(#loc381) + %25 = tt.splat %ks0 : i32 -> tensor<128xi32> loc(#loc120) + %26 = arith.cmpi slt, %offs_m_47, %25 : tensor<128xi32> loc(#loc120) + tt.store %l_ptrs_122, %lse_123, %26 : tensor<128x!tt.ptr> loc(#loc121) + tt.return loc(#loc122) + } loc(#loc) + tt.func private @"triton.language.standard.zeros____(0, 0)cconstexpr_128__(1,)cconstexpr_fp32_"() -> tensor<128xf32> attributes {noinline = false} { + %cst = arith.constant 0.000000e+00 : f32 loc(#loc124) + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128xf32> loc(#loc124) + tt.return %cst_0 : tensor<128xf32> loc(#loc125) + ^bb1: // no predecessors + %0 = ub.poison : tensor<128xf32> loc(#loc126) + tt.return %0 : tensor<128xf32> loc(#loc126) + } loc(#loc123) + tt.func private @"triton.language.standard.zeros____(0, 0)cconstexpr_128__(0, 1)cconstexpr_128__(1,)cconstexpr_fp32_"() -> tensor<128x128xf32> attributes {noinline = false} { + %cst = arith.constant 0.000000e+00 : f32 loc(#loc124) + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> loc(#loc124) + tt.return %cst_0 : tensor<128x128xf32> loc(#loc125) + ^bb1: // no predecessors + %0 = ub.poison : tensor<128x128xf32> loc(#loc126) + tt.return %0 : tensor<128x128xf32> loc(#loc126) + } loc(#loc123) + tt.func private @"torch._inductor.runtime.compile_tasks.chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.load_checked_2d__Pbf16_i32S128S_i32S128S_i32_i32_i32__(5,)cconstexpr_False__(6,)cconstexpr_True__(8,)cconstexpr_128_"(%ptr: !tt.ptr loc("ptr"(#loc127)), %offs_m: tensor<128xi32> loc("offs_m"(#loc127)), %offs_n: tensor<128xi32> loc("offs_n"(#loc127)), %stride_m: i32 loc("stride_m"(#loc127)), %stride_n: i32 loc("stride_n"(#loc127)), %M_LEN: i32 loc("M_LEN"(#loc127))) -> tensor<128x128xbf16> attributes {noinline = false} { + %ptr_0 = tt.expand_dims %offs_m {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc388) + %ptr_1 = tt.splat %stride_m : i32 -> tensor<128x1xi32> loc(#loc389) + %ptr_2 = arith.muli %ptr_0, %ptr_1 : tensor<128x1xi32> loc(#loc389) + %ptr_3 = tt.splat %ptr : !tt.ptr -> tensor<128x1x!tt.ptr> loc(#loc390) + %ptr_4 = tt.addptr %ptr_3, %ptr_2 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> loc(#loc390) + %ptr_5 = tt.expand_dims %offs_n {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc391) + %ptr_6 = tt.splat %stride_n : i32 -> tensor<1x128xi32> loc(#loc392) + %ptr_7 = arith.muli %ptr_5, %ptr_6 : tensor<1x128xi32> loc(#loc392) + %ptr_8 = tt.broadcast %ptr_4 : tensor<128x1x!tt.ptr> -> tensor<128x128x!tt.ptr> loc(#loc393) + %ptr_9 = tt.broadcast %ptr_7 : tensor<1x128xi32> -> tensor<128x128xi32> loc(#loc393) + %ptr_10 = tt.addptr %ptr_8, %ptr_9 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> loc(#loc393) + %0 = tt.expand_dims %offs_m {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc134) + %1 = tt.splat %M_LEN : i32 -> tensor<128x1xi32> loc(#loc135) + %2 = arith.cmpi slt, %0, %1 : tensor<128x1xi32> loc(#loc135) + %cst = arith.constant 0.000000e+00 : f32 loc(#loc136) + %3 = tt.broadcast %2 : tensor<128x1xi1> -> tensor<128x128xi1> loc(#loc136) + %cst_11 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> loc(#loc136) + %4 = arith.truncf %cst_11 : tensor<128x128xf32> to tensor<128x128xbf16> loc(#loc136) + %5 = tt.load %ptr_10, %3, %4 : tensor<128x128x!tt.ptr> loc(#loc136) + tt.return %5 : tensor<128x128xbf16> loc(#loc137) + ^bb1: // no predecessors + %6 = ub.poison : tensor<128x128xbf16> loc(#loc138) + tt.return %6 : tensor<128x128xbf16> loc(#loc138) + } loc(#loc127) + tt.func private @"triton.language.standard.cdiv__i32__(1,)cconstexpr_64_"(%x: i32 loc("x"(#loc139))) -> i32 attributes {noinline = false} { + %c64_i32 = arith.constant 64 : i32 loc(#loc140) + %c64_i32_0 = arith.constant 64 : i32 loc(#loc140) + %0 = arith.addi %x, %c64_i32_0 : i32 loc(#loc140) + %c1_i32 = arith.constant 1 : i32 loc(#loc141) + %c1_i32_1 = arith.constant 1 : i32 loc(#loc141) + %1 = arith.subi %0, %c1_i32_1 : i32 loc(#loc141) + %c64_i32_2 = arith.constant 64 : i32 loc(#loc142) + %c64_i32_3 = arith.constant 64 : i32 loc(#loc142) + %2 = arith.divsi %1, %c64_i32_3 : i32 loc(#loc142) + tt.return %2 : i32 loc(#loc143) + ^bb1: // no predecessors + %3 = ub.poison : i32 loc(#loc144) + tt.return %3 : i32 loc(#loc144) + } loc(#loc139) + tt.func private @"torch._inductor.runtime.compile_tasks.chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.forward_inner__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pi32_Pi32_Pi32_Pi32_Pi64_Pbf16_i32_i32_i32_i32_i32_i32_bf16S128_128S_Pbf16_Pbf16_i32_i32_fp32S128_128S_fp32S128S_fp32S128S_i32_i32_i32S128_1S_i32S1_64S_i32_Pi32_i32_i32_i32_i32_i32_i32__(20,)cNone_(21,)cNone_(34,)cconstexpr_0__(36,)cconstexpr_bf16__(41,)cconstexpr_False_"(%arg_Q: !tt.ptr loc("arg_Q"(#loc145)), %arg_K: !tt.ptr loc("arg_K"(#loc145)), %arg_V: !tt.ptr loc("arg_V"(#loc145)), %arg_LSE: !tt.ptr loc("arg_LSE"(#loc145)), %arg_MAX: !tt.ptr loc("arg_MAX"(#loc145)), %arg_KV_NUM_BLKS: !tt.ptr loc("arg_KV_NUM_BLKS"(#loc145)), %arg_KV_IDX: !tt.ptr loc("arg_KV_IDX"(#loc145)), %arg_FULL_KV_NUM_BLKS: !tt.ptr loc("arg_FULL_KV_NUM_BLKS"(#loc145)), %arg_FULL_KV_IDX: !tt.ptr loc("arg_FULL_KV_IDX"(#loc145)), %in_ptr9: !tt.ptr loc("in_ptr9"(#loc145)), %out_ptr0: !tt.ptr loc("out_ptr0"(#loc145)), %ks0: i32 loc("ks0"(#loc145)), %ks1: i32 loc("ks1"(#loc145)), %ks2: i32 loc("ks2"(#loc145)), %ks3: i32 loc("ks3"(#loc145)), %ks4: i32 loc("ks4"(#loc145)), %ks5: i32 loc("ks5"(#loc145)), %q: tensor<128x128xbf16> loc("q"(#loc145)), %K: !tt.ptr loc("K"(#loc145)), %V: !tt.ptr loc("V"(#loc145)), %Q_LEN: i32 loc("Q_LEN"(#loc145)), %KV_LEN: i32 loc("KV_LEN"(#loc145)), %acc: tensor<128x128xf32> loc("acc"(#loc145)), %l_i: tensor<128xf32> loc("l_i"(#loc145)), %m_i: tensor<128xf32> loc("m_i"(#loc145)), %off_z: i32 loc("off_z"(#loc145)), %off_h: i32 loc("off_h"(#loc145)), %offs_m: tensor<128x1xi32> loc("offs_m"(#loc145)), %offs_n: tensor<1x64xi32> loc("offs_n"(#loc145)), %kv_start: i32 loc("kv_start"(#loc145)), %kv_indices: !tt.ptr loc("kv_indices"(#loc145)), %kv_num_blocks: i32 loc("kv_num_blocks"(#loc145)), %block_n_end: i32 loc("block_n_end"(#loc145)), %stride_kk: i32 loc("stride_kk"(#loc145)), %stride_kn: i32 loc("stride_kn"(#loc145)), %stride_vn: i32 loc("stride_vn"(#loc145)), %stride_vk: i32 loc("stride_vk"(#loc145))) -> (tensor<128x128xf32>, tensor<128xf32>, tensor<128xf32>) attributes {noinline = false} { + %kv_offset = arith.constant 0 : i32 loc(#loc432) + %c0_i32 = arith.constant 0 : i32 loc(#loc147) + %c1_i32 = arith.constant 1 : i32 loc(#loc147) + %0 = arith.bitcast %c0_i32 : i32 to i32 loc(#loc147) + %1 = arith.bitcast %block_n_end : i32 to i32 loc(#loc147) + %2 = arith.bitcast %c1_i32 : i32 to i32 loc(#loc147) + %3 = ub.poison : i32 loc(#loc147) + %kv_offset_0:5 = scf.for %start_n = %0 to %1 step %2 iter_args(%acc_1 = %acc, %l_i_2 = %l_i, %m_i_3 = %m_i, %offs_n_4 = %offs_n, %kv_offset_5 = %kv_offset) -> (tensor<128x128xf32>, tensor<128xf32>, tensor<128xf32>, tensor<1x64xi32>, i32) : i32 { + %7:3 = tt.call @"torch._inductor.runtime.compile_tasks.chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.forward_block_mn__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pi32_Pi32_Pi32_Pi32_Pi64_Pbf16_i32_i32_i32_i32_i32_i32_bf16S128_128S_Pbf16_Pbf16_i32_i32_fp32S128_128S_fp32S128S_fp32S128S_i32_i32_i32S128_1S_i32S1_64S_i32_i32_i32_i32_i32_i32__(20,)cconstexpr_None__(21,)cconstexpr_None__(33,)cconstexpr_bf16__(34,)cconstexpr_1_d_44269504__(39,)cconstexpr_False__(40,)cconstexpr_True_"(%arg_Q, %arg_K, %arg_V, %arg_LSE, %arg_MAX, %arg_KV_NUM_BLKS, %arg_KV_IDX, %arg_FULL_KV_NUM_BLKS, %arg_FULL_KV_IDX, %in_ptr9, %out_ptr0, %ks0, %ks1, %ks2, %ks3, %ks4, %ks5, %q, %K, %V, %Q_LEN, %KV_LEN, %acc_1, %l_i_2, %m_i_3, %off_z, %off_h, %offs_m, %offs_n_4, %kv_start, %kv_offset_5, %stride_kk, %stride_kn, %stride_vn, %stride_vk) : (!tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, i32, i32, i32, i32, i32, i32, tensor<128x128xbf16>, !tt.ptr, !tt.ptr, i32, i32, tensor<128x128xf32>, tensor<128xf32>, tensor<128xf32>, i32, i32, tensor<128x1xi32>, tensor<1x64xi32>, i32, i32, i32, i32, i32, i32) -> (tensor<128x128xf32>, tensor<128xf32>, tensor<128xf32>) loc(#loc148) + %offset = tt.call @"torch._inductor.runtime.compile_tasks.chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.get_offset_for_next_block__i32_Pi32_i32__(3,)cconstexpr_128__(4,)cconstexpr_2__(5,)cconstexpr_64__(6,)cconstexpr_False_"(%start_n, %kv_indices, %kv_num_blocks) : (i32, !tt.ptr, i32) -> i32 loc(#loc434) + %offs_n_6 = tt.splat %offset : i32 -> tensor<1x64xi32> loc(#loc435) + %offs_n_7 = arith.addi %offs_n_4, %offs_n_6 : tensor<1x64xi32> loc(#loc435) + %kv_offset_8 = arith.addi %kv_offset_5, %offset : i32 loc(#loc436) + scf.yield %7#0, %7#1, %7#2, %offs_n_7, %kv_offset_8 : tensor<128x128xf32>, tensor<128xf32>, tensor<128xf32>, tensor<1x64xi32>, i32 loc(#loc152) + } loc(#loc573) + tt.return %kv_offset_0#0, %kv_offset_0#1, %kv_offset_0#2 : tensor<128x128xf32>, tensor<128xf32>, tensor<128xf32> loc(#loc153) + ^bb1: // no predecessors + %4 = ub.poison : tensor<128x128xf32> loc(#loc154) + %5 = ub.poison : tensor<128xf32> loc(#loc154) + %6 = ub.poison : tensor<128xf32> loc(#loc154) + tt.return %4, %5, %6 : tensor<128x128xf32>, tensor<128xf32>, tensor<128xf32> loc(#loc154) + } loc(#loc145) + tt.func private @"torch._inductor.runtime.compile_tasks.chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.forward_block_mn__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pi32_Pi32_Pi32_Pi32_Pi64_Pbf16_i32_i32_i32_i32_i32_i32_bf16S128_128S_Pbf16_Pbf16_i32_i32_fp32S128_128S_fp32S128S_fp32S128S_i32_i32_i32S128_1S_i32S1_64S_i32_i32_i32_i32_i32_i32__(20,)cconstexpr_None__(21,)cconstexpr_None__(33,)cconstexpr_bf16__(34,)cconstexpr_1_d_44269504__(39,)cconstexpr_False__(40,)cconstexpr_True_"(%arg_Q: !tt.ptr loc("arg_Q"(#loc155)), %arg_K: !tt.ptr loc("arg_K"(#loc155)), %arg_V: !tt.ptr loc("arg_V"(#loc155)), %arg_LSE: !tt.ptr loc("arg_LSE"(#loc155)), %arg_MAX: !tt.ptr loc("arg_MAX"(#loc155)), %arg_KV_NUM_BLKS: !tt.ptr loc("arg_KV_NUM_BLKS"(#loc155)), %arg_KV_IDX: !tt.ptr loc("arg_KV_IDX"(#loc155)), %arg_FULL_KV_NUM_BLKS: !tt.ptr loc("arg_FULL_KV_NUM_BLKS"(#loc155)), %arg_FULL_KV_IDX: !tt.ptr loc("arg_FULL_KV_IDX"(#loc155)), %in_ptr9: !tt.ptr loc("in_ptr9"(#loc155)), %out_ptr0: !tt.ptr loc("out_ptr0"(#loc155)), %ks0: i32 loc("ks0"(#loc155)), %ks1: i32 loc("ks1"(#loc155)), %ks2: i32 loc("ks2"(#loc155)), %ks3: i32 loc("ks3"(#loc155)), %ks4: i32 loc("ks4"(#loc155)), %ks5: i32 loc("ks5"(#loc155)), %q: tensor<128x128xbf16> loc("q"(#loc155)), %K: !tt.ptr loc("K"(#loc155)), %V: !tt.ptr loc("V"(#loc155)), %Q_LEN: i32 loc("Q_LEN"(#loc155)), %KV_LEN: i32 loc("KV_LEN"(#loc155)), %acc: tensor<128x128xf32> loc("acc"(#loc155)), %l_i: tensor<128xf32> loc("l_i"(#loc155)), %m_i: tensor<128xf32> loc("m_i"(#loc155)), %off_z: i32 loc("off_z"(#loc155)), %off_h: i32 loc("off_h"(#loc155)), %offs_m: tensor<128x1xi32> loc("offs_m"(#loc155)), %offs_n: tensor<1x64xi32> loc("offs_n"(#loc155)), %kv_start: i32 loc("kv_start"(#loc155)), %kv_offset: i32 loc("kv_offset"(#loc155)), %stride_kk: i32 loc("stride_kk"(#loc155)), %stride_kn: i32 loc("stride_kn"(#loc155)), %stride_vn: i32 loc("stride_vn"(#loc155)), %stride_vk: i32 loc("stride_vk"(#loc155))) -> (tensor<128x128xf32>, tensor<128xf32>, tensor<128xf32>) attributes {noinline = false} { + %kv_base_offset = arith.addi %kv_start, %kv_offset : i32 loc(#loc472) + %offs_k = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc473) + %offs_n_load = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc474) + %offs_n_load_0 = tt.splat %kv_base_offset : i32 -> tensor<64xi32> loc(#loc475) + %offs_n_load_1 = arith.addi %offs_n_load_0, %offs_n_load : tensor<64xi32> loc(#loc475) + %k = tt.call @"torch._inductor.runtime.compile_tasks.chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.load_checked_2d__Pbf16_i32S64S_i32S128S_i32_i32_i32__(5,)cconstexpr_False__(6,)cconstexpr_True__(8,)cconstexpr_128_"(%K, %offs_n_load_1, %offs_k, %stride_kn, %stride_kk, %KV_LEN) : (!tt.ptr, tensor<64xi32>, tensor<128xi32>, i32, i32, i32) -> tensor<64x128xbf16> loc(#loc476) + %k_2 = tt.trans %k {order = array} : tensor<64x128xbf16> -> tensor<128x64xbf16> loc(#loc477) + %qk = arith.constant 0.000000e+00 : f32 loc(#loc478) + %qk_3 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc478) + %qk_4 = tt.dot %q, %k_2, %qk_3, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc478) + %qk_5 = arith.constant 0.0883883461 : f32 loc(#loc479) + %qk_6 = arith.constant 0.0883883461 : f32 loc(#loc479) + %qk_7 = arith.constant dense<0.0883883461> : tensor<128x64xf32> loc(#loc479) + %qk_8 = arith.mulf %qk_4, %qk_7 : tensor<128x64xf32> loc(#loc479) + %m = tt.call @torch._inductor.runtime.compile_tasks.chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.get_bounded_indices__i32S128_1S_i32__(%offs_m, %Q_LEN) : (tensor<128x1xi32>, i32) -> tensor<128x1xi32> loc(#loc480) + %n = tt.call @torch._inductor.runtime.compile_tasks.chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.get_bounded_indices__i32S1_64S_i32__(%offs_n, %KV_LEN) : (tensor<1x64xi32>, i32) -> tensor<1x64xi32> loc(#loc481) + %post_mod_scores = tt.splat %KV_LEN : i32 -> tensor<1x64xi32> loc(#loc482) + %post_mod_scores_9 = arith.cmpi slt, %offs_n, %post_mod_scores : tensor<1x64xi32> loc(#loc482) + %post_mod_scores_10 = arith.constant 0xFF800000 : f32 loc(#loc483) + %post_mod_scores_11 = arith.constant 0xFF800000 : f32 loc(#loc483) + %post_mod_scores_12 = arith.constant dense<0xFF800000> : tensor<128x64xf32> loc(#loc483) + %post_mod_scores_13 = tt.broadcast %post_mod_scores_9 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc483) + %post_mod_scores_14 = arith.select %post_mod_scores_13, %qk_8, %post_mod_scores_12 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc483) + %tmp1 = arith.constant false loc(#loc484) + %tmp1_15 = arith.constant dense : tensor<1xi1> loc(#loc484) + %tmp4 = tt.broadcast %m : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc485) + %tmp4_16 = tt.broadcast %n : tensor<1x64xi32> -> tensor<128x64xi32> loc(#loc485) + %tmp4_17 = arith.cmpi sge, %tmp4, %tmp4_16 : tensor<128x64xi32> loc(#loc485) + %tmp5 = arith.extsi %n : tensor<1x64xi32> to tensor<1x64xi64> loc(#loc486) + %tmp7 = tt.addptr %in_ptr9, %off_z : !tt.ptr, i32 loc(#loc487) + %tmp7_18 = tt.load %tmp7 : !tt.ptr loc(#loc488) + %tmp8 = tt.splat %tmp7_18 : i64 -> tensor<1x64xi64> loc(#loc489) + %tmp8_19 = arith.cmpi slt, %tmp5, %tmp8 : tensor<1x64xi64> loc(#loc489) + %tmp9 = arith.extsi %m : tensor<128x1xi32> to tensor<128x1xi64> loc(#loc490) + %tmp10 = tt.splat %tmp7_18 : i64 -> tensor<128x1xi64> loc(#loc491) + %tmp10_20 = arith.cmpi slt, %tmp9, %tmp10 : tensor<128x1xi64> loc(#loc491) + %tmp11 = tt.broadcast %tmp8_19 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc492) + %tmp11_21 = tt.broadcast %tmp10_20 : tensor<128x1xi1> -> tensor<128x64xi1> loc(#loc492) + %tmp11_22 = arith.andi %tmp11, %tmp11_21 : tensor<128x64xi1> loc(#loc492) + %tmp12 = arith.andi %tmp4_17, %tmp11_22 : tensor<128x64xi1> loc(#loc493) + %tmp13 = tt.expand_dims %tmp1_15 {axis = 0 : i32} : tensor<1xi1> -> tensor<1x1xi1> loc(#loc494) + %tmp13_23 = tt.broadcast %tmp13 : tensor<1x1xi1> -> tensor<128x64xi1> loc(#loc494) + %tmp13_24 = arith.ori %tmp13_23, %tmp12 : tensor<128x64xi1> loc(#loc494) + %tmp15 = tt.splat %ks5 : i32 -> tensor<1x64xi32> loc(#loc495) + %tmp15_25 = arith.cmpi sge, %n, %tmp15 : tensor<1x64xi32> loc(#loc495) + %tmp16 = tt.splat %ks5 : i32 -> tensor<1x64xi32> loc(#loc496) + %tmp16_26 = arith.remsi %n, %tmp16 : tensor<1x64xi32> loc(#loc496) + %tmp17 = arith.constant 0 : i32 loc(#loc497) + %tmp17_27 = arith.constant dense<0> : tensor<1xi32> loc(#loc497) + %tmp18 = tt.expand_dims %tmp17_27 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc498) + %tmp18_28 = tt.broadcast %tmp18 : tensor<1x1xi32> -> tensor<1x64xi32> loc(#loc498) + %tmp18_29 = arith.cmpi ne, %tmp16_26, %tmp18_28 : tensor<1x64xi32> loc(#loc498) + %tmp19 = arith.constant 0 : i32 loc(#loc499) + %tmp19_30 = arith.constant dense<0> : tensor<1x64xi32> loc(#loc499) + %tmp19_31 = arith.cmpi slt, %tmp16_26, %tmp19_30 : tensor<1x64xi32> loc(#loc499) + %tmp20 = arith.constant 0 : i32 loc(#loc500) + %tmp20_32 = arith.cmpi slt, %ks5, %tmp20 : i32 loc(#loc500) + %tmp21 = tt.splat %tmp20_32 : i1 -> tensor<1x64xi1> loc(#loc501) + %tmp21_33 = arith.cmpi ne, %tmp19_31, %tmp21 : tensor<1x64xi1> loc(#loc501) + %tmp22 = arith.andi %tmp18_29, %tmp21_33 : tensor<1x64xi1> loc(#loc502) + %tmp23 = tt.splat %ks5 : i32 -> tensor<1x64xi32> loc(#loc503) + %tmp23_34 = arith.addi %tmp16_26, %tmp23 : tensor<1x64xi32> loc(#loc503) + %tmp24 = arith.select %tmp22, %tmp23_34, %tmp16_26 : tensor<1x64xi1>, tensor<1x64xi32> loc(#loc504) + %tmp25 = arith.extsi %tmp24 : tensor<1x64xi32> to tensor<1x64xi64> loc(#loc505) + %tmp26 = tt.splat %tmp7_18 : i64 -> tensor<1x64xi64> loc(#loc506) + %tmp26_35 = arith.cmpi slt, %tmp25, %tmp26 : tensor<1x64xi64> loc(#loc506) + %tmp27 = arith.andi %tmp15_25, %tmp26_35 : tensor<1x64xi1> loc(#loc507) + %tmp28 = tt.broadcast %n : tensor<1x64xi32> -> tensor<128x64xi32> loc(#loc508) + %tmp28_36 = tt.broadcast %m : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc508) + %tmp28_37 = arith.subi %tmp28, %tmp28_36 : tensor<128x64xi32> loc(#loc508) + %tmp29 = tt.splat %ks5 : i32 -> tensor<128x64xi32> loc(#loc509) + %tmp29_38 = arith.remsi %tmp28_37, %tmp29 : tensor<128x64xi32> loc(#loc509) + %tmp30 = tt.expand_dims %tmp17_27 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc510) + %tmp30_39 = tt.broadcast %tmp30 : tensor<1x1xi32> -> tensor<128x64xi32> loc(#loc510) + %tmp30_40 = arith.cmpi ne, %tmp29_38, %tmp30_39 : tensor<128x64xi32> loc(#loc510) + %tmp31 = arith.constant 0 : i32 loc(#loc511) + %tmp31_41 = arith.constant dense<0> : tensor<128x64xi32> loc(#loc511) + %tmp31_42 = arith.cmpi slt, %tmp29_38, %tmp31_41 : tensor<128x64xi32> loc(#loc511) + %tmp32 = tt.splat %tmp20_32 : i1 -> tensor<128x64xi1> loc(#loc512) + %tmp32_43 = arith.cmpi ne, %tmp31_42, %tmp32 : tensor<128x64xi1> loc(#loc512) + %tmp33 = arith.andi %tmp30_40, %tmp32_43 : tensor<128x64xi1> loc(#loc513) + %tmp34 = tt.splat %ks5 : i32 -> tensor<128x64xi32> loc(#loc514) + %tmp34_44 = arith.addi %tmp29_38, %tmp34 : tensor<128x64xi32> loc(#loc514) + %tmp35 = arith.select %tmp33, %tmp34_44, %tmp29_38 : tensor<128x64xi1>, tensor<128x64xi32> loc(#loc515) + %tmp36 = tt.expand_dims %tmp17_27 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc516) + %tmp36_45 = tt.broadcast %tmp36 : tensor<1x1xi32> -> tensor<128x64xi32> loc(#loc516) + %tmp36_46 = arith.cmpi eq, %tmp35, %tmp36_45 : tensor<128x64xi32> loc(#loc516) + %tmp37 = tt.broadcast %tmp27 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc517) + %tmp37_47 = arith.andi %tmp37, %tmp36_46 : tensor<128x64xi1> loc(#loc517) + %tmp38 = arith.ori %tmp13_24, %tmp37_47 : tensor<128x64xi1> loc(#loc518) + %mask_mod_output = tt.splat %KV_LEN : i32 -> tensor<1x64xi32> loc(#loc519) + %mask_mod_output_48 = arith.cmpi slt, %offs_n, %mask_mod_output : tensor<1x64xi32> loc(#loc519) + %mask_mod_output_49 = arith.constant false loc(#loc520) + %mask_mod_output_50 = arith.constant false loc(#loc520) + %mask_mod_output_51 = arith.constant dense : tensor<128x64xi1> loc(#loc520) + %mask_mod_output_52 = tt.broadcast %mask_mod_output_48 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc520) + %mask_mod_output_53 = arith.select %mask_mod_output_52, %tmp38, %mask_mod_output_51 : tensor<128x64xi1>, tensor<128x64xi1> loc(#loc520) + %post_mod_scores_54 = arith.constant 0xFF800000 : f32 loc(#loc521) + %post_mod_scores_55 = arith.constant 0xFF800000 : f32 loc(#loc521) + %post_mod_scores_56 = arith.constant dense<0xFF800000> : tensor<128x64xf32> loc(#loc521) + %post_mod_scores_57 = arith.select %mask_mod_output_53, %post_mod_scores_14, %post_mod_scores_56 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc521) + %post_mod_scores_58 = arith.constant 1.44269502 : f32 loc(#loc522) + %post_mod_scores_59 = arith.constant 1.44269502 : f32 loc(#loc522) + %post_mod_scores_60 = arith.constant dense<1.44269502> : tensor<128x64xf32> loc(#loc522) + %post_mod_scores_61 = arith.mulf %post_mod_scores_57, %post_mod_scores_60 : tensor<128x64xf32> loc(#loc522) + %m_ij = tt.call @"triton.language.standard.max__fp32S128_64S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cconstexpr_True__(4,)cconstexpr_False_"(%post_mod_scores_61) : (tensor<128x64xf32>) -> tensor<128xf32> loc(#loc523) + %m_ij_62 = arith.maxnumf %m_i, %m_ij : tensor<128xf32> loc(#loc524) + %masked_out_rows = arith.constant 0xFF800000 : f32 loc(#loc525) + %masked_out_rows_63 = arith.constant dense<0xFF800000> : tensor<128xf32> loc(#loc525) + %masked_out_rows_64 = arith.cmpf oeq, %m_ij_62, %masked_out_rows_63 : tensor<128xf32> loc(#loc525) + %m_ij_masked = arith.constant 0 : i32 loc(#loc526) + %m_ij_masked_65 = arith.constant 0.000000e+00 : f32 loc(#loc526) + %m_ij_masked_66 = arith.constant dense<0.000000e+00> : tensor<128xf32> loc(#loc526) + %m_ij_masked_67 = arith.select %masked_out_rows_64, %m_ij_masked_66, %m_ij_62 : tensor<128xi1>, tensor<128xf32> loc(#loc526) + %alpha = arith.subf %m_i, %m_ij_masked_67 : tensor<128xf32> loc(#loc527) + %alpha_68 = math.exp2 %alpha : tensor<128xf32> loc(#loc528) + %p = tt.expand_dims %m_ij_masked_67 {axis = 1 : i32} : tensor<128xf32> -> tensor<128x1xf32> loc(#loc529) + %p_69 = tt.broadcast %p : tensor<128x1xf32> -> tensor<128x64xf32> loc(#loc530) + %p_70 = arith.subf %post_mod_scores_61, %p_69 : tensor<128x64xf32> loc(#loc530) + %p_71 = math.exp2 %p_70 : tensor<128x64xf32> loc(#loc531) + %l_i_72 = arith.mulf %l_i, %alpha_68 : tensor<128xf32> loc(#loc532) + %l_i_73 = tt.call @"triton.language.standard.sum__fp32S128_64S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%p_71) : (tensor<128x64xf32>) -> tensor<128xf32> loc(#loc533) + %l_i_74 = arith.addf %l_i_72, %l_i_73 : tensor<128xf32> loc(#loc534) + %acc_75 = tt.expand_dims %alpha_68 {axis = 1 : i32} : tensor<128xf32> -> tensor<128x1xf32> loc(#loc535) + %acc_76 = tt.broadcast %acc_75 : tensor<128x1xf32> -> tensor<128x128xf32> loc(#loc536) + %acc_77 = arith.mulf %acc, %acc_76 : tensor<128x128xf32> loc(#loc536) + %offs_v = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc537) + %v = tt.call @"torch._inductor.runtime.compile_tasks.chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.load_checked_2d__Pbf16_i32S64S_i32S128S_i32_i32_i32__(5,)cconstexpr_False__(6,)cconstexpr_True__(8,)cconstexpr_128_"(%V, %offs_n_load_1, %offs_v, %stride_vn, %stride_vk, %KV_LEN) : (!tt.ptr, tensor<64xi32>, tensor<128xi32>, i32, i32, i32) -> tensor<64x128xbf16> loc(#loc538) + %acc_78 = arith.truncf %p_71 : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc539) + %acc_79 = arith.constant 0.000000e+00 : f32 loc(#loc540) + %acc_80 = tt.dot %acc_78, %v, %acc_77, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc540) + tt.return %acc_80, %l_i_74, %m_ij_62 : tensor<128x128xf32>, tensor<128xf32>, tensor<128xf32> loc(#loc225) + ^bb1: // no predecessors + %0 = ub.poison : tensor<128x128xf32> loc(#loc226) + %1 = ub.poison : tensor<128xf32> loc(#loc226) + %2 = ub.poison : tensor<128xf32> loc(#loc226) + tt.return %0, %1, %2 : tensor<128x128xf32>, tensor<128xf32>, tensor<128xf32> loc(#loc226) + } loc(#loc155) + tt.func private @"torch._inductor.runtime.compile_tasks.chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.load_checked_2d__Pbf16_i32S64S_i32S128S_i32_i32_i32__(5,)cconstexpr_False__(6,)cconstexpr_True__(8,)cconstexpr_128_"(%ptr: !tt.ptr loc("ptr"(#loc127)), %offs_m: tensor<64xi32> loc("offs_m"(#loc127)), %offs_n: tensor<128xi32> loc("offs_n"(#loc127)), %stride_m: i32 loc("stride_m"(#loc127)), %stride_n: i32 loc("stride_n"(#loc127)), %M_LEN: i32 loc("M_LEN"(#loc127))) -> tensor<64x128xbf16> attributes {noinline = false} { + %ptr_0 = tt.expand_dims %offs_m {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc388) + %ptr_1 = tt.splat %stride_m : i32 -> tensor<64x1xi32> loc(#loc389) + %ptr_2 = arith.muli %ptr_0, %ptr_1 : tensor<64x1xi32> loc(#loc389) + %ptr_3 = tt.splat %ptr : !tt.ptr -> tensor<64x1x!tt.ptr> loc(#loc390) + %ptr_4 = tt.addptr %ptr_3, %ptr_2 : tensor<64x1x!tt.ptr>, tensor<64x1xi32> loc(#loc390) + %ptr_5 = tt.expand_dims %offs_n {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc391) + %ptr_6 = tt.splat %stride_n : i32 -> tensor<1x128xi32> loc(#loc392) + %ptr_7 = arith.muli %ptr_5, %ptr_6 : tensor<1x128xi32> loc(#loc392) + %ptr_8 = tt.broadcast %ptr_4 : tensor<64x1x!tt.ptr> -> tensor<64x128x!tt.ptr> loc(#loc393) + %ptr_9 = tt.broadcast %ptr_7 : tensor<1x128xi32> -> tensor<64x128xi32> loc(#loc393) + %ptr_10 = tt.addptr %ptr_8, %ptr_9 : tensor<64x128x!tt.ptr>, tensor<64x128xi32> loc(#loc393) + %0 = tt.expand_dims %offs_m {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc134) + %1 = tt.splat %M_LEN : i32 -> tensor<64x1xi32> loc(#loc135) + %2 = arith.cmpi slt, %0, %1 : tensor<64x1xi32> loc(#loc135) + %cst = arith.constant 0.000000e+00 : f32 loc(#loc136) + %3 = tt.broadcast %2 : tensor<64x1xi1> -> tensor<64x128xi1> loc(#loc136) + %cst_11 = arith.constant dense<0.000000e+00> : tensor<64x128xf32> loc(#loc136) + %4 = arith.truncf %cst_11 : tensor<64x128xf32> to tensor<64x128xbf16> loc(#loc136) + %5 = tt.load %ptr_10, %3, %4 : tensor<64x128x!tt.ptr> loc(#loc136) + tt.return %5 : tensor<64x128xbf16> loc(#loc137) + ^bb1: // no predecessors + %6 = ub.poison : tensor<64x128xbf16> loc(#loc138) + tt.return %6 : tensor<64x128xbf16> loc(#loc138) + } loc(#loc127) + tt.func private @torch._inductor.runtime.compile_tasks.chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.get_bounded_indices__i32S128_1S_i32__(%indices: tensor<128x1xi32> loc("indices"(#loc227)), %max_len: i32 loc("max_len"(#loc227))) -> tensor<128x1xi32> attributes {noinline = false} { + %0 = tt.splat %max_len : i32 -> tensor<128x1xi32> loc(#loc228) + %1 = arith.remsi %indices, %0 : tensor<128x1xi32> loc(#loc228) + tt.return %1 : tensor<128x1xi32> loc(#loc229) + ^bb1: // no predecessors + %2 = ub.poison : tensor<128x1xi32> loc(#loc230) + tt.return %2 : tensor<128x1xi32> loc(#loc230) + } loc(#loc227) + tt.func private @torch._inductor.runtime.compile_tasks.chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.get_bounded_indices__i32S1_64S_i32__(%indices: tensor<1x64xi32> loc("indices"(#loc227)), %max_len: i32 loc("max_len"(#loc227))) -> tensor<1x64xi32> attributes {noinline = false} { + %0 = tt.splat %max_len : i32 -> tensor<1x64xi32> loc(#loc228) + %1 = arith.remsi %indices, %0 : tensor<1x64xi32> loc(#loc228) + tt.return %1 : tensor<1x64xi32> loc(#loc229) + ^bb1: // no predecessors + %2 = ub.poison : tensor<1x64xi32> loc(#loc230) + tt.return %2 : tensor<1x64xi32> loc(#loc230) + } loc(#loc227) + tt.func private @"triton.language.standard.max__fp32S128_64S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cconstexpr_True__(4,)cconstexpr_False_"(%input: tensor<128x64xf32> loc("input"(#loc231))) -> tensor<128xf32> attributes {noinline = false} { + %0 = "tt.reduce"(%input) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32 loc(unknown), %arg2: f32 loc(unknown)): + %2 = tt.call @triton.language.standard._elementwise_max__fp32_fp32__(%arg1, %arg2) : (f32, f32) -> f32 loc(#loc232) + tt.reduce.return %2 : f32 loc(#loc232) + }) : (tensor<128x64xf32>) -> tensor<128xf32> loc(#loc232) + tt.return %0 : tensor<128xf32> loc(#loc234) + ^bb1: // no predecessors + %1 = ub.poison : tensor<128xf32> loc(#loc235) + tt.return %1 : tensor<128xf32> loc(#loc235) + } loc(#loc231) + tt.func private @triton.language.standard._elementwise_max__fp32_fp32__(%a: f32 loc("a"(#loc236)), %b: f32 loc("b"(#loc236))) -> f32 attributes {noinline = false} { + %0 = arith.maxnumf %a, %b : f32 loc(#loc237) + tt.return %0 : f32 loc(#loc238) + ^bb1: // no predecessors + %1 = ub.poison : f32 loc(#loc239) + tt.return %1 : f32 loc(#loc239) + } loc(#loc236) + tt.func private @"triton.language.standard.sum__fp32S128_64S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%input: tensor<128x64xf32> loc("input"(#loc240))) -> tensor<128xf32> attributes {noinline = false} { + %0 = "tt.reduce"(%input) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32 loc(unknown), %arg2: f32 loc(unknown)): + %2 = tt.call @triton.language.standard._sum_combine__fp32_fp32__(%arg1, %arg2) : (f32, f32) -> f32 loc(#loc241) + tt.reduce.return %2 : f32 loc(#loc241) + }) : (tensor<128x64xf32>) -> tensor<128xf32> loc(#loc241) + tt.return %0 : tensor<128xf32> loc(#loc242) + ^bb1: // no predecessors + %1 = ub.poison : tensor<128xf32> loc(#loc243) + tt.return %1 : tensor<128xf32> loc(#loc243) + } loc(#loc240) + tt.func private @triton.language.standard._sum_combine__fp32_fp32__(%a: f32 loc("a"(#loc244)), %b: f32 loc("b"(#loc244))) -> f32 attributes {noinline = false} { + %0 = arith.addf %a, %b : f32 loc(#loc245) + tt.return %0 : f32 loc(#loc246) + ^bb1: // no predecessors + %1 = ub.poison : f32 loc(#loc247) + tt.return %1 : f32 loc(#loc247) + } loc(#loc244) + tt.func private @"torch._inductor.runtime.compile_tasks.chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.get_offset_for_next_block__i32_Pi32_i32__(3,)cconstexpr_128__(4,)cconstexpr_2__(5,)cconstexpr_64__(6,)cconstexpr_False_"(%loop_iter: i32 loc("loop_iter"(#loc248)), %col_indices: !tt.ptr loc("col_indices"(#loc248)), %total_blocks: i32 loc("total_blocks"(#loc248))) -> i32 attributes {noinline = false} { + %cur_block_idx = arith.constant 2 : i32 loc(#loc552) + %cur_block_idx_0 = arith.constant 2 : i32 loc(#loc552) + %cur_block_idx_1 = arith.divsi %loop_iter, %cur_block_idx_0 : i32 loc(#loc552) + %cur_block = tt.addptr %col_indices, %cur_block_idx_1 : !tt.ptr, i32 loc(#loc553) + %cur_block_2 = tt.load %cur_block evictionPolicy = evict_last : !tt.ptr loc(#loc554) + %next_block = arith.constant 1 : i32 loc(#loc555) + %next_block_3 = arith.constant 1 : i32 loc(#loc555) + %next_block_4 = arith.addi %cur_block_idx_1, %next_block_3 : i32 loc(#loc555) + %next_block_5 = arith.cmpi slt, %next_block_4, %total_blocks : i32 loc(#loc556) + %next_block_6 = tt.addptr %col_indices, %cur_block_idx_1 : !tt.ptr, i32 loc(#loc557) + %next_block_7 = arith.constant 1 : i32 loc(#loc558) + %next_block_8 = tt.addptr %next_block_6, %next_block_7 : !tt.ptr, i32 loc(#loc558) + %next_block_9 = tt.load %next_block_8, %next_block_5 evictionPolicy = evict_last : !tt.ptr loc(#loc559) + %needs_jump = arith.constant 1 : i32 loc(#loc560) + %needs_jump_10 = arith.constant 1 : i32 loc(#loc560) + %needs_jump_11 = arith.addi %loop_iter, %needs_jump_10 : i32 loc(#loc560) + %needs_jump_12 = arith.constant 2 : i32 loc(#loc561) + %needs_jump_13 = arith.constant 2 : i32 loc(#loc561) + %needs_jump_14 = arith.remsi %needs_jump_11, %needs_jump_13 : i32 loc(#loc561) + %needs_jump_15 = arith.constant 0 : i32 loc(#loc562) + %needs_jump_16 = arith.cmpi eq, %needs_jump_14, %needs_jump_15 : i32 loc(#loc562) + %jump_to_block = arith.subi %next_block_9, %cur_block_2 : i32 loc(#loc563) + %jump_to_block_17 = arith.constant 128 : i32 loc(#loc564) + %jump_to_block_18 = arith.constant 128 : i32 loc(#loc564) + %jump_to_block_19 = arith.muli %jump_to_block, %jump_to_block_18 : i32 loc(#loc564) + %jump_to_block_20 = arith.constant 64 : i32 loc(#loc565) + %jump_to_block_21 = arith.constant 64 : i32 loc(#loc565) + %jump_to_block_22 = arith.subi %jump_to_block_19, %jump_to_block_21 : i32 loc(#loc565) + %offset = arith.extui %needs_jump_16 : i1 to i32 loc(#loc566) + %offset_23 = arith.muli %jump_to_block_22, %offset : i32 loc(#loc566) + %offset_24 = arith.constant 1 : i32 loc(#loc567) + %offset_25 = arith.constant 1 : i32 loc(#loc567) + %offset_26 = arith.extui %needs_jump_16 : i1 to i32 loc(#loc567) + %offset_27 = arith.subi %offset_25, %offset_26 : i32 loc(#loc567) + %offset_28 = arith.constant 64 : i32 loc(#loc568) + %offset_29 = arith.constant 64 : i32 loc(#loc568) + %offset_30 = arith.muli %offset_27, %offset_29 : i32 loc(#loc568) + %offset_31 = arith.addi %offset_23, %offset_30 : i32 loc(#loc569) + tt.return %offset_31 : i32 loc(#loc267) + ^bb1: // no predecessors + %0 = ub.poison : i32 loc(#loc268) + tt.return %0 : i32 loc(#loc268) + } loc(#loc248) + tt.func private @"torch._inductor.runtime.compile_tasks.chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.forward_inner__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pi32_Pi32_Pi32_Pi32_Pi64_Pbf16_i32_i32_i32_i32_i32_i32_bf16S128_128S_Pbf16_Pbf16_i32_i32_fp32S128_128S_fp32S128S_fp32S128S_i32_i32_i32S128_1S_i32S1_64S_i32_Pi32_i32_i32_i32_i32_i32_i32__(20,)cNone_(21,)cNone_(34,)cconstexpr_0__(36,)cconstexpr_bf16__(41,)cconstexpr_True_"(%arg_Q: !tt.ptr loc("arg_Q"(#loc145)), %arg_K: !tt.ptr loc("arg_K"(#loc145)), %arg_V: !tt.ptr loc("arg_V"(#loc145)), %arg_LSE: !tt.ptr loc("arg_LSE"(#loc145)), %arg_MAX: !tt.ptr loc("arg_MAX"(#loc145)), %arg_KV_NUM_BLKS: !tt.ptr loc("arg_KV_NUM_BLKS"(#loc145)), %arg_KV_IDX: !tt.ptr loc("arg_KV_IDX"(#loc145)), %arg_FULL_KV_NUM_BLKS: !tt.ptr loc("arg_FULL_KV_NUM_BLKS"(#loc145)), %arg_FULL_KV_IDX: !tt.ptr loc("arg_FULL_KV_IDX"(#loc145)), %in_ptr9: !tt.ptr loc("in_ptr9"(#loc145)), %out_ptr0: !tt.ptr loc("out_ptr0"(#loc145)), %ks0: i32 loc("ks0"(#loc145)), %ks1: i32 loc("ks1"(#loc145)), %ks2: i32 loc("ks2"(#loc145)), %ks3: i32 loc("ks3"(#loc145)), %ks4: i32 loc("ks4"(#loc145)), %ks5: i32 loc("ks5"(#loc145)), %q: tensor<128x128xbf16> loc("q"(#loc145)), %K: !tt.ptr loc("K"(#loc145)), %V: !tt.ptr loc("V"(#loc145)), %Q_LEN: i32 loc("Q_LEN"(#loc145)), %KV_LEN: i32 loc("KV_LEN"(#loc145)), %acc: tensor<128x128xf32> loc("acc"(#loc145)), %l_i: tensor<128xf32> loc("l_i"(#loc145)), %m_i: tensor<128xf32> loc("m_i"(#loc145)), %off_z: i32 loc("off_z"(#loc145)), %off_h: i32 loc("off_h"(#loc145)), %offs_m: tensor<128x1xi32> loc("offs_m"(#loc145)), %offs_n: tensor<1x64xi32> loc("offs_n"(#loc145)), %kv_start: i32 loc("kv_start"(#loc145)), %kv_indices: !tt.ptr loc("kv_indices"(#loc145)), %kv_num_blocks: i32 loc("kv_num_blocks"(#loc145)), %block_n_end: i32 loc("block_n_end"(#loc145)), %stride_kk: i32 loc("stride_kk"(#loc145)), %stride_kn: i32 loc("stride_kn"(#loc145)), %stride_vn: i32 loc("stride_vn"(#loc145)), %stride_vk: i32 loc("stride_vk"(#loc145))) -> (tensor<128x128xf32>, tensor<128xf32>, tensor<128xf32>) attributes {noinline = false} { + %kv_offset = arith.constant 0 : i32 loc(#loc432) + %c0_i32 = arith.constant 0 : i32 loc(#loc147) + %c1_i32 = arith.constant 1 : i32 loc(#loc147) + %0 = arith.bitcast %c0_i32 : i32 to i32 loc(#loc147) + %1 = arith.bitcast %block_n_end : i32 to i32 loc(#loc147) + %2 = arith.bitcast %c1_i32 : i32 to i32 loc(#loc147) + %3 = ub.poison : i32 loc(#loc147) + %kv_offset_0:5 = scf.for %start_n = %0 to %1 step %2 iter_args(%acc_1 = %acc, %l_i_2 = %l_i, %m_i_3 = %m_i, %offs_n_4 = %offs_n, %kv_offset_5 = %kv_offset) -> (tensor<128x128xf32>, tensor<128xf32>, tensor<128xf32>, tensor<1x64xi32>, i32) : i32 { + %7:3 = tt.call @"torch._inductor.runtime.compile_tasks.chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.forward_block_mn__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pi32_Pi32_Pi32_Pi32_Pi64_Pbf16_i32_i32_i32_i32_i32_i32_bf16S128_128S_Pbf16_Pbf16_i32_i32_fp32S128_128S_fp32S128S_fp32S128S_i32_i32_i32S128_1S_i32S1_64S_i32_i32_i32_i32_i32_i32__(20,)cconstexpr_None__(21,)cconstexpr_None__(33,)cconstexpr_bf16__(34,)cconstexpr_1_d_44269504__(39,)cconstexpr_True__(40,)cconstexpr_True_"(%arg_Q, %arg_K, %arg_V, %arg_LSE, %arg_MAX, %arg_KV_NUM_BLKS, %arg_KV_IDX, %arg_FULL_KV_NUM_BLKS, %arg_FULL_KV_IDX, %in_ptr9, %out_ptr0, %ks0, %ks1, %ks2, %ks3, %ks4, %ks5, %q, %K, %V, %Q_LEN, %KV_LEN, %acc_1, %l_i_2, %m_i_3, %off_z, %off_h, %offs_m, %offs_n_4, %kv_start, %kv_offset_5, %stride_kk, %stride_kn, %stride_vn, %stride_vk) : (!tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, i32, i32, i32, i32, i32, i32, tensor<128x128xbf16>, !tt.ptr, !tt.ptr, i32, i32, tensor<128x128xf32>, tensor<128xf32>, tensor<128xf32>, i32, i32, tensor<128x1xi32>, tensor<1x64xi32>, i32, i32, i32, i32, i32, i32) -> (tensor<128x128xf32>, tensor<128xf32>, tensor<128xf32>) loc(#loc148) + %offset = tt.call @"torch._inductor.runtime.compile_tasks.chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.get_offset_for_next_block__i32_Pi32_i32__(3,)cconstexpr_128__(4,)cconstexpr_2__(5,)cconstexpr_64__(6,)cconstexpr_False_"(%start_n, %kv_indices, %kv_num_blocks) : (i32, !tt.ptr, i32) -> i32 loc(#loc434) + %offs_n_6 = tt.splat %offset : i32 -> tensor<1x64xi32> loc(#loc435) + %offs_n_7 = arith.addi %offs_n_4, %offs_n_6 : tensor<1x64xi32> loc(#loc435) + %kv_offset_8 = arith.addi %kv_offset_5, %offset : i32 loc(#loc436) + scf.yield %7#0, %7#1, %7#2, %offs_n_7, %kv_offset_8 : tensor<128x128xf32>, tensor<128xf32>, tensor<128xf32>, tensor<1x64xi32>, i32 loc(#loc152) + } loc(#loc573) + tt.return %kv_offset_0#0, %kv_offset_0#1, %kv_offset_0#2 : tensor<128x128xf32>, tensor<128xf32>, tensor<128xf32> loc(#loc153) + ^bb1: // no predecessors + %4 = ub.poison : tensor<128x128xf32> loc(#loc154) + %5 = ub.poison : tensor<128xf32> loc(#loc154) + %6 = ub.poison : tensor<128xf32> loc(#loc154) + tt.return %4, %5, %6 : tensor<128x128xf32>, tensor<128xf32>, tensor<128xf32> loc(#loc154) + } loc(#loc145) + tt.func private @"torch._inductor.runtime.compile_tasks.chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.forward_block_mn__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pi32_Pi32_Pi32_Pi32_Pi64_Pbf16_i32_i32_i32_i32_i32_i32_bf16S128_128S_Pbf16_Pbf16_i32_i32_fp32S128_128S_fp32S128S_fp32S128S_i32_i32_i32S128_1S_i32S1_64S_i32_i32_i32_i32_i32_i32__(20,)cconstexpr_None__(21,)cconstexpr_None__(33,)cconstexpr_bf16__(34,)cconstexpr_1_d_44269504__(39,)cconstexpr_True__(40,)cconstexpr_True_"(%arg_Q: !tt.ptr loc("arg_Q"(#loc155)), %arg_K: !tt.ptr loc("arg_K"(#loc155)), %arg_V: !tt.ptr loc("arg_V"(#loc155)), %arg_LSE: !tt.ptr loc("arg_LSE"(#loc155)), %arg_MAX: !tt.ptr loc("arg_MAX"(#loc155)), %arg_KV_NUM_BLKS: !tt.ptr loc("arg_KV_NUM_BLKS"(#loc155)), %arg_KV_IDX: !tt.ptr loc("arg_KV_IDX"(#loc155)), %arg_FULL_KV_NUM_BLKS: !tt.ptr loc("arg_FULL_KV_NUM_BLKS"(#loc155)), %arg_FULL_KV_IDX: !tt.ptr loc("arg_FULL_KV_IDX"(#loc155)), %in_ptr9: !tt.ptr loc("in_ptr9"(#loc155)), %out_ptr0: !tt.ptr loc("out_ptr0"(#loc155)), %ks0: i32 loc("ks0"(#loc155)), %ks1: i32 loc("ks1"(#loc155)), %ks2: i32 loc("ks2"(#loc155)), %ks3: i32 loc("ks3"(#loc155)), %ks4: i32 loc("ks4"(#loc155)), %ks5: i32 loc("ks5"(#loc155)), %q: tensor<128x128xbf16> loc("q"(#loc155)), %K: !tt.ptr loc("K"(#loc155)), %V: !tt.ptr loc("V"(#loc155)), %Q_LEN: i32 loc("Q_LEN"(#loc155)), %KV_LEN: i32 loc("KV_LEN"(#loc155)), %acc: tensor<128x128xf32> loc("acc"(#loc155)), %l_i: tensor<128xf32> loc("l_i"(#loc155)), %m_i: tensor<128xf32> loc("m_i"(#loc155)), %off_z: i32 loc("off_z"(#loc155)), %off_h: i32 loc("off_h"(#loc155)), %offs_m: tensor<128x1xi32> loc("offs_m"(#loc155)), %offs_n: tensor<1x64xi32> loc("offs_n"(#loc155)), %kv_start: i32 loc("kv_start"(#loc155)), %kv_offset: i32 loc("kv_offset"(#loc155)), %stride_kk: i32 loc("stride_kk"(#loc155)), %stride_kn: i32 loc("stride_kn"(#loc155)), %stride_vn: i32 loc("stride_vn"(#loc155)), %stride_vk: i32 loc("stride_vk"(#loc155))) -> (tensor<128x128xf32>, tensor<128xf32>, tensor<128xf32>) attributes {noinline = false} { + %kv_base_offset = arith.addi %kv_start, %kv_offset : i32 loc(#loc472) + %offs_k = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc473) + %offs_n_load = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc474) + %offs_n_load_0 = tt.splat %kv_base_offset : i32 -> tensor<64xi32> loc(#loc475) + %offs_n_load_1 = arith.addi %offs_n_load_0, %offs_n_load : tensor<64xi32> loc(#loc475) + %k = tt.call @"torch._inductor.runtime.compile_tasks.chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.load_checked_2d__Pbf16_i32S64S_i32S128S_i32_i32_i32__(5,)cconstexpr_False__(6,)cconstexpr_True__(8,)cconstexpr_128_"(%K, %offs_n_load_1, %offs_k, %stride_kn, %stride_kk, %KV_LEN) : (!tt.ptr, tensor<64xi32>, tensor<128xi32>, i32, i32, i32) -> tensor<64x128xbf16> loc(#loc476) + %k_2 = tt.trans %k {order = array} : tensor<64x128xbf16> -> tensor<128x64xbf16> loc(#loc477) + %qk = arith.constant 0.000000e+00 : f32 loc(#loc478) + %qk_3 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc478) + %qk_4 = tt.dot %q, %k_2, %qk_3, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc478) + %qk_5 = arith.constant 0.0883883461 : f32 loc(#loc479) + %qk_6 = arith.constant 0.0883883461 : f32 loc(#loc479) + %qk_7 = arith.constant dense<0.0883883461> : tensor<128x64xf32> loc(#loc479) + %qk_8 = arith.mulf %qk_4, %qk_7 : tensor<128x64xf32> loc(#loc479) + %m = tt.call @torch._inductor.runtime.compile_tasks.chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.get_bounded_indices__i32S128_1S_i32__(%offs_m, %Q_LEN) : (tensor<128x1xi32>, i32) -> tensor<128x1xi32> loc(#loc480) + %n = tt.call @torch._inductor.runtime.compile_tasks.chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.get_bounded_indices__i32S1_64S_i32__(%offs_n, %KV_LEN) : (tensor<1x64xi32>, i32) -> tensor<1x64xi32> loc(#loc481) + %post_mod_scores = tt.splat %KV_LEN : i32 -> tensor<1x64xi32> loc(#loc482) + %post_mod_scores_9 = arith.cmpi slt, %offs_n, %post_mod_scores : tensor<1x64xi32> loc(#loc482) + %post_mod_scores_10 = arith.constant 0xFF800000 : f32 loc(#loc483) + %post_mod_scores_11 = arith.constant 0xFF800000 : f32 loc(#loc483) + %post_mod_scores_12 = arith.constant dense<0xFF800000> : tensor<128x64xf32> loc(#loc483) + %post_mod_scores_13 = tt.broadcast %post_mod_scores_9 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc483) + %post_mod_scores_14 = arith.select %post_mod_scores_13, %qk_8, %post_mod_scores_12 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc483) + %post_mod_scores_15 = arith.constant 1.44269502 : f32 loc(#loc522) + %post_mod_scores_16 = arith.constant 1.44269502 : f32 loc(#loc522) + %post_mod_scores_17 = arith.constant dense<1.44269502> : tensor<128x64xf32> loc(#loc522) + %post_mod_scores_18 = arith.mulf %post_mod_scores_14, %post_mod_scores_17 : tensor<128x64xf32> loc(#loc522) + %m_ij = tt.call @"triton.language.standard.max__fp32S128_64S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cconstexpr_True__(4,)cconstexpr_False_"(%post_mod_scores_18) : (tensor<128x64xf32>) -> tensor<128xf32> loc(#loc523) + %m_ij_19 = arith.maxnumf %m_i, %m_ij : tensor<128xf32> loc(#loc524) + %masked_out_rows = arith.constant 0xFF800000 : f32 loc(#loc525) + %masked_out_rows_20 = arith.constant dense<0xFF800000> : tensor<128xf32> loc(#loc525) + %masked_out_rows_21 = arith.cmpf oeq, %m_ij_19, %masked_out_rows_20 : tensor<128xf32> loc(#loc525) + %m_ij_masked = arith.constant 0 : i32 loc(#loc526) + %m_ij_masked_22 = arith.constant 0.000000e+00 : f32 loc(#loc526) + %m_ij_masked_23 = arith.constant dense<0.000000e+00> : tensor<128xf32> loc(#loc526) + %m_ij_masked_24 = arith.select %masked_out_rows_21, %m_ij_masked_23, %m_ij_19 : tensor<128xi1>, tensor<128xf32> loc(#loc526) + %alpha = arith.subf %m_i, %m_ij_masked_24 : tensor<128xf32> loc(#loc527) + %alpha_25 = math.exp2 %alpha : tensor<128xf32> loc(#loc528) + %p = tt.expand_dims %m_ij_masked_24 {axis = 1 : i32} : tensor<128xf32> -> tensor<128x1xf32> loc(#loc529) + %p_26 = tt.broadcast %p : tensor<128x1xf32> -> tensor<128x64xf32> loc(#loc530) + %p_27 = arith.subf %post_mod_scores_18, %p_26 : tensor<128x64xf32> loc(#loc530) + %p_28 = math.exp2 %p_27 : tensor<128x64xf32> loc(#loc531) + %l_i_29 = arith.mulf %l_i, %alpha_25 : tensor<128xf32> loc(#loc532) + %l_i_30 = tt.call @"triton.language.standard.sum__fp32S128_64S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%p_28) : (tensor<128x64xf32>) -> tensor<128xf32> loc(#loc533) + %l_i_31 = arith.addf %l_i_29, %l_i_30 : tensor<128xf32> loc(#loc534) + %acc_32 = tt.expand_dims %alpha_25 {axis = 1 : i32} : tensor<128xf32> -> tensor<128x1xf32> loc(#loc535) + %acc_33 = tt.broadcast %acc_32 : tensor<128x1xf32> -> tensor<128x128xf32> loc(#loc536) + %acc_34 = arith.mulf %acc, %acc_33 : tensor<128x128xf32> loc(#loc536) + %offs_v = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc537) + %v = tt.call @"torch._inductor.runtime.compile_tasks.chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.load_checked_2d__Pbf16_i32S64S_i32S128S_i32_i32_i32__(5,)cconstexpr_False__(6,)cconstexpr_True__(8,)cconstexpr_128_"(%V, %offs_n_load_1, %offs_v, %stride_vn, %stride_vk, %KV_LEN) : (!tt.ptr, tensor<64xi32>, tensor<128xi32>, i32, i32, i32) -> tensor<64x128xbf16> loc(#loc538) + %acc_35 = arith.truncf %p_28 : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc539) + %acc_36 = arith.constant 0.000000e+00 : f32 loc(#loc540) + %acc_37 = tt.dot %acc_35, %v, %acc_34, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc540) + tt.return %acc_37, %l_i_31, %m_ij_19 : tensor<128x128xf32>, tensor<128xf32>, tensor<128xf32> loc(#loc225) + ^bb1: // no predecessors + %0 = ub.poison : tensor<128x128xf32> loc(#loc226) + %1 = ub.poison : tensor<128xf32> loc(#loc226) + %2 = ub.poison : tensor<128xf32> loc(#loc226) + tt.return %0, %1, %2 : tensor<128x128xf32>, tensor<128xf32>, tensor<128xf32> loc(#loc226) + } loc(#loc155) +} loc(#loc) +#loc1 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":85:54) +#loc2 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":85:49) +#loc3 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":86:54) +#loc4 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":86:63) +#loc5 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":86:49) +#loc6 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":87:54) +#loc7 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":87:63) +#loc8 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":87:49) +#loc9 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":89:9) +#loc10 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":90:9) +#loc11 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":92:10) +#loc12 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":97:28) +#loc13 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":98:27) +#loc14 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":99:27) +#loc15 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":103:23) +#loc16 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":104:24) +#loc17 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":105:21) +#loc18 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":107:24) +#loc19 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":107:45) +#loc20 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":107:36) +#loc21 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":108:25) +#loc22 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":108:47) +#loc23 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":108:37) +#loc24 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":109:25) +#loc25 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":109:47) +#loc26 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":109:37) +#loc27 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":111:12) +#loc28 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":112:12) +#loc29 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":113:12) +#loc30 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":120:15) +#loc31 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":121:16) +#loc32 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":123:28) +#loc33 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":124:29) +#loc34 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":130:26) +#loc35 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":134:19) +#loc36 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":134:50) +#loc37 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":135:19) +#loc38 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":136:19) +#loc39 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":138:23) +#loc40 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":138:46) +#loc41 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":138:33) +#loc42 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":141:38) +#loc43 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":141:50) +#loc44 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":142:51) +#loc45 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":142:85) +#loc46 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":142:74) +#loc47 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":143:46) +#loc48 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":143:76) +#loc49 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":143:97) +#loc50 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":143:64) +#loc51 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":144:23) +#loc52 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":144:46) +#loc53 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":144:33) +#loc54 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":145:26) +#loc55 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":146:101) +#loc56 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":151:26) +#loc57 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":152:23) +#loc58 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":152:37) +#loc59 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":153:42) +#loc60 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":153:28) +#loc61 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":154:45) +#loc62 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":154:92) +#loc63 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":154:102) +#loc64 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":154:65) +#loc65 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":159:37) +#loc66 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":159:24) +#loc67 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":167:31) +#loc68 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":167:48) +#loc69 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":172:41) +#loc70 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":181:35) +#loc71 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":182:27) +#loc72 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":182:41) +#loc73 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":183:51) +#loc74 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":183:32) +#loc75 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":184:49) +#loc76 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":184:96) +#loc77 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":184:106) +#loc78 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":184:69) +#loc79 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":186:41) +#loc80 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":186:28) +#loc81 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":193:35) +#loc82 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":193:52) +#loc83 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":198:45) +#loc84 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":206:26) +#loc85 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":206:34) +#loc86 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":208:20) +#loc87 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":208:16) +#loc88 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":209:27) +#loc89 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":210:27) +#loc90 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":211:19) +#loc91 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":212:25) +#loc92 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":212:45) +#loc93 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":214:20) +#loc94 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":214:38) +#loc95 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":214:30) +#loc96 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":217:25) +#loc97 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":217:21) +#loc98 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":217:37) +#loc99 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":217:44) +#loc100 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":217:33) +#loc101 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":217:55) +#loc102 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":217:62) +#loc103 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":217:50) +#loc104 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":218:53) +#loc105 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":218:49) +#loc106 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":218:67) +#loc107 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":218:62) +#loc108 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":218:80) +#loc109 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":218:87) +#loc110 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":218:75) +#loc111 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":218:25) +#loc112 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":218:110) +#loc113 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":221:26) +#loc114 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":221:31) +#loc115 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":222:32) +#loc116 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":222:23) +#loc117 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":222:40) +#loc118 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":223:33) +#loc119 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":223:20) +#loc120 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":227:48) +#loc121 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":227:29) +#loc122 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":229:4) +#loc123 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":118:0) +#loc124 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":127:31) +#loc125 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":127:11) +#loc126 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":127:4) +#loc128 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":284:27) +#loc129 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":284:38) +#loc130 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":284:20) +#loc131 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":284:56) +#loc132 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":284:67) +#loc133 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":284:49) +#loc134 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":292:41) +#loc135 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":292:52) +#loc136 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":292:23) +#loc137 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":292:15) +#loc138 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":287:4) +#loc140 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":41:16) +#loc141 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":41:22) +#loc142 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":41:28) +#loc143 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":41:11) +#loc144 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":41:4) +#loc146 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":499:16) +#loc147 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":502:40) +#loc148 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":538:16) +#loc149 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":545:63) +#loc150 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":548:26) +#loc151 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":549:21) +#loc152 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":549:8) +#loc153 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":552:11) +#loc154 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":552:4) +#loc156 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":342:32) +#loc157 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":345:26) +#loc158 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":346:48) +#loc159 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":346:35) +#loc160 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":347:107) +#loc161 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":349:17) +#loc162 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":351:19) +#loc163 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":353:14) +#loc164 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":358:36) +#loc165 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":359:36) +#loc166 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":367:44) +#loc167 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":367:69) +#loc168 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":370:35) +#loc169 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":373:23) +#loc170 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":374:23) +#loc171 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":376:33) +#loc172 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":376:23) +#loc173 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":377:22) +#loc174 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":378:23) +#loc175 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":379:23) +#loc176 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":380:23) +#loc177 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":381:23) +#loc178 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":382:23) +#loc179 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":384:24) +#loc180 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":385:24) +#loc181 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":386:32) +#loc182 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":387:25) +#loc183 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":388:92) +#loc184 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":389:92) +#loc185 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":390:25) +#loc186 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":391:24) +#loc187 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":392:24) +#loc188 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":393:39) +#loc189 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":394:25) +#loc190 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":395:24) +#loc191 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":396:24) +#loc192 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":397:23) +#loc193 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":398:25) +#loc194 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":399:25) +#loc195 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":400:92) +#loc196 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":401:25) +#loc197 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":402:24) +#loc198 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":403:24) +#loc199 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":404:39) +#loc200 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":405:25) +#loc201 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":406:24) +#loc202 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":407:24) +#loc203 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":412:48) +#loc204 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":412:73) +#loc205 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":414:69) +#loc206 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":417:27) +#loc207 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":421:51) +#loc208 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":421:27) +#loc209 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":423:35) +#loc210 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":424:51) +#loc211 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":428:31) +#loc212 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":428:25) +#loc213 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":429:51) +#loc214 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":429:39) +#loc215 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":429:21) +#loc216 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":434:16) +#loc217 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":434:34) +#loc218 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":434:24) +#loc219 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":436:22) +#loc220 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":436:16) +#loc221 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":438:26) +#loc222 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":439:107) +#loc223 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":440:22) +#loc224 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":440:44) +#loc225 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":445:11) +#loc226 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":445:4) +#loc228 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":257:21) +#loc229 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":257:11) +#loc230 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":257:4) +#loc232 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":189:40) +#loc234 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":189:15) +#loc235 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":177:4) +#loc237 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":168:27) +#loc238 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":168:11) +#loc239 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":168:4) +#loc241 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:36) +#loc242 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:11) +#loc243 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:4) +#loc245 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:15) +#loc246 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:11) +#loc247 = loc("/workspace/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:4) +#loc249 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":247:33) +#loc250 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":248:38) +#loc251 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":248:24) +#loc252 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":249:109) +#loc253 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":249:113) +#loc254 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":249:39) +#loc255 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":249:55) +#loc256 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":249:25) +#loc257 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":250:30) +#loc258 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":250:35) +#loc259 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":250:60) +#loc260 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":251:34) +#loc261 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":251:48) +#loc262 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":251:63) +#loc263 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":252:29) +#loc264 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":252:47) +#loc265 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":252:61) +#loc266 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":252:42) +#loc267 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":253:11) +#loc268 = loc("/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hl/chlgxuwkoqa3b2xlsfagnvc7a4ucsl37zj5momx5db76vetqwkrv.py":253:4) +#loc286 = loc("ZQ"(#loc9)) +#loc287 = loc("HQ"(#loc10)) +#loc288 = loc("ZKV"(#loc11)) +#loc289 = loc("q_start"(#loc12)) +#loc290 = loc("off_zq"(#loc13)) +#loc291 = loc("off_hq"(#loc14)) +#loc292 = loc("off_zkv"(#loc15)) +#loc293 = loc("off_hkv"(#loc16)) +#loc294 = loc("off_g"(#loc17)) +#loc295 = loc("q_offset"(#loc18)) +#loc296 = loc("q_offset"(#loc19)) +#loc297 = loc("q_offset"(#loc20)) +#loc298 = loc("k_offset"(#loc21)) +#loc299 = loc("k_offset"(#loc22)) +#loc300 = loc("k_offset"(#loc23)) +#loc301 = loc("v_offset"(#loc24)) +#loc302 = loc("v_offset"(#loc25)) +#loc303 = loc("v_offset"(#loc26)) +#loc304 = loc("Q"(#loc27)) +#loc305 = loc("K"(#loc28)) +#loc306 = loc("V"(#loc29)) +#loc307 = loc("SPARSE_Z"(#loc30)) +#loc308 = loc("SPARSE_HQ"(#loc31)) +#loc309 = loc("sparse_idx_z"(#loc32)) +#loc310 = loc("sparse_idx_hq"(#loc33)) +#loc311 = loc("stride_kv_idx_h"(#loc34)) +#loc312 = loc("m_i"(#loc35)) +#loc313 = loc("m_i"(#loc36)) +#loc314 = loc("l_i"(#loc37)) +#loc315 = loc("acc"(#loc38)) +#loc316 = loc("offs_m"(#loc39)) +#loc317 = loc("offs_m"(#loc40)) +#loc318 = loc("offs_m"(#loc41)) +#loc319 = loc("sparse_hz_offset"(#loc42)) +#loc320 = loc("sparse_hz_offset"(#loc43)) +#loc321 = loc("sparse_kv_num_blks_offset"(#loc44)) +#loc322 = loc("sparse_kv_num_blks_offset"(#loc45)) +#loc323 = loc("sparse_kv_num_blks_offset"(#loc46)) +#loc324 = loc("sparse_kv_idx_offset"(#loc47)) +#loc325 = loc("sparse_kv_idx_offset"(#loc48)) +#loc326 = loc("sparse_kv_idx_offset"(#loc49)) +#loc327 = loc("sparse_kv_idx_offset"(#loc50)) +#loc328 = loc("offs_m"(#loc51)) +#loc329 = loc("offs_m"(#loc52)) +#loc330 = loc("offs_m"(#loc53)) +#loc331 = loc("offs_k"(#loc54)) +#loc332 = loc("q"(#loc55)) +#loc333 = loc("kv_indices"(#loc56)) +#loc334 = loc("kv_start"(#loc57)) +#loc335 = loc("kv_start"(#loc58)) +#loc336 = loc("kv_num_blocks"(#loc59)) +#loc337 = loc("kv_num_blocks"(#loc60)) +#loc338 = loc("block_n_end"(#loc61)) +#loc339 = loc("block_n_end"(#loc62)) +#loc340 = loc("block_n_end"(#loc63)) +#loc341 = loc("block_n_end"(#loc64)) +#loc342 = loc("offs_n"(#loc65)) +#loc343 = loc("offs_n"(#loc66)) +#loc344 = loc("kv_indices"(#loc70)) +#loc345 = loc("kv_start"(#loc71)) +#loc346 = loc("kv_start"(#loc72)) +#loc347 = loc("kv_num_blocks"(#loc73)) +#loc348 = loc("kv_num_blocks"(#loc74)) +#loc349 = loc("block_n_end"(#loc75)) +#loc350 = loc("block_n_end"(#loc76)) +#loc351 = loc("block_n_end"(#loc77)) +#loc352 = loc("block_n_end"(#loc78)) +#loc353 = loc("offs_n"(#loc79)) +#loc354 = loc("offs_n"(#loc80)) +#loc355 = loc("l_i"(#loc84)) +#loc356 = loc("l_i"(#loc85)) +#loc357 = loc("acc"(#loc86)) +#loc358 = loc("acc"(#loc87)) +#loc359 = loc("idx_zq"(#loc88)) +#loc360 = loc("idx_hq"(#loc89)) +#loc361 = loc("idx_m"(#loc90)) +#loc362 = loc("idx_d"(#loc91)) +#loc363 = loc("idx_d"(#loc92)) +#loc364 = loc("mask"(#loc93)) +#loc365 = loc("mask"(#loc94)) +#loc366 = loc("mask"(#loc95)) +#loc367 = loc("xindex"(#loc96)) +#loc368 = loc("xindex"(#loc97)) +#loc369 = loc("xindex"(#loc98)) +#loc370 = loc("xindex"(#loc99)) +#loc371 = loc("xindex"(#loc100)) +#loc372 = loc("xindex"(#loc101)) +#loc373 = loc("xindex"(#loc102)) +#loc374 = loc("xindex"(#loc103)) +#loc375 = loc("off_hz"(#loc113)) +#loc376 = loc("off_hz"(#loc114)) +#loc377 = loc("l_ptrs"(#loc115)) +#loc378 = loc("l_ptrs"(#loc116)) +#loc379 = loc("l_ptrs"(#loc117)) +#loc380 = loc("lse"(#loc118)) +#loc381 = loc("lse"(#loc119)) +#loc388 = loc("ptr"(#loc128)) +#loc389 = loc("ptr"(#loc129)) +#loc390 = loc("ptr"(#loc130)) +#loc391 = loc("ptr"(#loc131)) +#loc392 = loc("ptr"(#loc132)) +#loc393 = loc("ptr"(#loc133)) +#loc432 = loc("kv_offset"(#loc146)) +#loc433 = loc("acc"(#loc147)) +#loc434 = loc("offset"(#loc149)) +#loc435 = loc("offs_n"(#loc150)) +#loc436 = loc("kv_offset"(#loc151)) +#loc472 = loc("kv_base_offset"(#loc156)) +#loc473 = loc("offs_k"(#loc157)) +#loc474 = loc("offs_n_load"(#loc158)) +#loc475 = loc("offs_n_load"(#loc159)) +#loc476 = loc("k"(#loc160)) +#loc477 = loc("k"(#loc161)) +#loc478 = loc("qk"(#loc162)) +#loc479 = loc("qk"(#loc163)) +#loc480 = loc("m"(#loc164)) +#loc481 = loc("n"(#loc165)) +#loc482 = loc("post_mod_scores"(#loc166)) +#loc483 = loc("post_mod_scores"(#loc167)) +#loc484 = loc("tmp1"(#loc168)) +#loc485 = loc("tmp4"(#loc169)) +#loc486 = loc("tmp5"(#loc170)) +#loc487 = loc("tmp7"(#loc171)) +#loc488 = loc("tmp7"(#loc172)) +#loc489 = loc("tmp8"(#loc173)) +#loc490 = loc("tmp9"(#loc174)) +#loc491 = loc("tmp10"(#loc175)) +#loc492 = loc("tmp11"(#loc176)) +#loc493 = loc("tmp12"(#loc177)) +#loc494 = loc("tmp13"(#loc178)) +#loc495 = loc("tmp15"(#loc179)) +#loc496 = loc("tmp16"(#loc180)) +#loc497 = loc("tmp17"(#loc181)) +#loc498 = loc("tmp18"(#loc182)) +#loc499 = loc("tmp19"(#loc183)) +#loc500 = loc("tmp20"(#loc184)) +#loc501 = loc("tmp21"(#loc185)) +#loc502 = loc("tmp22"(#loc186)) +#loc503 = loc("tmp23"(#loc187)) +#loc504 = loc("tmp24"(#loc188)) +#loc505 = loc("tmp25"(#loc189)) +#loc506 = loc("tmp26"(#loc190)) +#loc507 = loc("tmp27"(#loc191)) +#loc508 = loc("tmp28"(#loc192)) +#loc509 = loc("tmp29"(#loc193)) +#loc510 = loc("tmp30"(#loc194)) +#loc511 = loc("tmp31"(#loc195)) +#loc512 = loc("tmp32"(#loc196)) +#loc513 = loc("tmp33"(#loc197)) +#loc514 = loc("tmp34"(#loc198)) +#loc515 = loc("tmp35"(#loc199)) +#loc516 = loc("tmp36"(#loc200)) +#loc517 = loc("tmp37"(#loc201)) +#loc518 = loc("tmp38"(#loc202)) +#loc519 = loc("mask_mod_output"(#loc203)) +#loc520 = loc("mask_mod_output"(#loc204)) +#loc521 = loc("post_mod_scores"(#loc205)) +#loc522 = loc("post_mod_scores"(#loc206)) +#loc523 = loc("m_ij"(#loc207)) +#loc524 = loc("m_ij"(#loc208)) +#loc525 = loc("masked_out_rows"(#loc209)) +#loc526 = loc("m_ij_masked"(#loc210)) +#loc527 = loc("alpha"(#loc211)) +#loc528 = loc("alpha"(#loc212)) +#loc529 = loc("p"(#loc213)) +#loc530 = loc("p"(#loc214)) +#loc531 = loc("p"(#loc215)) +#loc532 = loc("l_i"(#loc216)) +#loc533 = loc("l_i"(#loc217)) +#loc534 = loc("l_i"(#loc218)) +#loc535 = loc("acc"(#loc219)) +#loc536 = loc("acc"(#loc220)) +#loc537 = loc("offs_v"(#loc221)) +#loc538 = loc("v"(#loc222)) +#loc539 = loc("acc"(#loc223)) +#loc540 = loc("acc"(#loc224)) +#loc552 = loc("cur_block_idx"(#loc249)) +#loc553 = loc("cur_block"(#loc250)) +#loc554 = loc("cur_block"(#loc251)) +#loc555 = loc("next_block"(#loc252)) +#loc556 = loc("next_block"(#loc253)) +#loc557 = loc("next_block"(#loc254)) +#loc558 = loc("next_block"(#loc255)) +#loc559 = loc("next_block"(#loc256)) +#loc560 = loc("needs_jump"(#loc257)) +#loc561 = loc("needs_jump"(#loc258)) +#loc562 = loc("needs_jump"(#loc259)) +#loc563 = loc("jump_to_block"(#loc260)) +#loc564 = loc("jump_to_block"(#loc261)) +#loc565 = loc("jump_to_block"(#loc262)) +#loc566 = loc("offset"(#loc263)) +#loc567 = loc("offset"(#loc264)) +#loc568 = loc("offset"(#loc265)) +#loc569 = loc("offset"(#loc266)) +#loc570 = loc("l_i"(#loc433)) +#loc571 = loc("m_i"(#loc570)) +#loc572 = loc("offs_n"(#loc571)) +#loc573 = loc("kv_offset"(#loc572)) diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/X4SFCUNHNVK6FR6CSIUU4JIDJXVPMITMWOHHGKRF3QCUQNY7M77Q/__grp__triton_poi_fused_new_zeros_1.json b/SpecForge-ext/cache/compiled_kernels/triton/0/X4SFCUNHNVK6FR6CSIUU4JIDJXVPMITMWOHHGKRF3QCUQNY7M77Q/__grp__triton_poi_fused_new_zeros_1.json new file mode 100644 index 0000000000000000000000000000000000000000..95c277ddce978d5883a19e0e6394ded1c0b26c85 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/X4SFCUNHNVK6FR6CSIUU4JIDJXVPMITMWOHHGKRF3QCUQNY7M77Q/__grp__triton_poi_fused_new_zeros_1.json @@ -0,0 +1 @@ +{"child_paths": {"triton_poi_fused_new_zeros_1.source": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/X4SFCUNHNVK6FR6CSIUU4JIDJXVPMITMWOHHGKRF3QCUQNY7M77Q/triton_poi_fused_new_zeros_1.source", "triton_poi_fused_new_zeros_1.ttir": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/X4SFCUNHNVK6FR6CSIUU4JIDJXVPMITMWOHHGKRF3QCUQNY7M77Q/triton_poi_fused_new_zeros_1.ttir", "triton_poi_fused_new_zeros_1.ttgir": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/X4SFCUNHNVK6FR6CSIUU4JIDJXVPMITMWOHHGKRF3QCUQNY7M77Q/triton_poi_fused_new_zeros_1.ttgir", "triton_poi_fused_new_zeros_1.llir": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/X4SFCUNHNVK6FR6CSIUU4JIDJXVPMITMWOHHGKRF3QCUQNY7M77Q/triton_poi_fused_new_zeros_1.llir", "triton_poi_fused_new_zeros_1.ptx": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/X4SFCUNHNVK6FR6CSIUU4JIDJXVPMITMWOHHGKRF3QCUQNY7M77Q/triton_poi_fused_new_zeros_1.ptx", "triton_poi_fused_new_zeros_1.cubin": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/X4SFCUNHNVK6FR6CSIUU4JIDJXVPMITMWOHHGKRF3QCUQNY7M77Q/triton_poi_fused_new_zeros_1.cubin", "triton_poi_fused_new_zeros_1.json": "/workspace/hanrui/SpecForge-ext/cache/compiled_kernels/triton/0/X4SFCUNHNVK6FR6CSIUU4JIDJXVPMITMWOHHGKRF3QCUQNY7M77Q/triton_poi_fused_new_zeros_1.json"}} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/triton/0/X4SFCUNHNVK6FR6CSIUU4JIDJXVPMITMWOHHGKRF3QCUQNY7M77Q/triton_poi_fused_new_zeros_1.json b/SpecForge-ext/cache/compiled_kernels/triton/0/X4SFCUNHNVK6FR6CSIUU4JIDJXVPMITMWOHHGKRF3QCUQNY7M77Q/triton_poi_fused_new_zeros_1.json new file mode 100644 index 0000000000000000000000000000000000000000..1763cdc7954c2932273615e3329fb5dda994f1a4 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/triton/0/X4SFCUNHNVK6FR6CSIUU4JIDJXVPMITMWOHHGKRF3QCUQNY7M77Q/triton_poi_fused_new_zeros_1.json @@ -0,0 +1 @@ +{"hash": "bf245151a76d55e2c7c292294e25034deaf6226cb38e732a25dc0548371f67ff", "target": {"backend": "cuda", "arch": 90, "warp_size": 32}, "num_warps": 4, "num_ctas": 1, "num_stages": 1, "warp_size": 32, "maxnreg": null, "cluster_dims": [1, 1, 1], "ptx_version": null, "ptx_options": null, "ir_override": null, "enable_fp_fusion": true, "launch_cooperative_grid": false, "launch_pdl": false, "supported_fp8_dtypes": ["fp8e4b15", "fp8e4nv", "fp8e5"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b15"], "default_dot_input_precision": "tf32", "allowed_dot_input_precisions": ["tf32", "tf32x3", "ieee"], "max_num_imprecise_acc_default": 1073741824, "extern_libs": [["libdevice", "/workspace/specforge/lib/python3.11/site-packages/triton/backends/nvidia/lib/libdevice.10.bc"]], "debug": true, "backend_name": "cuda", "sanitize_overflow": false, "arch": "sm90", "instrumentation_mode": "", "triton_version": "3.5.1", "tensordesc_meta": [], "shared": 0, "tmem_size": 0, "global_scratch_size": 0, "global_scratch_align": 1, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "triton_poi_fused_new_zeros_1"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/ts/ctsovtz6kkoo5eiywzgs22dkumqa4mozeu7kd76grusyr4wk3sow.py b/SpecForge-ext/cache/compiled_kernels/ts/ctsovtz6kkoo5eiywzgs22dkumqa4mozeu7kd76grusyr4wk3sow.py new file mode 100644 index 0000000000000000000000000000000000000000..ff388c538c74dabd2581ee812395b94e1aa7512e --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ts/ctsovtz6kkoo5eiywzgs22dkumqa4mozeu7kd76grusyr4wk3sow.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 16 + stride_q_idx_h = 256 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/SpecForge-ext/cache/compiled_kernels/u6/cu6afr5nhqqd5bad2maydqx4neuwqlinvhb45iucheh6i2ebqqmn.py b/SpecForge-ext/cache/compiled_kernels/u6/cu6afr5nhqqd5bad2maydqx4neuwqlinvhb45iucheh6i2ebqqmn.py new file mode 100644 index 0000000000000000000000000000000000000000..0ffbd006fbf6798ff89b91348534255e2c56bc28 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/u6/cu6afr5nhqqd5bad2maydqx4neuwqlinvhb45iucheh6i2ebqqmn.py @@ -0,0 +1,72 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_add_div_expand_mul_pow_sum_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp8 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr2 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tmp2.to(tl.float32) + tmp5 = tmp4.to(tl.float32) + tmp6 = tmp3 * tmp5 + tmp7 = tl.broadcast_to(tmp6, [XBLOCK, R0_BLOCK]) + tmp9 = _tmp8 + tmp7 + _tmp8 = tl.where(r0_mask & xmask, tmp9, _tmp8) + tmp8 = tl.sum(_tmp8, 1)[:, None] + tmp14 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last') + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp10 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp11 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp24 = tl.load(in_ptr2 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp12 = tmp10 * tmp11 + tmp13 = tmp12.to(tl.float32) + tmp15 = tmp13 * tmp14 + tmp16 = -0.5 + tmp17 = tmp8 * tmp16 + tmp18 = tmp14 * tmp14 + tmp19 = tmp18 * tmp14 + tmp20 = tmp17 * tmp19 + tmp21 = ks0 + tmp22 = tmp21.to(tl.float32) + tmp23 = (tmp20 / tmp22) + tmp25 = tmp24.to(tl.float32) + tmp26 = 2.0 + tmp27 = tmp25 * tmp26 + tmp28 = tmp23 * tmp27 + tmp29 = tmp15 + tmp28 + tmp30 = tmp29.to(tl.float32) + tl.store(out_ptr1 + (r0_1 + ks0*x0), tmp30, r0_mask & xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/uc/cucsqr3b7upw3tq57br5dklnxe2ga3abu5v244ypr5tyktygrfzf.py b/SpecForge-ext/cache/compiled_kernels/uc/cucsqr3b7upw3tq57br5dklnxe2ga3abu5v244ypr5tyktygrfzf.py new file mode 100644 index 0000000000000000000000000000000000000000..f8baa735eaedd637b861724d459024bf4ee5c39c --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/uc/cucsqr3b7upw3tq57br5dklnxe2ga3abu5v244ypr5tyktygrfzf.py @@ -0,0 +1,1065 @@ +# AOT ID: ['9_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bl/cbl5w3gbsoiwosnwknfarb55lklpfntpuz6q4jjui35yi2wdwepo.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:7" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 262144, 128, 1]cuda:7" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[8, 32, 2048][65536, 2048, 1]cuda:7" = PlaceHolder[target=buf0] +# %full_default : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:7, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_3, %primals_5, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, %primals_8, %primals_9, %primals_7, %primals_11, %primals_13, %primals_15, %primals_17, %primals_19, %primals_21, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_10,)), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_zeros_0 = async_compile.triton('triton_red_fused_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 524288, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 4194304, 'r0_': 268435456}} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 524288 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = ((xindex // 2048) % 32) + x2 = xindex // 65536 + x4 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bt/cbtjqwfnxxsdpc3bcwbd3t3jzev3hcwuokttvgumrw2z2aszsthx.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:7" = PlaceHolder[target=primals_1] +# %primals_3 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:7" = PlaceHolder[target=primals_3] +# %primals_5 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:7" = PlaceHolder[target=primals_5] +# %getitem_1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:7" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:7" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 262144, 128, 1]cuda:7" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:7" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:7" = PlaceHolder[target=getitem_5] +# %primals_9 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:7" = PlaceHolder[target=primals_9] +# %primals_7 : Tensor "i32[8, 1, 16, s72][16*s72, 16*s72, s72, 1]cuda:7" = PlaceHolder[target=primals_7] +# %primals_15 : Tensor "i32[8, 1, s56][s56, s56, 1]cuda:7" = PlaceHolder[target=primals_15] +# %primals_17 : Tensor "i32[8, 1, s84, 16][16*s84, 16*s84, 16, 1]cuda:7" = PlaceHolder[target=primals_17] +# %primals_11 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:7" = PlaceHolder[target=primals_11] +# %primals_13 : Tensor "i32[8, 1, 16, s4][16*s4, 16*s4, s4, 1]cuda:7" = PlaceHolder[target=primals_13] +# %primals_19 : Tensor "i32[8, 1, s99][s99, s99, 1]cuda:7" = PlaceHolder[target=primals_19] +# %primals_21 : Tensor "i32[8, 1, s6, 16][16*s6, 16*s6, 16, 1]cuda:7" = PlaceHolder[target=primals_21] +# %primals_10 : Tensor "i64[8][1]cuda:7" = PlaceHolder[target=primals_10] +# %full_default : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:7, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_3, %primals_5, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, %primals_8, %primals_9, %primals_7, %primals_11, %primals_13, %primals_15, %primals_17, %primals_19, %primals_21, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_10,)), kwargs = {}) +# return %getitem_4 +triton_tem_fused_zeros_1 = async_compile.triton('triton_tem_fused_zeros_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks2 + stride_q_idx_h = 16*ks3 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_8, primals_6, primals_12, primals_14, primals_16, primals_18, primals_20, primals_1, primals_3, primals_5, primals_7, primals_9, primals_10, primals_11, primals_13, primals_15, primals_17, primals_19, primals_21, getitem, getitem_1, tangents_1 = args + args.clear() + s0 = primals_8 + s72 = primals_6 + s4 = primals_12 + s56 = primals_14 + s84 = primals_16 + s99 = primals_18 + s6 = primals_20 + assert_size_stride(primals_1, (8, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_3, (8, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_5, (8, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_7, (8, 1, 16, s72), (16*s72, 16*s72, s72, 1)) + assert_size_stride(primals_9, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (8, ), (1, )) + assert_size_stride(primals_11, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_13, (8, 1, 16, s4), (16*s4, 16*s4, s4, 1)) + assert_size_stride(primals_15, (8, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_17, (8, 1, s84, 16), (16*s84, 16*s84, 16, 1)) + assert_size_stride(primals_19, (8, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_21, (8, 1, s6, 16), (16*s6, 16*s6, 16, 1)) + assert_size_stride(getitem, (8, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(getitem_1, (8, 32, 2048), (65536, 2048, 1)) + assert_size_stride(tangents_1, (8, 32, 2048, 128), (8388608, 262144, 128, 1)) + with torch.cuda._DeviceGuard(7): + torch.cuda.set_device(7) + buf1 = empty_strided_cuda((8, 32, 2048), (65536, 2048, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream7 = get_raw_stream(7) + triton_red_fused_zeros_0.run(getitem, tangents_1, buf1, 524288, 128, stream=stream7) + del getitem + buf3 = empty_strided_cuda((8, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((8, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16) + buf5 = empty_strided_cuda((8, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream7 = get_raw_stream(7) + triton_tem_fused_zeros_1.run(primals_1, primals_3, primals_5, getitem_1, buf1, tangents_1, buf3, buf4, primals_9, primals_7, primals_15, primals_17, primals_11, primals_13, primals_19, primals_21, primals_10, buf5, s0, s72, s56, s84, 64 + ((127 + s0) // 128), 8, 8, stream=stream7) + del buf1 + del getitem_1 + del primals_1 + del primals_10 + del primals_11 + del primals_13 + del primals_15 + del primals_17 + del primals_19 + del primals_21 + del primals_3 + del primals_5 + del primals_7 + del primals_9 + del tangents_1 + return (buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_8 = 4096 + primals_6 = 32 + primals_12 = 32 + primals_14 = 32 + primals_16 = 32 + primals_18 = 32 + primals_20 = 32 + primals_1 = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16) + primals_3 = rand_strided((8, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:7', dtype=torch.bfloat16) + primals_5 = rand_strided((8, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:7', dtype=torch.bfloat16) + primals_7 = rand_strided((8, 1, 16, 32), (512, 512, 32, 1), device='cuda:7', dtype=torch.int32) + primals_9 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:7', dtype=torch.int32) + primals_10 = rand_strided((8, ), (1, ), device='cuda:7', dtype=torch.int64) + primals_11 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:7', dtype=torch.int32) + primals_13 = rand_strided((8, 1, 16, 32), (512, 512, 32, 1), device='cuda:7', dtype=torch.int32) + primals_15 = rand_strided((8, 1, 32), (32, 32, 1), device='cuda:7', dtype=torch.int32) + primals_17 = rand_strided((8, 1, 32, 16), (512, 512, 16, 1), device='cuda:7', dtype=torch.int32) + primals_19 = rand_strided((8, 1, 32), (32, 32, 1), device='cuda:7', dtype=torch.int32) + primals_21 = rand_strided((8, 1, 32, 16), (512, 512, 16, 1), device='cuda:7', dtype=torch.int32) + getitem = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16) + getitem_1 = rand_strided((8, 32, 2048), (65536, 2048, 1), device='cuda:7', dtype=torch.float32) + tangents_1 = rand_strided((8, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:7', dtype=torch.bfloat16) + fn = lambda: call([primals_8, primals_6, primals_12, primals_14, primals_16, primals_18, primals_20, primals_1, primals_3, primals_5, primals_7, primals_9, primals_10, primals_11, primals_13, primals_15, primals_17, primals_19, primals_21, getitem, getitem_1, tangents_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/us/973c10a7a62b75138eb39aef7d46f0e791d5734316944ad60d1884c4a455acd7.best_config b/SpecForge-ext/cache/compiled_kernels/us/973c10a7a62b75138eb39aef7d46f0e791d5734316944ad60d1884c4a455acd7.best_config new file mode 100644 index 0000000000000000000000000000000000000000..2a95815b49cfc301dd2a3d06bb1b105b04bfbae7 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/us/973c10a7a62b75138eb39aef7d46f0e791d5734316944ad60d1884c4a455acd7.best_config @@ -0,0 +1 @@ +{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "XAIV2GWX5UZL7NNOCKNWC2I6AATKI6664P6FTQPRXS2M4AR4WJWA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/us/cuss5perekriv5zlubnk52f3pej5qclzrmta7cvidhkpzhupmvt5.py b/SpecForge-ext/cache/compiled_kernels/us/cuss5perekriv5zlubnk52f3pej5qclzrmta7cvidhkpzhupmvt5.py new file mode 100644 index 0000000000000000000000000000000000000000..153e42fcce0656cc3fe2e2fe2b8bd408e87dd4dd --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/us/cuss5perekriv5zlubnk52f3pej5qclzrmta7cvidhkpzhupmvt5.py @@ -0,0 +1,24 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 2048}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/uv/cuv3vrtusq2q2nsfbjisxnso2yv7wfwd5g5wzneucl6wsg7qdu22.py b/SpecForge-ext/cache/compiled_kernels/uv/cuv3vrtusq2q2nsfbjisxnso2yv7wfwd5g5wzneucl6wsg7qdu22.py new file mode 100644 index 0000000000000000000000000000000000000000..f0a3675ba370136adf57c5c62215e23aea8d7f53 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/uv/cuv3vrtusq2q2nsfbjisxnso2yv7wfwd5g5wzneucl6wsg7qdu22.py @@ -0,0 +1,45 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x0), tmp2, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/uv/cuvrovm6bqb5asyuhufqtqvlrd4crf64b76sboaqgrrpgxqj2krm.py b/SpecForge-ext/cache/compiled_kernels/uv/cuvrovm6bqb5asyuhufqtqvlrd4crf64b76sboaqgrrpgxqj2krm.py new file mode 100644 index 0000000000000000000000000000000000000000..b1240a97b65003a00b4a10231e95c74a229ac8f9 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/uv/cuvrovm6bqb5asyuhufqtqvlrd4crf64b76sboaqgrrpgxqj2krm.py @@ -0,0 +1,72 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2048, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 32768, 'r0_': 0}} +) +@triton.jit +def triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2048 + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = ((xindex // 16) % 16) + x0 = (xindex % 16) + x2 = xindex // 256 + tmp3 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + _tmp29 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x6 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_4 = r0_index // 128 + r0_3 = (r0_index % 128) + tmp0 = r0_4 + 128*x1 + tmp1 = r0_3 + 128*x0 + tmp2 = tmp0 >= tmp1 + tmp4 = tmp1 < tmp3 + tmp5 = tmp0 < tmp3 + tmp6 = tmp4 & tmp5 + tmp7 = tmp2 & tmp6 + tmp8 = tl.full([1, 1], False, tl.int1) + tmp9 = tmp8 | tmp7 + tmp10 = tl.full([1, 1], 2048, tl.int64) + tmp11 = tmp1 >= tmp10 + tmp12 = tmp11 & tmp4 + tmp13 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp14 = (tmp13 % tmp10) + tmp15 = tl.full([1, 1], 0, tl.int32) + tmp16 = tmp14 != tmp15 + tmp17 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp18 = (libdevice.signbit(tmp10) != 0) if (tmp10).dtype is tl.float32 else tmp10 < 0 + tmp19 = tmp17 != tmp18 + tmp20 = tmp16 & tmp19 + tmp21 = tmp14 + tmp10 + tmp22 = tl.where(tmp20, tmp21, tmp14) + tmp23 = tl.full([1, 1], 0, tl.int64) + tmp24 = tmp22 == tmp23 + tmp25 = tmp12 & tmp24 + tmp26 = tmp9 | tmp25 + tmp27 = tmp26.to(tl.int64) + tmp28 = tl.broadcast_to(tmp27, [XBLOCK, R0_BLOCK]) + tmp30 = _tmp29 + tmp28 + _tmp29 = tl.where(r0_mask & xmask, tmp30, _tmp29) + tmp29 = tl.sum(_tmp29, 1)[:, None] + tl.store(out_ptr0 + (x6), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/uy/cuy7ehjdb4zqfsjtt3uzwfmkjssfasiwxrtpfvjxh6awb5o2hesl.py b/SpecForge-ext/cache/compiled_kernels/uy/cuy7ehjdb4zqfsjtt3uzwfmkjssfasiwxrtpfvjxh6awb5o2hesl.py new file mode 100644 index 0000000000000000000000000000000000000000..d03cd0e20758caabb0fc80764fa5c1eca1d77d7e --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/uy/cuy7ehjdb4zqfsjtt3uzwfmkjssfasiwxrtpfvjxh6awb5o2hesl.py @@ -0,0 +1,309 @@ +# AOT ID: ['4_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/m3/cm37uosceekc3rzf227uaqezxryaoqfpjbxdhfyiygonml26lx5x.py +# Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul, x1, x2, neg, cat, mul_1, q_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] +# Source node to ATen node mapping: +# cat => cat +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# mul => mul_24 +# mul_1 => mul_45 +# neg => neg +# q_embed => add_54 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# x1 => slice_1 +# x2 => slice_2 +# Graph fragment: +# %primals_12 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:5" = PlaceHolder[target=primals_12] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:5" = PlaceHolder[target=primals_8] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:5" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:5" = PlaceHolder[target=primals_6] +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %mul_24 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_12, %unsqueeze), kwargs = {}) +# %slice_1 : Tensor "bf16[s48, s34, s9, (s24//2)][s24*s34*s9, s24, s24*s34, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_12, 3, 0, %floordiv), kwargs = {}) +# %slice_2 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s24*s34*s9, s24, s24*s34, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_12, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %neg : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s34*s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), s34*Max(1, s24 - ((s24//2))), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_2,), kwargs = {}) +# %cat : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg, %slice_1], -1), kwargs = {}) +# %mul_45 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat, %unsqueeze_1), kwargs = {}) +# %add_54 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_24, %mul_45), kwargs = {}) +# return %add_54 +triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0 = async_compile.triton('triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ah/cah767udo2rzeazh6rycnirtnr5sijiv7nem2l67isu5iyh5pzyj.py +# Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul_2, x1_1, x2_1, neg_1, cat_1, mul_3, k_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] +# Source node to ATen node mapping: +# cat_1 => cat_1 +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# k_embed => add_90 +# mul_2 => mul_54 +# mul_3 => mul_75 +# neg_1 => neg_1 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# x1_1 => slice_3 +# x2_1 => slice_4 +# Graph fragment: +# %primals_14 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24, s24*s25, 1]cuda:5" = PlaceHolder[target=primals_14] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:5" = PlaceHolder[target=primals_8] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:5" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:5" = PlaceHolder[target=primals_6] +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %mul_54 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24, s24*s25, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_14, %unsqueeze), kwargs = {}) +# %slice_3 : Tensor "bf16[s48, s25, s9, (s24//2)][s24*s25*s9, s24, s24*s25, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_14, 3, 0, %floordiv), kwargs = {}) +# %slice_4 : Tensor "bf16[s48, s25, s9, s24 - ((s24//2))][s24*s25*s9, s24, s24*s25, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_14, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %neg_1 : Tensor "bf16[s48, s25, s9, s24 - ((s24//2))][s25*s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), s25*Max(1, s24 - ((s24//2))), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_4,), kwargs = {}) +# %cat_1 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_1, %slice_3], -1), kwargs = {}) +# %mul_75 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_1, %unsqueeze_1), kwargs = {}) +# %add_90 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24, s24*s25, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_54, %mul_75), kwargs = {}) +# return %add_90 +triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1 = async_compile.triton('triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4194304}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14 = args + args.clear() + s92 = primals_1 + s24 = primals_2 + s96 = primals_3 + s79 = primals_5 + s9 = primals_7 + s38 = primals_9 + s48 = primals_10 + s34 = primals_11 + s25 = primals_13 + assert_size_stride(primals_4, (1, 1, s92, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_6, (1, 1, s79, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_8, (1, s9), (s9, 1)) + assert_size_stride(primals_12, (s48, s34, s9, s24), (s24*s34*s9, s24, s24*s34, 1)) + assert_size_stride(primals_14, (s48, s25, s9, s24), (s24*s25*s9, s24, s24*s25, 1)) + with torch.cuda._DeviceGuard(5): + torch.cuda.set_device(5) + ps0 = s24*s34 + buf0 = empty_strided_cuda((s48, s34, s9, s24), (s24*s34*s9, s24, s24*s34, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul, x1, x2, neg, cat, mul_1, q_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0_xnumel = s24*s34*s48*s9 + stream5 = get_raw_stream(5) + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0.run(primals_12, primals_8, primals_4, primals_6, buf0, ps0, s9, s92, s24, s79, triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0_xnumel, stream=stream5) + del primals_12 + ps1 = s24*s25 + buf1 = empty_strided_cuda((s48, s25, s9, s24), (s24*s25*s9, s24, s24*s25, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul_2, x1_1, x2_1, neg_1, cat_1, mul_3, k_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1_xnumel = s24*s25*s48*s9 + stream5 = get_raw_stream(5) + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1.run(primals_14, primals_8, primals_4, primals_6, buf1, ps1, s9, s92, s24, s79, triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1_xnumel, stream=stream5) + del primals_14 + return (buf0, buf1, primals_4, primals_6, primals_8, s24, s9, s48, s34, s25, s92, s96, s79, s24 // 2, s24 + (-1)*(s24 // 2), ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 2048 + primals_2 = 128 + primals_3 = 5245440 + primals_4 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:5', dtype=torch.bfloat16) + primals_5 = 2048 + primals_6 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:5', dtype=torch.bfloat16) + primals_7 = 2048 + primals_8 = rand_strided((1, 2048), (2048, 1), device='cuda:5', dtype=torch.int64) + primals_9 = 1 + primals_10 = 2 + primals_11 = 32 + primals_12 = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:5', dtype=torch.bfloat16) + primals_13 = 8 + primals_14 = rand_strided((2, 8, 2048, 128), (2097152, 128, 1024, 1), device='cuda:5', dtype=torch.bfloat16) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/vq/cvqvxwsz5trm7yg2d2gcqm3fnjjobjar5tizng43rigkxges3nhj.py b/SpecForge-ext/cache/compiled_kernels/vq/cvqvxwsz5trm7yg2d2gcqm3fnjjobjar5tizng43rigkxges3nhj.py new file mode 100644 index 0000000000000000000000000000000000000000..845ea94c013de0de970609d70de9050cf9d8fb52 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/vq/cvqvxwsz5trm7yg2d2gcqm3fnjjobjar5tizng43rigkxges3nhj.py @@ -0,0 +1,66 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4194304}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/vq/da7d5a6e5bf7aaf86b87ee9c4e2b07bd8e90f2a5f60594acc706c8c25d864b17.best_config b/SpecForge-ext/cache/compiled_kernels/vq/da7d5a6e5bf7aaf86b87ee9c4e2b07bd8e90f2a5f60594acc706c8c25d864b17.best_config new file mode 100644 index 0000000000000000000000000000000000000000..37707241555f35a01f7e4a693e0cda27ae37aab0 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/vq/da7d5a6e5bf7aaf86b87ee9c4e2b07bd8e90f2a5f60594acc706c8c25d864b17.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 22, "triton_cache_hash": "XRR2QXTZQK4DSBTDJUTNXO6FEFXI2IIRKSC5GYSBWLTL56SKI4WA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/vy/cvyoqg7jzeadarrgggxhs2djmxugxyvjhim72vfstqc4io4qsecj.py b/SpecForge-ext/cache/compiled_kernels/vy/cvyoqg7jzeadarrgggxhs2djmxugxyvjhim72vfstqc4io4qsecj.py new file mode 100644 index 0000000000000000000000000000000000000000..fdbe10f36dfcce8b26c0f68442ee9cedd1ebb751 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/vy/cvyoqg7jzeadarrgggxhs2djmxugxyvjhim72vfstqc4io4qsecj.py @@ -0,0 +1,56 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/vy/e1eed843b71c28d52c2313e97c0fc3d0959b52b15c6df3e7663dd70b1ce408b3.best_config b/SpecForge-ext/cache/compiled_kernels/vy/e1eed843b71c28d52c2313e97c0fc3d0959b52b15c6df3e7663dd70b1ce408b3.best_config new file mode 100644 index 0000000000000000000000000000000000000000..5abd9f60ea88e5a77867b391996e5ebb14da0a20 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/vy/e1eed843b71c28d52c2313e97c0fc3d0959b52b15c6df3e7663dd70b1ce408b3.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 51, "triton_cache_hash": "NFABHOURJ57C2IKXWDMS2VHZ76PCVKJVD7V6CBWJDLMT5TQE5GFA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/wb/cwbu7golosgf3flvkirgesrzzw3iq2u2nierku62nsic53n5awn6.py b/SpecForge-ext/cache/compiled_kernels/wb/cwbu7golosgf3flvkirgesrzzw3iq2u2nierku62nsic53n5awn6.py new file mode 100644 index 0000000000000000000000000000000000000000..ef5d9d1305ff3eca23d9e7098e12797b8578ff83 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/wb/cwbu7golosgf3flvkirgesrzzw3iq2u2nierku62nsic53n5awn6.py @@ -0,0 +1,675 @@ +# AOT ID: ['6_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +from torch._C import _cuda_getCurrentRawStream as get_raw_stream +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/kk/ckkonpj4ig4m6kul577movgkkpytb6t5h6kpoun5efcbvgaje63a.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:3" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[8, 8, 2048, 128][2097152, 262144, 128, 1]cuda:3" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[8, 8, 2048, 128][2097152, 262144, 128, 1]cuda:3" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:3" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:3" = PlaceHolder[target=buf1] +# %primals_5 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:3" = PlaceHolder[target=primals_5] +# %primals_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:3" = PlaceHolder[target=primals_4] +# %primals_7 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:3" = PlaceHolder[target=primals_7] +# %primals_8 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:3" = PlaceHolder[target=primals_8] +# %primals_6 : Tensor "i64[8][1]cuda:3" = PlaceHolder[target=primals_6] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_2, %primals_3, %sdpa_score0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %sdpa_mask0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 2097152, 262144, 128, 1 + + ZQ = 8 + HQ = 32 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12 = args + args.clear() + assert_size_stride(primals_1, (8, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_2, (8, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_3, (8, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_4, (8, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_5, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_6, (8, ), (1, )) + assert_size_stride(primals_7, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_8, (8, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_9, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (8, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_11, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_12, (8, 1, 16, 16), (256, 256, 16, 1)) + with torch.cuda._DeviceGuard(3): + torch.cuda.set_device(3) + buf0 = empty_strided_cuda((8, 32, 2048), (65536, 2048, 1), torch.float32) + buf1 = empty_strided_cuda((8, 32, 2048), (65536, 2048, 1), torch.float32) + buf2 = empty_strided_cuda((8, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream3 = get_raw_stream(3) + triton_tem_fused_0.run(primals_1, primals_2, primals_3, buf0, buf1, primals_5, primals_4, primals_7, primals_8, primals_6, buf2, 16, 8, 32, stream=stream3) + del buf1 + return (buf2, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, buf2, buf0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16) + primals_2 = rand_strided((8, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:3', dtype=torch.bfloat16) + primals_3 = rand_strided((8, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:3', dtype=torch.bfloat16) + primals_4 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:3', dtype=torch.int32) + primals_5 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32) + primals_6 = rand_strided((8, ), (1, ), device='cuda:3', dtype=torch.int64) + primals_7 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32) + primals_8 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:3', dtype=torch.int32) + primals_9 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32) + primals_10 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:3', dtype=torch.int32) + primals_11 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32) + primals_12 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:3', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/we/6f4fcddd234315db405c51971ff50a00640ddd0931ecee15eec6ac158d802eb8.best_config b/SpecForge-ext/cache/compiled_kernels/we/6f4fcddd234315db405c51971ff50a00640ddd0931ecee15eec6ac158d802eb8.best_config new file mode 100644 index 0000000000000000000000000000000000000000..2a95815b49cfc301dd2a3d06bb1b105b04bfbae7 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/we/6f4fcddd234315db405c51971ff50a00640ddd0931ecee15eec6ac158d802eb8.best_config @@ -0,0 +1 @@ +{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "XAIV2GWX5UZL7NNOCKNWC2I6AATKI6664P6FTQPRXS2M4AR4WJWA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/we/7292f5a4e0fd1eab0f63dfe7b8d7dc958ac540517dc0993e48d61a9bc1d20ea3.best_config b/SpecForge-ext/cache/compiled_kernels/we/7292f5a4e0fd1eab0f63dfe7b8d7dc958ac540517dc0993e48d61a9bc1d20ea3.best_config new file mode 100644 index 0000000000000000000000000000000000000000..c2d9b36c5180887fa413aa1eb230c04dc216dd00 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/we/7292f5a4e0fd1eab0f63dfe7b8d7dc958ac540517dc0993e48d61a9bc1d20ea3.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "6fcabd0411a839b7b5d117b5e6638bd1b5d7bc3379312c678d803859f08278a9", "found_by_coordesc": false, "time_taken_ms": 18, "triton_cache_hash": "G2LU7LIHIOEHQSWVLFBJATACJ76YHM672CUBUDGJGAJUEQVWVOFQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/we/cwe54qdzud6xskhvjsubzqbviobtofi7rtla3cz3fvufmekfy4qf.py b/SpecForge-ext/cache/compiled_kernels/we/cwe54qdzud6xskhvjsubzqbviobtofi7rtla3cz3fvufmekfy4qf.py new file mode 100644 index 0000000000000000000000000000000000000000..57a06876dfe6e337da122cdfbe37a56ee51ac0c9 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/we/cwe54qdzud6xskhvjsubzqbviobtofi7rtla3cz3fvufmekfy4qf.py @@ -0,0 +1,49 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 64, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % ks0) + x1 = xindex // ks0 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (r0_2 + x0 + 16*x1 + ks0*r0_2 + 16*ks0*x1), xmask, eviction_policy='evict_last', other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x0 + 16*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp13, xmask) + tl.store(out_ptr3 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp14, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/we/cwempdtzfn3w5llag62nr4u5sjyacj7sgj6eqzt57fxgjrbhrohq.py b/SpecForge-ext/cache/compiled_kernels/we/cwempdtzfn3w5llag62nr4u5sjyacj7sgj6eqzt57fxgjrbhrohq.py new file mode 100644 index 0000000000000000000000000000000000000000..1d94811990382e08c61e15cfbb1f7e7cb6a23b40 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/we/cwempdtzfn3w5llag62nr4u5sjyacj7sgj6eqzt57fxgjrbhrohq.py @@ -0,0 +1,57 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 262144}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i64', 'r0_numel': 'i64', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 151936 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0).to(tl.int64) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None].to(tl.int64) + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :].to(tl.int64) + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 9223372036854775807, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tmp11 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last') + tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32) + tmp4 = tmp2 + tmp3 + tmp5 = tmp2 < 0 + tmp6 = tl.where(tmp5, tmp4, tmp2) + tl.device_assert((0 <= tmp6) & (tmp6 < 151936), "index out of bounds: 0 <= tmp6 < 151936") + tmp8 = tl.load(in_ptr1 + (tmp6), None, eviction_policy='evict_last').to(tl.int1) + tmp9 = tmp8.to(tl.int32) + tmp10 = tmp9.to(tl.int64) + tmp12 = tmp10 * tmp11 + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, None) diff --git a/SpecForge-ext/cache/compiled_kernels/we/cwewx4zqbwcobeyw3qdqbgpgixwektloukjczxecy43rn53wthsl.py b/SpecForge-ext/cache/compiled_kernels/we/cwewx4zqbwcobeyw3qdqbgpgixwektloukjczxecy43rn53wthsl.py new file mode 100644 index 0000000000000000000000000000000000000000..707d887b246b5336597da3341787851fac9719f7 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/we/cwewx4zqbwcobeyw3qdqbgpgixwektloukjczxecy43rn53wthsl.py @@ -0,0 +1,24 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 2048}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/wh/25ad0e5e5fedee4a89ef20e8e1554521c460359e0e0891dd290a117f21a569e7.best_config b/SpecForge-ext/cache/compiled_kernels/wh/25ad0e5e5fedee4a89ef20e8e1554521c460359e0e0891dd290a117f21a569e7.best_config new file mode 100644 index 0000000000000000000000000000000000000000..5bb8c17ba21d4d24155a8e40c195fe8eab145e8c --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/wh/25ad0e5e5fedee4a89ef20e8e1554521c460359e0e0891dd290a117f21a569e7.best_config @@ -0,0 +1 @@ +{"XBLOCK": 8, "R0_BLOCK": 16, "num_warps": 2, "num_stages": 1, "configs_hash": "9889a3900cf19f2f3cbdf50dfff07c1cd9bb504be42b4c95a8b2b6f156e5f333", "found_by_coordesc": false, "time_taken_ms": 34, "triton_cache_hash": "JOVMEF5UB3XPAO2GMKSDIGWCQSZP2YWR3UL465XFLBODYHCLBU5A"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/wh/cwhgu2bzfksecbm4int3okkh3ph54kdms6zfad7ngpfzq6h2owww.py b/SpecForge-ext/cache/compiled_kernels/wh/cwhgu2bzfksecbm4int3okkh3ph54kdms6zfad7ngpfzq6h2owww.py new file mode 100644 index 0000000000000000000000000000000000000000..9b75b298818245f5359af2171ef84d6e1137363c --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/wh/cwhgu2bzfksecbm4int3okkh3ph54kdms6zfad7ngpfzq6h2owww.py @@ -0,0 +1,43 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 128, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_clone_slice_sum_transpose_5', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_clone_slice_sum_transpose_5(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (x0 + ks0*r0_2 + ks0*ks1*x1), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp5, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/wj/cwjcr6usw2eqonbimd2nccontt4sk63dqadamrrm5fhq6kprylb3.py b/SpecForge-ext/cache/compiled_kernels/wj/cwjcr6usw2eqonbimd2nccontt4sk63dqadamrrm5fhq6kprylb3.py new file mode 100644 index 0000000000000000000000000000000000000000..e8b6b89f5196e0d6931cdad8ecbf4822fac020c5 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/wj/cwjcr6usw2eqonbimd2nccontt4sk63dqadamrrm5fhq6kprylb3.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 16 + stride_q_idx_h = 256 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/SpecForge-ext/cache/compiled_kernels/wj/cwjvrrote4o3omk3bsyldepk36ghybihiz4p3kpmh42qyiehh54y.py b/SpecForge-ext/cache/compiled_kernels/wj/cwjvrrote4o3omk3bsyldepk36ghybihiz4p3kpmh42qyiehh54y.py new file mode 100644 index 0000000000000000000000000000000000000000..939b768ac33f59b0e4125bd600740671dbd50f06 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/wj/cwjvrrote4o3omk3bsyldepk36ghybihiz4p3kpmh42qyiehh54y.py @@ -0,0 +1,164 @@ +# AOT ID: ['0_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/fa/cfawzdo3q32syzk5d3t3mjridjbalgrkptn5qwko7qnup25mzrum.py +# Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul] +# Source node to ATen node mapping: +# getitem_1 => unsqueeze +# position_mask => mul +# target_mask => index +# target_mask_1 => convert_element_type +# target_max_token => argmax +# Graph fragment: +# %arg0_1 : Tensor "bf16[8, 2048, 151936][311164928, 151936, 1]cuda:4" = PlaceHolder[target=arg0_1] +# %argmax : Tensor "i64[8, 2048][2048, 1]cuda:4" = PlaceHolder[target=argmax] +# %arg1_1 : Tensor "b8[151936][1]cuda:4" = PlaceHolder[target=arg1_1] +# %arg2_1 : Tensor "i64[8, 2048, 1][2048, 1, 1]cuda:4" = PlaceHolder[target=arg2_1] +# %argmax : Tensor "i64[8, 2048][2048, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg0_1, -1), kwargs = {}) +# %index : Tensor "b8[8, 2048][2048, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg1_1, [%argmax]), kwargs = {}) +# %unsqueeze : Tensor "b8[8, 2048, 1][2048, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 2), kwargs = {}) +# %convert_element_type : Tensor "i32[8, 2048, 1][2048, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze, torch.int32), kwargs = {}) +# %mul : Tensor "i64[8, 2048, 1][2048, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %arg2_1), kwargs = {}) +# return %argmax,%mul +triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0 = async_compile.triton('triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 262144}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i64', 'r0_numel': 'i64', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 151936 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0).to(tl.int64) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None].to(tl.int64) + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :].to(tl.int64) + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 9223372036854775807, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tmp11 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last') + tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32) + tmp4 = tmp2 + tmp3 + tmp5 = tmp2 < 0 + tmp6 = tl.where(tmp5, tmp4, tmp2) + tl.device_assert((0 <= tmp6) & (tmp6 < 151936), "index out of bounds: 0 <= tmp6 < 151936") + tmp8 = tl.load(in_ptr1 + (tmp6), None, eviction_policy='evict_last').to(tl.int1) + tmp9 = tmp8.to(tl.int32) + tmp10 = tmp9.to(tl.int64) + tmp12 = tmp10 * tmp11 + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1 = args + args.clear() + assert_size_stride(arg0_1, (8, 2048, 151936), (311164928, 151936, 1)) + assert_size_stride(arg1_1, (151936, ), (1, )) + assert_size_stride(arg2_1, (8, 2048, 1), (2048, 1, 1)) + with torch.cuda._DeviceGuard(4): + torch.cuda.set_device(4) + buf0 = empty_strided_cuda((8, 2048), (2048, 1), torch.int64) + buf1 = reinterpret_tensor(buf0, (8, 2048, 1), (2048, 1, 1), 0); del buf0 # reuse + # Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul] + stream4 = get_raw_stream(4) + triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0.run(buf1, arg0_1, arg1_1, arg2_1, 16384, 151936, stream=stream4) + del arg0_1 + del arg1_1 + del arg2_1 + return (buf1, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((8, 2048, 151936), (311164928, 151936, 1), device='cuda:4', dtype=torch.bfloat16) + arg1_1 = rand_strided((151936, ), (1, ), device='cuda:4', dtype=torch.bool) + arg2_1 = rand_strided((8, 2048, 1), (2048, 1, 1), device='cuda:4', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1, arg2_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/wp/a4f8209267ace68281839acf75a9de7a751a4f87172bf04bfd62d4d60e085af1.best_config b/SpecForge-ext/cache/compiled_kernels/wp/a4f8209267ace68281839acf75a9de7a751a4f87172bf04bfd62d4d60e085af1.best_config new file mode 100644 index 0000000000000000000000000000000000000000..a7b716fb95af35565b080c5a7e63797d5b74e434 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/wp/a4f8209267ace68281839acf75a9de7a751a4f87172bf04bfd62d4d60e085af1.best_config @@ -0,0 +1 @@ +{"XBLOCK": 64, "R0_BLOCK": 64, "num_warps": 16, "num_stages": 1, "configs_hash": "48464ea7d171263ae4fed5184e32a30841f1081b8df295ec1f8e2f76e5287c9d", "found_by_coordesc": false, "time_taken_ms": 52, "triton_cache_hash": "BXWZSSWKBTIG7YDOE6QDLF3DYUHLUN57GPEDYW37ZDRQO2XWRGCQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/wp/cwp3fb3c4kq3jjlj5mue2pouzku7f5r3znogrbfwtcofaovvgmqa.py b/SpecForge-ext/cache/compiled_kernels/wp/cwp3fb3c4kq3jjlj5mue2pouzku7f5r3znogrbfwtcofaovvgmqa.py new file mode 100644 index 0000000000000000000000000000000000000000..89c4f268b3b21f1a61baebfbfd85dc457d1f8477 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/wp/cwp3fb3c4kq3jjlj5mue2pouzku7f5r3znogrbfwtcofaovvgmqa.py @@ -0,0 +1,66 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/wp/cwpp2ogxi4ziv4d6g6hohwssakk6mbdlaj4nklq5voabubccx3d6.py b/SpecForge-ext/cache/compiled_kernels/wp/cwpp2ogxi4ziv4d6g6hohwssakk6mbdlaj4nklq5voabubccx3d6.py new file mode 100644 index 0000000000000000000000000000000000000000..f5a7eddfdabd42c8c0568cf45f336aef1f9b1d5e --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/wp/cwpp2ogxi4ziv4d6g6hohwssakk6mbdlaj4nklq5voabubccx3d6.py @@ -0,0 +1,49 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 131072, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % 32) + x2 = xindex // ks1 + x5 = triton_helpers.div_floor_integer(xindex, ks0) + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x4 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 4096*ks0*x2), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x0 + 128*x5*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/wp/e868346e0ceff5692ab30fc1f135f50b44539499cf0fab3da8a9db4d8265c961.best_config b/SpecForge-ext/cache/compiled_kernels/wp/e868346e0ceff5692ab30fc1f135f50b44539499cf0fab3da8a9db4d8265c961.best_config new file mode 100644 index 0000000000000000000000000000000000000000..65940914d216d062f2725b003671f86bc9112489 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/wp/e868346e0ceff5692ab30fc1f135f50b44539499cf0fab3da8a9db4d8265c961.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 56, "triton_cache_hash": "XRR2QXTZQK4DSBTDJUTNXO6FEFXI2IIRKSC5GYSBWLTL56SKI4WA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/wr/cwrt3cdfiri2z4jso4afypedtru4cdebpo556yzgrqawlufswk26.py b/SpecForge-ext/cache/compiled_kernels/wr/cwrt3cdfiri2z4jso4afypedtru4cdebpo556yzgrqawlufswk26.py new file mode 100644 index 0000000000000000000000000000000000000000..82e6fdbfe09f689f48bc5ff81759c7c06c90c7f6 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/wr/cwrt3cdfiri2z4jso4afypedtru4cdebpo556yzgrqawlufswk26.py @@ -0,0 +1,72 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2048, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 32768, 'r0_': 0}} +) +@triton.jit +def triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2048 + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = ((xindex // 16) % 16) + x0 = (xindex % 16) + x2 = xindex // 256 + tmp3 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + _tmp29 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x6 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_4 = r0_index // 128 + r0_3 = (r0_index % 128) + tmp0 = r0_4 + 128*x1 + tmp1 = r0_3 + 128*x0 + tmp2 = tmp0 >= tmp1 + tmp4 = tmp1 < tmp3 + tmp5 = tmp0 < tmp3 + tmp6 = tmp4 & tmp5 + tmp7 = tmp2 & tmp6 + tmp8 = tl.full([1, 1], False, tl.int1) + tmp9 = tmp8 | tmp7 + tmp10 = tl.full([1, 1], 2048, tl.int64) + tmp11 = tmp1 >= tmp10 + tmp12 = tmp11 & tmp4 + tmp13 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp14 = (tmp13 % tmp10) + tmp15 = tl.full([1, 1], 0, tl.int32) + tmp16 = tmp14 != tmp15 + tmp17 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp18 = (libdevice.signbit(tmp10) != 0) if (tmp10).dtype is tl.float32 else tmp10 < 0 + tmp19 = tmp17 != tmp18 + tmp20 = tmp16 & tmp19 + tmp21 = tmp14 + tmp10 + tmp22 = tl.where(tmp20, tmp21, tmp14) + tmp23 = tl.full([1, 1], 0, tl.int64) + tmp24 = tmp22 == tmp23 + tmp25 = tmp12 & tmp24 + tmp26 = tmp9 | tmp25 + tmp27 = tmp26.to(tl.int64) + tmp28 = tl.broadcast_to(tmp27, [XBLOCK, R0_BLOCK]) + tmp30 = _tmp29 + tmp28 + _tmp29 = tl.where(r0_mask & xmask, tmp30, _tmp29) + tmp29 = tl.sum(_tmp29, 1)[:, None] + tl.store(out_ptr0 + (x6), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/wu/cwumjzwhdyx25xmlkkoa6qowv4syi25q37ww7oouqscmqihsepmm.py b/SpecForge-ext/cache/compiled_kernels/wu/cwumjzwhdyx25xmlkkoa6qowv4syi25q37ww7oouqscmqihsepmm.py new file mode 100644 index 0000000000000000000000000000000000000000..92674fe9e67b555f174391e6e3c908165e647f9f --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/wu/cwumjzwhdyx25xmlkkoa6qowv4syi25q37ww7oouqscmqihsepmm.py @@ -0,0 +1,161 @@ +# AOT ID: ['11_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ao/caoqvgzvbk7exhnvkuijsznlx2ebywfk6vitynyaomz5hgx5szk5.py +# Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax] +# Source node to ATen node mapping: +# target_head => convert_element_type +# target_p => div +# Graph fragment: +# %arg1_1 : Tensor "bf16[2, s67, 32000][32000*s67, 32000, 1]cuda:3" = PlaceHolder[target=arg1_1] +# %getitem : Tensor "f32[2, s67, 1][s67, 1, 2*s67]cuda:3" = PlaceHolder[target=getitem] +# %getitem_1 : Tensor "f32[2, s67, 1][s67, 1, 2*s67]cuda:3" = PlaceHolder[target=getitem_1] +# %convert_element_type : Tensor "f32[2, s67, 32000][32000*s67, 32000, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%arg1_1, torch.float32), kwargs = {}) +# %prepare_softmax_online_default : [num_users=2] = call_function[target=torch.ops.prims.prepare_softmax_online.default](args = (%convert_element_type, 2), kwargs = {}) +# %sub_tensor : Tensor "f32[2, s67, 32000][32000*s67, 32000, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type, %getitem), kwargs = {}) +# %exp_default : Tensor "f32[2, s67, 32000][32000*s67, 32000, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub_tensor,), kwargs = {}) +# %div : Tensor "f32[2, s67, 32000][32000*s67, 32000, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%exp_default, %getitem_1), kwargs = {}) +# return %getitem,%getitem_1,%div +triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0 = async_compile.triton('triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32) + _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + + _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine( + _tmp3_max, _tmp3_sum, tmp2, False + ) + + _tmp3_max = tl.where(r0_mask & xmask, _tmp3_max_next, _tmp3_max) + _tmp3_sum = tl.where(r0_mask & xmask, _tmp3_sum_next, _tmp3_sum) + + tmp3, tmp4 = triton_helpers.online_softmax_reduce( + _tmp3_max, _tmp3_sum, 1, False) + tmp3 = tmp3[:, None] + tmp4 = tmp4[:, None] + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp6 - tmp3 + tmp8 = libdevice.exp(tmp7) + tmp9 = (tmp8 / tmp4) + tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask & xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1 = args + args.clear() + s67 = arg0_1 + assert_size_stride(arg1_1, (2, s67, 32000), (32000*s67, 32000, 1)) + with torch.cuda._DeviceGuard(3): + torch.cuda.set_device(3) + buf2 = empty_strided_cuda((2, s67, 32000), (32000*s67, 32000, 1), torch.float32) + # Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax] + triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0_xnumel = 2*s67 + stream3 = get_raw_stream(3) + triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0.run(arg1_1, buf2, triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0_xnumel, 32000, stream=stream3) + del arg1_1 + return (buf2, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 2014 + arg1_1 = rand_strided((2, 2014, 32000), (64448000, 32000, 1), device='cuda:3', dtype=torch.bfloat16) + fn = lambda: call([arg0_1, arg1_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/wx/cwxwytckobiifa2ldje64mqhei52v7qeuhackau5a2g32rrtzau7.py b/SpecForge-ext/cache/compiled_kernels/wx/cwxwytckobiifa2ldje64mqhei52v7qeuhackau5a2g32rrtzau7.py new file mode 100644 index 0000000000000000000000000000000000000000..568b0556de400f5568a3eb298d1bb5f2d8b6ea56 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/wx/cwxwytckobiifa2ldje64mqhei52v7qeuhackau5a2g32rrtzau7.py @@ -0,0 +1,682 @@ +# AOT ID: ['12_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/j6/cj6lb2lwab43zvel34z3wsdzgjns7efwyvjd2ycexqj3bnayivh6.py +# Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] +# Source node to ATen node mapping: +# dense_mask_2 => full_default_1 +# Graph fragment: +# %full_default_1 : Tensor "i32[2, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, %floordiv_3, %add_201], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:3, pin_memory: False}) +# return %index_put +triton_poi_fused_new_zeros_0 = async_compile.triton('triton_poi_fused_new_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 1024}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/av/cavp7xan77tfr7qytfkp6sjrgkd6hvruiaqfzkeibtl5rtagscng.py +# Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_1, mask_2, mask_3, mask_block_sum, gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, full_blocks, full_blocks_1, dense_mask_1], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.constant_pad_nd, aten.permute, aten.sum, aten.gt, aten._to_copy] +# Source node to ATen node mapping: +# and_2 => bitwise_and_1 +# and_3 => bitwise_and_2 +# and_4 => bitwise_and_3, view_8 +# b => iota +# batched_outputs_2 => view_9 +# causal_mask => ge_2, view +# dense_mask => convert_element_type_2 +# dense_mask_1 => convert_element_type_5 +# diagnol_mask => eq_24 +# full_blocks => eq_45 +# full_blocks_1 => convert_element_type_1 +# gt => gt +# index => index +# index_1 => index_1 +# index_2 => index_2 +# lt => lt, view_1 +# lt_1 => lt_1, view_2 +# lt_3 => lt_3 +# m => iota_2 +# mask_1 => constant_pad_nd +# mask_2 => view_10 +# mask_3 => permute +# mask_block_sum => sum_1 +# n => iota_3 +# padding_mask => bitwise_and, view_3, view_4 +# padding_mask_1 => lt_2, view_6 +# partial_blocks => bitwise_and_4 +# partial_blocks_1 => convert_element_type +# remainder => remainder +# remainder_1 => remainder_1 +# result_1 => bitwise_or, full_default +# result_2 => bitwise_or_1 +# sub => sub_24, view_7 +# suffix_mask => ge_3 +# Graph fragment: +# %arg2_1 : Tensor "i64[2][1]cuda:3" = PlaceHolder[target=arg2_1] +# %sum_1 : Tensor "i64[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][(((s12 + 127)//128))*(((s37 + 127)//128)), 2*(((s12 + 127)//128))*(((s37 + 127)//128)), ((s37 + 127)//128), 1]cuda:3" = PlaceHolder[target=sum_1] +# %full_default : Tensor "b8[2, 1, 1][1, 1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 1], False), kwargs = {dtype: torch.bool, layout: torch.strided, device: cuda:3, pin_memory: False}) +# %iota_2 : Tensor "i64[s12][1]cuda:3"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (%arg0_1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:3, requires_grad: False}) +# %view : Tensor "i64[s12, 1][1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [%arg0_1, 1]), kwargs = {}) +# %iota_3 : Tensor "i64[s37][1]cuda:3"[num_users=5] = call_function[target=torch.ops.prims.iota.default](args = (%arg1_1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:3, requires_grad: False}) +# %ge_2 : Tensor "b8[s12, s37][Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.ge.Tensor](args = (%view, %iota_3), kwargs = {}) +# %iota : Tensor "i64[2][1]cuda:3"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (2,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:3, requires_grad: False}) +# %index : Tensor "i64[2][1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%iota]), kwargs = {}) +# %view_1 : Tensor "i64[2, 1][1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index, [2, 1]), kwargs = {}) +# %lt : Tensor "b8[2, s37][Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_3, %view_1), kwargs = {}) +# %view_4 : Tensor "b8[2, 1, s37][Max(1, s37), s37, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt, [2, 1, %arg1_1]), kwargs = {}) +# %index_1 : Tensor "i64[2][1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%iota]), kwargs = {}) +# %view_2 : Tensor "i64[2, 1][1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_1, [2, 1]), kwargs = {}) +# %lt_1 : Tensor "b8[2, s12][Max(1, s12), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_2, %view_2), kwargs = {}) +# %view_3 : Tensor "b8[2, s12, 1][Max(1, s12), 1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt_1, [2, %arg0_1, 1]), kwargs = {}) +# %bitwise_and : Tensor "b8[2, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_4, %view_3), kwargs = {}) +# %bitwise_and_1 : Tensor "b8[2, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_2, %bitwise_and), kwargs = {}) +# %bitwise_or : Tensor "b8[2, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%full_default, %bitwise_and_1), kwargs = {}) +# %ge_3 : Tensor "b8[s37][1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%iota_3, %arg3_1), kwargs = {}) +# %remainder : Tensor "i64[s37][1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%iota_3, %arg3_1), kwargs = {}) +# %index_2 : Tensor "i64[2][1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%iota]), kwargs = {}) +# %view_6 : Tensor "i64[2, 1][1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_2, [2, 1]), kwargs = {}) +# %lt_2 : Tensor "b8[2, s37][Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%remainder, %view_6), kwargs = {}) +# %bitwise_and_2 : Tensor "b8[2, s37][Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_3, %lt_2), kwargs = {}) +# %view_8 : Tensor "b8[2, 1, s37][Max(1, s37), s37, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_and_2, [2, 1, %arg1_1]), kwargs = {}) +# %view_7 : Tensor "i64[s12, 1][1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [%arg0_1, 1]), kwargs = {}) +# %sub_24 : Tensor "i64[s12, s37][Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%iota_3, %view_7), kwargs = {}) +# %remainder_1 : Tensor "i64[s12, s37][Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%sub_24, %arg3_1), kwargs = {}) +# %eq_24 : Tensor "b8[s12, s37][Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%remainder_1, 0), kwargs = {}) +# %bitwise_and_3 : Tensor "b8[2, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_8, %eq_24), kwargs = {}) +# %bitwise_or_1 : Tensor "b8[2, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%bitwise_or, %bitwise_and_3), kwargs = {}) +# %view_9 : Tensor "b8[2, 1, s12, s37][Max(1, s12)*Max(1, s37), s12*Max(1, s37), Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_or_1, [2, 1, %arg0_1, %arg1_1]), kwargs = {}) +# %constant_pad_nd : Tensor "b8[2, 1, 128*(((s12 + 127)//128)), 128*(((s37 + 127)//128))][Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s37 + 127)//128))), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.constant_pad_nd.default](args = (%expand, [0, %sub_42, 0, %sub_44], 0.0), kwargs = {}) +# %view_10 : Tensor "b8[2, 1, ((s12 + 127)//128), 128, ((s37 + 127)//128), 128][Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), 128*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s37 + 127)//128))), 128, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%constant_pad_nd, [2, 1, %floordiv_3, 128, %floordiv_2, 128]), kwargs = {}) +# %permute : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128), 128, 128][Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), 128*Max(1, 128*(((s37 + 127)//128))), 128, Max(1, 128*(((s37 + 127)//128))), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 1, 2, 4, 3, 5]), kwargs = {}) +# %sum_1 : Tensor "i64[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:3"[num_users=3] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute, [-2, -1]), kwargs = {}) +# %gt : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {}) +# %lt_3 : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %bitwise_and_4 : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%gt, %lt_3), kwargs = {}) +# %convert_element_type : Tensor "i8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%bitwise_and_4, torch.int8), kwargs = {}) +# %convert_element_type_2 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:3"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type, torch.int32), kwargs = {}) +# %eq_45 : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %convert_element_type_1 : Tensor "i8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%eq_45, torch.int8), kwargs = {}) +# %convert_element_type_5 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:3"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_1, torch.int32), kwargs = {}) +# return %sum_1,%convert_element_type_2,%convert_element_type_5 +triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1 = async_compile.triton('triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 512, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'ks5': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1(in_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, ks3, ks4, ks5, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = ((xindex // ks0) % ks1) + x0 = (xindex % ks0) + x2 = xindex // ks4 + _tmp46 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x5 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_4 = r0_index // 128 + r0_3 = (r0_index % 128) + tmp0 = r0_4 + 128*x1 + tmp1 = ks2 + tmp2 = tmp0 < tmp1 + tmp3 = r0_3 + 128*x0 + tmp4 = ks3 + tmp5 = tmp3 < tmp4 + tmp6 = tmp2 & tmp5 + tmp7 = r0_4 + 128*x1 + tmp8 = r0_3 + 128*x0 + tmp9 = tmp7 >= tmp8 + tmp10 = tl.load(in_ptr0 + (tl.broadcast_to(x2, [XBLOCK, R0_BLOCK])), r0_mask & tmp6 & xmask, eviction_policy='evict_last', other=0.0) + tmp11 = tmp8 < tmp10 + tmp12 = tmp7 < tmp10 + tmp13 = tmp11 & tmp12 + tmp14 = tmp9 & tmp13 + tmp15 = tl.full([1, 1], False, tl.int1) + tmp16 = tmp15 | tmp14 + tmp17 = tl.broadcast_to(ks5, [XBLOCK, R0_BLOCK]) + tmp18 = tmp8 >= tmp17 + tmp19 = (tmp8 % tmp17) + tmp20 = tl.full([1, 1], 0, tl.int32) + tmp21 = tmp19 != tmp20 + tmp22 = (libdevice.signbit(tmp19) != 0) if (tmp19).dtype is tl.float32 else tmp19 < 0 + tmp23 = (libdevice.signbit(tmp17) != 0) if (tmp17).dtype is tl.float32 else tmp17 < 0 + tmp24 = tmp22 != tmp23 + tmp25 = tmp21 & tmp24 + tmp26 = tmp19 + tmp17 + tmp27 = tl.where(tmp25, tmp26, tmp19) + tmp28 = tmp27 < tmp10 + tmp29 = tmp18 & tmp28 + tmp30 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp31 = (tmp30 % tmp17) + tmp32 = tmp31 != tmp20 + tmp33 = (libdevice.signbit(tmp31) != 0) if (tmp31).dtype is tl.float32 else tmp31 < 0 + tmp34 = tmp33 != tmp23 + tmp35 = tmp32 & tmp34 + tmp36 = tmp31 + tmp17 + tmp37 = tl.where(tmp35, tmp36, tmp31) + tmp38 = tl.full([1, 1], 0, tl.int64) + tmp39 = tmp37 == tmp38 + tmp40 = tmp29 & tmp39 + tmp41 = tmp16 | tmp40 + tmp42 = tl.full(tmp41.shape, False, tmp41.dtype) + tmp43 = tl.where(tmp6, tmp41, tmp42) + tmp44 = tmp43.to(tl.int64) + tmp45 = tl.broadcast_to(tmp44, [XBLOCK, R0_BLOCK]) + tmp47 = _tmp46 + tmp45 + _tmp46 = tl.where(r0_mask & xmask, tmp47, _tmp46) + tmp46 = tl.sum(_tmp46, 1)[:, None] + tmp48 = tl.full([1, 1], 0, tl.int64) + tmp49 = tmp46 > tmp48 + tmp50 = tl.full([1, 1], 16384, tl.int64) + tmp51 = tmp46 < tmp50 + tmp52 = tmp49 & tmp51 + tmp53 = tmp52.to(tl.int8) + tmp54 = tmp53.to(tl.int32) + tmp55 = tmp46 == tmp50 + tmp56 = tmp55.to(tl.int8) + tmp57 = tmp56.to(tl.int32) + tl.store(out_ptr1 + (x5), tmp54, xmask) + tl.store(out_ptr2 + (x5), tmp57, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/2c/c2cmsqbgkrofzfikzrnehvhp4wxhze4bly4ct5edlg3syiny626e.py +# Topologically Sorted Source Nodes: [num_blocks_in_row, child_3], Original ATen: [aten.sum, aten._to_copy] +# Source node to ATen node mapping: +# child_3 => convert_element_type_3 +# num_blocks_in_row => sum_2 +# Graph fragment: +# %convert_element_type_2 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][(((s12 + 127)//128))*(((s37 + 127)//128)), 2*(((s12 + 127)//128))*(((s37 + 127)//128)), ((s37 + 127)//128), 1]cuda:3" = PlaceHolder[target=convert_element_type_2] +# %sum_2 : Tensor "i64[2, 1, ((s12 + 127)//128)][((s12 + 127)//128), 2*(((s12 + 127)//128)), 1]cuda:3" = PlaceHolder[target=sum_2] +# %sum_2 : Tensor "i64[2, 1, ((s12 + 127)//128)][Max(1, ((s12 + 127)//128)), Max(1, ((s12 + 127)//128)), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_2, [-1]), kwargs = {}) +# %convert_element_type_3 : Tensor "i32[2, 1, ((s12 + 127)//128)][Max(1, ((s12 + 127)//128)), Max(1, ((s12 + 127)//128)), 1]cuda:3"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_2, torch.int32), kwargs = {}) +# return %sum_2,%convert_element_type_3 +triton_red_fused__to_copy_sum_2 = async_compile.triton('triton_red_fused__to_copy_sum_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_sum_2(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + x2 = (xindex % ks1) + x3 = xindex // ks1 + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x2 + x3*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp5, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/w6/cw62diausl6wewbsyobebfl3nbh45k6rt3qi3czewb6njmwszxmy.py +# Topologically Sorted Source Nodes: [dense_mask_2, setitem, arange_4, row_indices, col_range, unsqueeze_1, index_mask, child_4, valid_indices], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.lt, aten._to_copy, aten.scalar_tensor, aten.where, aten.view, aten.index_put] +# Source node to ATen node mapping: +# arange_4 => iota_4 +# child_4 => convert_element_type_4 +# col_range => iota_5 +# dense_mask_2 => full_default_1 +# index_mask => lt_4 +# row_indices => unsqueeze +# setitem => full_default_2, index_put, iota_6, iota_7, unsqueeze_2, unsqueeze_3, unsqueeze_4, unsqueeze_5, unsqueeze_6 +# unsqueeze_1 => unsqueeze_1 +# valid_indices => scalar_tensor, where +# Graph fragment: +# %getitem_1 : Tensor "i64[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), 2*Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:3" = PlaceHolder[target=getitem_1] +# %convert_element_type_3 : Tensor "i32[2, 1, ((s12 + 127)//128)][Max(1, ((s12 + 127)//128)), Max(1, ((s12 + 127)//128)), 1]cuda:3" = PlaceHolder[target=convert_element_type_3] +# %convert_element_type_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:3" = PlaceHolder[target=convert_element_type_4] +# %index_put : Tensor "i32[2, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][((((s37 + 127)//128)) + 1)*(((s12 + 127)//128)), ((((s37 + 127)//128)) + 1)*(((s12 + 127)//128)), (((s37 + 127)//128)) + 1, 1]cuda:3" = PlaceHolder[target=index_put] +# %full_default_1 : Tensor "i32[2, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, %floordiv_3, %add_201], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:3, pin_memory: False}) +# %iota_7 : Tensor "i64[2][1]cuda:3"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (2,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:3, requires_grad: False}) +# %unsqueeze_4 : Tensor "i64[2, 1][1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_7, -1), kwargs = {}) +# %unsqueeze_5 : Tensor "i64[2, 1, 1][1, 1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_4, -1), kwargs = {}) +# %unsqueeze_6 : Tensor "i64[2, 1, 1, 1][1, 1, 1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_5, -1), kwargs = {}) +# %iota_6 : Tensor "i64[1][1]cuda:3"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:3, requires_grad: False}) +# %unsqueeze_2 : Tensor "i64[1, 1][1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_6, -1), kwargs = {}) +# %unsqueeze_3 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_2, -1), kwargs = {}) +# %iota_4 : Tensor "i32[((s12 + 127)//128)][1]cuda:3"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (%floordiv_3,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:3, requires_grad: False}) +# %unsqueeze : Tensor "i32[((s12 + 127)//128), 1][1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_4, -1), kwargs = {}) +# %iota_5 : Tensor "i32[((s37 + 127)//128)][1]cuda:3"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (%floordiv_2,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:3, requires_grad: False}) +# %unsqueeze_1 : Tensor "i32[2, 1, ((s12 + 127)//128), 1][Max(1, ((s12 + 127)//128)), Max(1, ((s12 + 127)//128)), 1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_3, 3), kwargs = {}) +# %lt_4 : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_5, %unsqueeze_1), kwargs = {}) +# %convert_element_type_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:3"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_1, torch.int32), kwargs = {}) +# %scalar_tensor : Tensor "i32[][]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (%floordiv_2,), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:3}) +# %where : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_4, %convert_element_type_4, %scalar_tensor), kwargs = {}) +# %full_default_2 : Tensor "i32[2, 1, 1, 1][1, 1, 1, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:3, pin_memory: False}) +# %index_put : Tensor "i32[2, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_1, [%unsqueeze_6, %unsqueeze_3, %unsqueeze, %where], %full_default_2), kwargs = {}) +# return %convert_element_type_4,%buf13 +triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3 = async_compile.triton('triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 512}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i32', 'out_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3', 'mutated_arg_names': ['out_ptr1'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3(in_ptr0, in_ptr1, out_ptr0, out_ptr1, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % ks1) + x2 = xindex // ks2 + x3 = xindex // ks0 + tmp0 = tl.load(in_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), xmask, eviction_policy='evict_last') + tmp2 = tl.load(in_ptr1 + (x3), xmask, eviction_policy='evict_last') + tmp1 = tmp0.to(tl.int32) + tmp3 = x0 + tmp4 = tmp3 < tmp2 + tmp5 = ks0 + tmp6 = tl.where(tmp4, tmp1, tmp5) + tmp7 = 1 + ks0 + tmp8 = tmp6 + tmp7 + tmp9 = tmp6 < 0 + tmp10 = tl.where(tmp9, tmp8, tmp6) + tl.device_assert(((0 <= tmp10) & (tmp10 < 1 + (triton_helpers.div_floor_integer(127 + ks3, 128)))) | ~(xmask), "index out of bounds: 0 <= tmp10 < 1 + (triton_helpers.div_floor_integer(127 + ks3, 128))") + tmp12 = tl.full([1], 1, tl.int32) + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp1, xmask) + tl.store(out_ptr1 + (tmp10 + x3 + ks0*x3), tmp12, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ki/ckilhdvljb7gisvjeh27eoct2t7a3jnlamhdjdlsk4sziursotb7.py +# Topologically Sorted Source Nodes: [batched_outputs_3], Original ATen: [aten.slice, aten.clone] +# Source node to ATen node mapping: +# batched_outputs_3 => clone_4, slice_4 +# Graph fragment: +# %buf13 : Tensor "i32[2, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][((((s37 + 127)//128)) + 1)*(((s12 + 127)//128)), ((((s37 + 127)//128)) + 1)*(((s12 + 127)//128)), (((s37 + 127)//128)) + 1, 1]cuda:3" = PlaceHolder[target=buf13] +# %slice_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, %floordiv_2), kwargs = {}) +# %clone_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_4,), kwargs = {memory_format: torch.contiguous_format}) +# return %clone_4 +triton_poi_fused_clone_slice_4 = async_compile.triton('triton_poi_fused_clone_slice_4', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 512}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr0': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_slice_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_clone_slice_4(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = xindex // ks0 + x2 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + x1 + ks0*x1), xmask, eviction_policy='evict_last') + tl.store(out_ptr0 + (x2), tmp0, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xt/cxtmrcn6ndghukfh42cluziweoazfmfrz3jeopcrromwvj5m3lsj.py +# Topologically Sorted Source Nodes: [batched_outputs_3, transpose, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sum, aten._to_copy] +# Source node to ATen node mapping: +# batched_outputs_3 => clone_4, slice_4 +# num_blocks_in_row_2 => sum_4 +# q_num_blocks => convert_element_type_8 +# transpose => permute_1 +# Graph fragment: +# %clone_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][(((s12 + 127)//128))*(((s37 + 127)//128)), 1, ((s37 + 127)//128), 1]cuda:3" = PlaceHolder[target=clone_4] +# %sum_4 : Tensor "i64[2, 1, ((s37 + 127)//128)][((s37 + 127)//128), 2*(((s37 + 127)//128)), 1]cuda:3" = PlaceHolder[target=sum_4] +# %slice_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, %floordiv_2), kwargs = {}) +# %clone_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_4,), kwargs = {memory_format: torch.contiguous_format}) +# %permute_1 : Tensor "i32[2, 1, ((s37 + 127)//128), ((s12 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%clone_4, [0, 1, 3, 2]), kwargs = {}) +# %sum_4 : Tensor "i64[2, 1, ((s37 + 127)//128)][Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute_1, [-1]), kwargs = {}) +# %convert_element_type_8 : Tensor "i32[2, 1, ((s37 + 127)//128)][Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_4, torch.int32), kwargs = {}) +# return %sum_4,%convert_element_type_8 +triton_red_fused__to_copy_clone_slice_sum_transpose_5 = async_compile.triton('triton_red_fused__to_copy_clone_slice_sum_transpose_5', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_clone_slice_sum_transpose_5', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_clone_slice_sum_transpose_5(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (x0 + ks0*r0_2 + ks0*ks1*x1), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp5, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ae/cae67gi6mcuey2plbam545opffg533rvjwoqmvcd54ylptdxjf2c.py +# Topologically Sorted Source Nodes: [q_indices], Original ATen: [aten._to_copy] +# Source node to ATen node mapping: +# q_indices => clone_6, convert_element_type_9 +# Graph fragment: +# %getitem_5 : Tensor "i64[2, 1, ((s37 + 127)//128), ((s12 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:3" = PlaceHolder[target=getitem_5] +# %convert_element_type_9 : Tensor "i32[2, 1, ((s37 + 127)//128), ((s12 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:3"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_5, torch.int32), kwargs = {}) +# %clone_6 : Tensor "i32[2, 1, ((s37 + 127)//128), ((s12 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128)), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_9,), kwargs = {memory_format: torch.contiguous_format}) +# return %clone_6 +triton_poi_fused__to_copy_6 = async_compile.triton('triton_poi_fused__to_copy_6', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 512}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_6', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused__to_copy_6(in_ptr0, out_ptr0, ks0, ks1, ks2, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % ks1) + x2 = xindex // ks2 + tmp0 = tl.load(in_ptr0 + (x1 + x0*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), xmask, eviction_policy='evict_last') + tmp1 = tmp0.to(tl.int32) + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp1, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1 = args + args.clear() + s12 = arg0_1 + s37 = arg1_1 + s21 = arg3_1 + assert_size_stride(arg2_1, (2, ), (1, )) + with torch.cuda._DeviceGuard(3): + torch.cuda.set_device(3) + buf12 = empty_strided_cuda((2, 1, (127 + s12) // 128, 1 + ((127 + s37) // 128)), (((127 + s12) // 128)*((127 + s37) // 128) + ((127 + s12) // 128), ((127 + s12) // 128)*((127 + s37) // 128) + ((127 + s12) // 128), 1 + ((127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] + triton_poi_fused_new_zeros_0_xnumel = 2*((127 + s12) // 128) + 2*((127 + s12) // 128)*((127 + s37) // 128) + stream3 = get_raw_stream(3) + triton_poi_fused_new_zeros_0.run(buf12, triton_poi_fused_new_zeros_0_xnumel, stream=stream3) + buf21 = empty_strided_cuda((2, 1, (127 + s12) // 128, 1 + ((127 + s37) // 128)), (((127 + s12) // 128)*((127 + s37) // 128) + ((127 + s12) // 128), ((127 + s12) // 128)*((127 + s37) // 128) + ((127 + s12) // 128), 1 + ((127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros] + triton_poi_fused_new_zeros_0_xnumel = 2*((127 + s12) // 128) + 2*((127 + s12) // 128)*((127 + s37) // 128) + stream3 = get_raw_stream(3) + triton_poi_fused_new_zeros_0.run(buf21, triton_poi_fused_new_zeros_0_xnumel, stream=stream3) + ps0 = (127 + s37) // 128 + ps1 = (127 + s12) // 128 + ps2 = ((127 + s12) // 128)*((127 + s37) // 128) + buf1 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 2*((127 + s12) // 128)*((127 + s37) // 128), (127 + s37) // 128, 1), torch.int32) + buf5 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 2*((127 + s12) // 128)*((127 + s37) // 128), (127 + s37) // 128, 1), torch.int32) + # Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_1, mask_2, mask_3, mask_block_sum, gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, full_blocks, full_blocks_1, dense_mask_1], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.constant_pad_nd, aten.permute, aten.sum, aten.gt, aten._to_copy] + triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream3 = get_raw_stream(3) + triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1.run(arg2_1, buf1, buf5, ps0, ps1, s12, s37, ps2, s21, triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1_xnumel, 16384, stream=stream3) + del arg2_1 + buf10 = empty_strided_cuda((2, 1, (127 + s12) // 128), (max(1, (127 + s12) // 128), max(1, (127 + s12) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [num_blocks_in_row, child_3], Original ATen: [aten.sum, aten._to_copy] + triton_red_fused__to_copy_sum_2_xnumel = 2*((127 + s12) // 128) + triton_red_fused__to_copy_sum_2_r0_numel = (127 + s37) // 128 + stream3 = get_raw_stream(3) + triton_red_fused__to_copy_sum_2.run(buf1, buf10, ps0, ps1, triton_red_fused__to_copy_sum_2_xnumel, triton_red_fused__to_copy_sum_2_r0_numel, stream=stream3) + buf19 = empty_strided_cuda((2, 1, (127 + s12) // 128), (max(1, (127 + s12) // 128), max(1, (127 + s12) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [num_blocks_in_row_1, child_7], Original ATen: [aten.sum, aten._to_copy] + triton_red_fused__to_copy_sum_2_xnumel = 2*((127 + s12) // 128) + triton_red_fused__to_copy_sum_2_r0_numel = (127 + s37) // 128 + stream3 = get_raw_stream(3) + triton_red_fused__to_copy_sum_2.run(buf5, buf19, ps0, ps1, triton_red_fused__to_copy_sum_2_xnumel, triton_red_fused__to_copy_sum_2_r0_numel, stream=stream3) + # Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort] + buf2 = torch.ops.aten.sort.stable(buf1, stable=True, dim=3, descending=True) + del buf1 + buf4 = buf2[1] + assert_size_stride(buf4, (2, 1, (127 + s12) // 128, (127 + s37) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), 2*max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), 'torch.ops.aten.sort.stable') + assert_alignment(buf4, 16, 'torch.ops.aten.sort.stable') + del buf2 + buf11 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2, setitem, arange_4, row_indices, col_range, unsqueeze_1, index_mask, child_4, valid_indices], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.lt, aten._to_copy, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream3 = get_raw_stream(3) + triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3.run(buf4, buf10, buf11, buf12, ps0, ps1, ps2, s37, triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3_xnumel, stream=stream3) + del buf4 + buf14 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 1, (127 + s37) // 128, 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_3], Original ATen: [aten.slice, aten.clone] + triton_poi_fused_clone_slice_4_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream3 = get_raw_stream(3) + triton_poi_fused_clone_slice_4.run(buf12, buf14, ps0, triton_poi_fused_clone_slice_4_xnumel, stream=stream3) + del buf12 + buf32 = empty_strided_cuda((2, 1, (127 + s37) // 128), (max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sum, aten._to_copy] + triton_red_fused__to_copy_clone_slice_sum_transpose_5_xnumel = 2*((127 + s37) // 128) + triton_red_fused__to_copy_clone_slice_sum_transpose_5_r0_numel = (127 + s12) // 128 + stream3 = get_raw_stream(3) + triton_red_fused__to_copy_clone_slice_sum_transpose_5.run(buf14, buf32, ps0, ps1, triton_red_fused__to_copy_clone_slice_sum_transpose_5_xnumel, triton_red_fused__to_copy_clone_slice_sum_transpose_5_r0_numel, stream=stream3) + # Topologically Sorted Source Nodes: [full_blocks, full_blocks_1, dense_mask_1, col_indices_1], Original ATen: [aten.eq, aten._to_copy, aten.sort] + buf6 = torch.ops.aten.sort.stable(buf5, stable=True, dim=3, descending=True) + del buf5 + buf8 = buf6[1] + assert_size_stride(buf8, (2, 1, (127 + s12) // 128, (127 + s37) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), 2*max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), 'torch.ops.aten.sort.stable') + assert_alignment(buf8, 16, 'torch.ops.aten.sort.stable') + del buf6 + buf20 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.lt, aten._to_copy, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream3 = get_raw_stream(3) + triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3.run(buf8, buf19, buf20, buf21, ps0, ps1, ps2, s37, triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3_xnumel, stream=stream3) + del buf8 + buf23 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 1, (127 + s37) // 128, 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_5], Original ATen: [aten.slice, aten.clone] + triton_poi_fused_clone_slice_4_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream3 = get_raw_stream(3) + triton_poi_fused_clone_slice_4.run(buf21, buf23, ps0, triton_poi_fused_clone_slice_4_xnumel, stream=stream3) + del buf21 + buf29 = empty_strided_cuda((2, 1, (127 + s37) // 128), (max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, num_blocks_in_row_3, full_q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sum, aten._to_copy] + triton_red_fused__to_copy_clone_slice_sum_transpose_5_xnumel = 2*((127 + s37) // 128) + triton_red_fused__to_copy_clone_slice_sum_transpose_5_r0_numel = (127 + s12) // 128 + stream3 = get_raw_stream(3) + triton_red_fused__to_copy_clone_slice_sum_transpose_5.run(buf23, buf29, ps0, ps1, triton_red_fused__to_copy_clone_slice_sum_transpose_5_xnumel, triton_red_fused__to_copy_clone_slice_sum_transpose_5_r0_numel, stream=stream3) + # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort] + buf15 = torch.ops.aten.sort.stable(reinterpret_tensor(buf14, (2, 1, (127 + s37) // 128, (127 + s12) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 0, 1, (127 + s37) // 128), 0), stable=True, dim=3, descending=True) + del buf14 + buf17 = buf15[1] + assert_size_stride(buf17, (2, 1, (127 + s37) // 128, (127 + s12) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), 1, max(1, (127 + s37) // 128)), 'torch.ops.aten.sort.stable') + assert_alignment(buf17, 16, 'torch.ops.aten.sort.stable') + del buf15 + buf30 = empty_strided_cuda((2, 1, (127 + s37) // 128, (127 + s12) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [q_indices], Original ATen: [aten._to_copy] + triton_poi_fused__to_copy_6_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream3 = get_raw_stream(3) + triton_poi_fused__to_copy_6.run(buf17, buf30, ps1, ps0, ps2, triton_poi_fused__to_copy_6_xnumel, stream=stream3) + del buf17 + # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, col_indices_3], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort] + buf24 = torch.ops.aten.sort.stable(reinterpret_tensor(buf23, (2, 1, (127 + s37) // 128, (127 + s12) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 0, 1, (127 + s37) // 128), 0), stable=True, dim=3, descending=True) + del buf23 + buf26 = buf24[1] + assert_size_stride(buf26, (2, 1, (127 + s37) // 128, (127 + s12) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), 1, max(1, (127 + s37) // 128)), 'torch.ops.aten.sort.stable') + assert_alignment(buf26, 16, 'torch.ops.aten.sort.stable') + del buf24 + buf27 = empty_strided_cuda((2, 1, (127 + s37) // 128, (127 + s12) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [full_q_indices], Original ATen: [aten._to_copy] + triton_poi_fused__to_copy_6_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream3 = get_raw_stream(3) + triton_poi_fused__to_copy_6.run(buf26, buf27, ps1, ps0, ps2, triton_poi_fused__to_copy_6_xnumel, stream=stream3) + del buf26 + return (buf27, buf29, buf30, buf32, buf20, buf19, buf11, buf10, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 2014 + arg1_1 = 2014 + arg2_1 = rand_strided((2, ), (1, ), device='cuda:3', dtype=torch.int64) + arg3_1 = 2014 + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/xe/cxe74qazmcwxkyh3xlgupaetbeksmhlptcogpgxu7tfvr4arcob6.py b/SpecForge-ext/cache/compiled_kernels/xe/cxe74qazmcwxkyh3xlgupaetbeksmhlptcogpgxu7tfvr4arcob6.py new file mode 100644 index 0000000000000000000000000000000000000000..4d609118317ed3835e5e48a01f935653e5155ea4 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/xe/cxe74qazmcwxkyh3xlgupaetbeksmhlptcogpgxu7tfvr4arcob6.py @@ -0,0 +1,40 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2, 'r0_': 8192}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_sum_3(in_ptr0, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (ks0*((((r0_1 + 4*ks0*x0) // ks0) % 8)) + ((r0_1 % ks0))), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = _tmp2 + tmp1 + _tmp2 = tl.where(r0_mask & xmask, tmp3, _tmp2) + tmp2 = tl.sum(_tmp2, 1)[:, None] + tl.store(out_ptr0 + (x0), tmp2, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/xe/cxeou6auzbu4dnrn2twxe573bmqovq7xnk4b6hydfbw53px4etc7.py b/SpecForge-ext/cache/compiled_kernels/xe/cxeou6auzbu4dnrn2twxe573bmqovq7xnk4b6hydfbw53px4etc7.py new file mode 100644 index 0000000000000000000000000000000000000000..864206d498e950424a14513008e9ea428469341e --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/xe/cxeou6auzbu4dnrn2twxe573bmqovq7xnk4b6hydfbw53px4etc7.py @@ -0,0 +1,46 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'out_ptr1': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_clamp_min_div_sum_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_clamp_min_div_sum_3(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 1 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = _tmp2 + tmp1 + _tmp2 = tl.where(r0_mask, tmp3, _tmp2) + tmp2 = tl.sum(_tmp2, 1)[:, None] + tmp4 = tl.load(in_ptr1 + (0)) + tmp5 = tl.broadcast_to(tmp4, [XBLOCK, 1]) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp2.to(tl.float32) + tmp8 = 1e-06 + tmp9 = triton_helpers.maximum(tmp7, tmp8) + tmp10 = (tmp6 / tmp9) + tl.store(out_ptr1 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp10, None) diff --git a/SpecForge-ext/cache/compiled_kernels/xm/cxmzwrsvqg6m4osmwishqxje7ewrr7v76bsoo62xlg7rfcwmygir.py b/SpecForge-ext/cache/compiled_kernels/xm/cxmzwrsvqg6m4osmwishqxje7ewrr7v76bsoo62xlg7rfcwmygir.py new file mode 100644 index 0000000000000000000000000000000000000000..ac63180ff4fcbd71a0580a60707841f5882f62bf --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/xm/cxmzwrsvqg6m4osmwishqxje7ewrr7v76bsoo62xlg7rfcwmygir.py @@ -0,0 +1,56 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 262144}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i64', 'r0_numel': 'i64', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 151936 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0).to(tl.int64) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None].to(tl.int64) + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :].to(tl.int64) + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 9223372036854775807, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tmp11 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last') + tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32) + tmp4 = tmp2 + tmp3 + tmp5 = tmp2 < 0 + tmp6 = tl.where(tmp5, tmp4, tmp2) + tl.device_assert(((0 <= tmp6) & (tmp6 < 151936)) | ~(xmask), "index out of bounds: 0 <= tmp6 < 151936") + tmp8 = tl.load(in_ptr1 + (tmp6), xmask, eviction_policy='evict_last').to(tl.int1) + tmp9 = tmp8.to(tl.int32) + tmp10 = tmp9.to(tl.int64) + tmp12 = tmp10 * tmp11 + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/xp/cxpp6d655eeypwfd7ox74cqiy7jxhai6ad6wrfxaftzthsrkzsyf.py b/SpecForge-ext/cache/compiled_kernels/xp/cxpp6d655eeypwfd7ox74cqiy7jxhai6ad6wrfxaftzthsrkzsyf.py new file mode 100644 index 0000000000000000000000000000000000000000..5d71b250f5b9b41debf9e3cd622c75ece26886f9 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/xp/cxpp6d655eeypwfd7ox74cqiy7jxhai6ad6wrfxaftzthsrkzsyf.py @@ -0,0 +1,1051 @@ +# AOT ID: ['6_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/jc/cjcezd4fm2g2fppy44lhtzc36sz7bi63sscwdmenwlvu3y4xt7np.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:2" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 262144, 128, 1]cuda:2" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[8, 32, 2048][65536, 2048, 1]cuda:2" = PlaceHolder[target=buf0] +# %full_default : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:2, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_zeros_0 = async_compile.triton('triton_red_fused_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 524288, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 4194304, 'r0_': 268435456}} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 524288 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = ((xindex // 2048) % 32) + x2 = xindex // 65536 + x4 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/za/czavugmlxltvuliq4dhv6chq23pigmyj6dktmupor2jiyiklfef6.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:2" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[8, 8, 2048, 128][2097152, 262144, 128, 1]cuda:2" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[8, 8, 2048, 128][2097152, 262144, 128, 1]cuda:2" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:2" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:2" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 262144, 128, 1]cuda:2" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:2" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[8, 8, 2048, 128][2097152, 262144, 128, 1]cuda:2" = PlaceHolder[target=getitem_5] +# %primals_5 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:2" = PlaceHolder[target=primals_5] +# %primals_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:2" = PlaceHolder[target=primals_4] +# %primals_9 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:2" = PlaceHolder[target=primals_9] +# %primals_10 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:2" = PlaceHolder[target=primals_10] +# %primals_7 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:2" = PlaceHolder[target=primals_7] +# %primals_8 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:2" = PlaceHolder[target=primals_8] +# %primals_11 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:2" = PlaceHolder[target=primals_11] +# %primals_12 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:2" = PlaceHolder[target=primals_12] +# %primals_6 : Tensor "i64[8][1]cuda:2" = PlaceHolder[target=primals_6] +# %full_default : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:2, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {}) +# return %getitem_4 +triton_tem_fused_zeros_1 = async_compile.triton('triton_tem_fused_zeros_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 16 + stride_q_idx_h = 256 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, getitem, getitem_1, tangents_1 = args + args.clear() + assert_size_stride(primals_1, (8, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_2, (8, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_3, (8, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_4, (8, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_5, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_6, (8, ), (1, )) + assert_size_stride(primals_7, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_8, (8, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_9, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (8, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_11, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_12, (8, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(getitem, (8, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(getitem_1, (8, 32, 2048), (65536, 2048, 1)) + assert_size_stride(tangents_1, (8, 32, 2048, 128), (8388608, 262144, 128, 1)) + with torch.cuda._DeviceGuard(2): + torch.cuda.set_device(2) + buf1 = empty_strided_cuda((8, 32, 2048), (65536, 2048, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream2 = get_raw_stream(2) + triton_red_fused_zeros_0.run(getitem, tangents_1, buf1, 524288, 128, stream=stream2) + del getitem + buf3 = empty_strided_cuda((8, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((8, 8, 2048, 128), (2097152, 262144, 128, 1), torch.bfloat16) + buf5 = empty_strided_cuda((8, 8, 2048, 128), (2097152, 262144, 128, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream2 = get_raw_stream(2) + triton_tem_fused_zeros_1.run(primals_1, primals_2, primals_3, getitem_1, buf1, tangents_1, buf3, buf4, primals_5, primals_4, primals_9, primals_10, primals_7, primals_8, primals_11, primals_12, primals_6, buf5, 80, 8, 8, stream=stream2) + del buf1 + del getitem_1 + del primals_1 + del primals_10 + del primals_11 + del primals_12 + del primals_2 + del primals_3 + del primals_4 + del primals_5 + del primals_6 + del primals_7 + del primals_8 + del primals_9 + del tangents_1 + return (buf3, buf5, buf4, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16) + primals_2 = rand_strided((8, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:2', dtype=torch.bfloat16) + primals_3 = rand_strided((8, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:2', dtype=torch.bfloat16) + primals_4 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:2', dtype=torch.int32) + primals_5 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:2', dtype=torch.int32) + primals_6 = rand_strided((8, ), (1, ), device='cuda:2', dtype=torch.int64) + primals_7 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:2', dtype=torch.int32) + primals_8 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:2', dtype=torch.int32) + primals_9 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:2', dtype=torch.int32) + primals_10 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:2', dtype=torch.int32) + primals_11 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:2', dtype=torch.int32) + primals_12 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:2', dtype=torch.int32) + getitem = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16) + getitem_1 = rand_strided((8, 32, 2048), (65536, 2048, 1), device='cuda:2', dtype=torch.float32) + tangents_1 = rand_strided((8, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:2', dtype=torch.bfloat16) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, getitem, getitem_1, tangents_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/xp/cxprnl6wyrkxecwymb5nwdyyiuq4vpew4zlpdy2zpq7whdmm3twe.py b/SpecForge-ext/cache/compiled_kernels/xp/cxprnl6wyrkxecwymb5nwdyyiuq4vpew4zlpdy2zpq7whdmm3twe.py new file mode 100644 index 0000000000000000000000000000000000000000..5fcc29af3d42d4af83da23cf6b97fbfecac8f338 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/xp/cxprnl6wyrkxecwymb5nwdyyiuq4vpew4zlpdy2zpq7whdmm3twe.py @@ -0,0 +1,56 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 262144}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 151936 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tmp11 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last') + tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32) + tmp4 = tmp2 + tmp3 + tmp5 = tmp2 < 0 + tmp6 = tl.where(tmp5, tmp4, tmp2) + tl.device_assert(((0 <= tmp6) & (tmp6 < 151936)) | ~(xmask), "index out of bounds: 0 <= tmp6 < 151936") + tmp8 = tl.load(in_ptr1 + (tmp6), xmask, eviction_policy='evict_last').to(tl.int1) + tmp9 = tmp8.to(tl.int32) + tmp10 = tmp9.to(tl.int64) + tmp12 = tmp10 * tmp11 + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/xu/cxud6f772kqmkn5hbju35maiws4eoueh62pnfgpfpdg6c6uykwzz.py b/SpecForge-ext/cache/compiled_kernels/xu/cxud6f772kqmkn5hbju35maiws4eoueh62pnfgpfpdg6c6uykwzz.py new file mode 100644 index 0000000000000000000000000000000000000000..705dd0953dbdeee04f05625c1c30cbe2a9d70b1a --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/xu/cxud6f772kqmkn5hbju35maiws4eoueh62pnfgpfpdg6c6uykwzz.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 2 + HQ = 32 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/yd/864a4c6b0248dca933ac3672934a07906f020bb888bf1d34a30f790fcf8d995b.best_config b/SpecForge-ext/cache/compiled_kernels/yd/864a4c6b0248dca933ac3672934a07906f020bb888bf1d34a30f790fcf8d995b.best_config new file mode 100644 index 0000000000000000000000000000000000000000..422e1afda877306872879bb2d038c5a4e486fa13 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/yd/864a4c6b0248dca933ac3672934a07906f020bb888bf1d34a30f790fcf8d995b.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "6fcabd0411a839b7b5d117b5e6638bd1b5d7bc3379312c678d803859f08278a9", "found_by_coordesc": false, "time_taken_ms": 27, "triton_cache_hash": "NFXDDTZIEQAGR5JBWWCXV73QZILXJMIVVJVPZH3CHAIULDAND5UQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/yd/cydbdwjbypjuoymwltsnbdacvkducp6g2mxtlx6bqk4tkmf4qofp.py b/SpecForge-ext/cache/compiled_kernels/yd/cydbdwjbypjuoymwltsnbdacvkducp6g2mxtlx6bqk4tkmf4qofp.py new file mode 100644 index 0000000000000000000000000000000000000000..55336bac323c4869fb673fb568b600c755d4cb5b --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/yd/cydbdwjbypjuoymwltsnbdacvkducp6g2mxtlx6bqk4tkmf4qofp.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 2097152, 262144, 128, 1 + + ZQ = 2 + HQ = 32 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/yd/cydm4ixmqvhbavulxm2fcljwgiii233bmw2z5sp4twyvkoyyd6vm.py b/SpecForge-ext/cache/compiled_kernels/yd/cydm4ixmqvhbavulxm2fcljwgiii233bmw2z5sp4twyvkoyyd6vm.py new file mode 100644 index 0000000000000000000000000000000000000000..b5d4100396ad4892988042fab761ba022e8e8897 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/yd/cydm4ixmqvhbavulxm2fcljwgiii233bmw2z5sp4twyvkoyyd6vm.py @@ -0,0 +1,86 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr4': '*i32', 'out_ptr5': '*i32', 'out_ptr6': '*i32', 'out_ptr7': '*i32', 'out_ptr8': '*i32', 'out_ptr9': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr7', 'out_ptr9'], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(in_ptr0, out_ptr4, out_ptr5, out_ptr6, out_ptr7, out_ptr8, out_ptr9, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 32 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + tmp0 = tl.load(in_ptr0 + (r0_1 + 16*x0), xmask, other=0.0) + tmp1 = tl.full([1, 1], 0, tl.int64) + tmp2 = tmp0 > tmp1 + tmp3 = tl.full([1, 1], 16384, tl.int64) + tmp4 = tmp0 < tmp3 + tmp5 = tmp2 & tmp4 + tmp6 = tmp5.to(tl.int8) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8.to(tl.int16) + tmp10 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp11 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12, tmp13, = triton_helpers.sort_with_index(tmp10, tmp11, None, 1, stable=True, descending=True) + tmp14 = tmp0 == tmp3 + tmp15 = tmp14.to(tl.int8) + tmp16 = tmp15.to(tl.int32) + tmp17 = tl.broadcast_to(tmp16, [XBLOCK, R0_BLOCK]) + tmp18, tmp19, = triton_helpers.sort_with_index(tmp17, tmp11, None, 1, stable=True, descending=True) + tmp20 = tmp7.to(tl.int64) + tmp21 = tl.broadcast_to(tmp20, [XBLOCK, R0_BLOCK]) + tmp23 = tl.where(xmask, tmp21, 0) + tmp24 = tl.sum(tmp23, 1)[:, None].to(tl.int64) + tmp25 = tmp16.to(tl.int64) + tmp26 = tl.broadcast_to(tmp25, [XBLOCK, R0_BLOCK]) + tmp28 = tl.where(xmask, tmp26, 0) + tmp29 = tl.sum(tmp28, 1)[:, None].to(tl.int64) + tmp30 = tmp24.to(tl.int32) + tmp31 = tmp29.to(tl.int32) + tmp32 = tmp13.to(tl.int64) + tmp33 = tmp32.to(tl.int32) + tmp34 = tmp8 < tmp30 + tmp35 = tl.full([1, 1], 16, tl.int32) + tmp36 = tl.where(tmp34, tmp33, tmp35) + tmp37 = tl.full([XBLOCK, R0_BLOCK], 17, tl.int32) + tmp38 = tmp36 + tmp37 + tmp39 = tmp36 < 0 + tmp40 = tl.where(tmp39, tmp38, tmp36) + tl.device_assert(((0 <= tmp40) & (tmp40 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp40 < 17") + tmp42 = tl.full([1, 1], 1, tl.int32) + tmp43 = tmp19.to(tl.int64) + tmp44 = tmp43.to(tl.int32) + tmp45 = tmp8 < tmp31 + tmp46 = tl.where(tmp45, tmp44, tmp35) + tmp47 = tmp46 + tmp37 + tmp48 = tmp46 < 0 + tmp49 = tl.where(tmp48, tmp47, tmp46) + tl.device_assert(((0 <= tmp49) & (tmp49 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp49 < 17") + tl.store(out_ptr4 + (x0), tmp30, xmask) + tl.store(out_ptr5 + (x0), tmp31, xmask) + tl.store(out_ptr6 + (r0_1 + 16*x0), tmp33, xmask) + tl.store(out_ptr7 + (tl.broadcast_to(tmp40 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) + tl.store(out_ptr8 + (r0_1 + 16*x0), tmp44, xmask) + tl.store(out_ptr9 + (tl.broadcast_to(tmp49 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py b/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py new file mode 100644 index 0000000000000000000000000000000000000000..c969b5fd96b728cc2c50c948ce6afeba254e3b46 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/z4/cz4i5pkucu3lmx6qfdqafj2zggun4yjrtg7ax2ufqthke4za3ff7.py @@ -0,0 +1,49 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 524288, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 4194304, 'r0_': 268435456}} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 524288 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = ((xindex // 2048) % 32) + x2 = xindex // 65536 + x4 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, None) diff --git a/SpecForge-ext/cache/compiled_kernels/z4/cz4jnlpbb32eopbc2caystnepiaizyinwoncir73za7sf3sijadk.py b/SpecForge-ext/cache/compiled_kernels/z4/cz4jnlpbb32eopbc2caystnepiaizyinwoncir73za7sf3sijadk.py new file mode 100644 index 0000000000000000000000000000000000000000..04822027d6d53fbdd664e42ce7d23328b49a606e --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/z4/cz4jnlpbb32eopbc2caystnepiaizyinwoncir73za7sf3sijadk.py @@ -0,0 +1,49 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 131072, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 1048576, 'r0_': 67108864}} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 131072 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = ((xindex // 2048) % 32) + x2 = xindex // 65536 + x4 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, None) diff --git a/SpecForge-ext/cache/compiled_kernels/z4/e8b5f218b2d48874ddafc2d09c16572cbbf3840bf7e8a3597fd0935df747e5f6.best_config b/SpecForge-ext/cache/compiled_kernels/z4/e8b5f218b2d48874ddafc2d09c16572cbbf3840bf7e8a3597fd0935df747e5f6.best_config new file mode 100644 index 0000000000000000000000000000000000000000..990be040d913054ee650201b25cf2c95af882efd --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/z4/e8b5f218b2d48874ddafc2d09c16572cbbf3840bf7e8a3597fd0935df747e5f6.best_config @@ -0,0 +1 @@ +{"XBLOCK": 8, "R0_BLOCK": 128, "num_warps": 8, "num_stages": 1, "configs_hash": "b70837e3723f218c7368cc2b49566dcd2bec3baf4c88b5e174a3f0822a6c86c0", "found_by_coordesc": false, "time_taken_ms": 142, "triton_cache_hash": "BZ2FPB5QIE7EHR6P7EPVPHR4HKS3YX3QQPIWQIT2R3EOJOAVWCGA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/z7/cz7myxgtooymzhwfspespi6npzwo7vyhpfccmgq3q6pvqxo3c4cu.py b/SpecForge-ext/cache/compiled_kernels/z7/cz7myxgtooymzhwfspespi6npzwo7vyhpfccmgq3q6pvqxo3c4cu.py new file mode 100644 index 0000000000000000000000000000000000000000..802691e173e52cf21008d8a1334f3b429deae270 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/z7/cz7myxgtooymzhwfspespi6npzwo7vyhpfccmgq3q6pvqxo3c4cu.py @@ -0,0 +1,89 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1(in_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % 16) + x2 = xindex // ks2 + _tmp36 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x5 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = (r0_index % 128) + r0_4 = r0_index // 128 + tmp0 = r0_3 + 128*x0 + tmp1 = ks1 + tmp2 = tmp0 < tmp1 + tmp3 = r0_4 + 128*x1 + tmp4 = r0_3 + 128*x0 + tmp5 = tmp3 >= tmp4 + tmp6 = tl.load(in_ptr0 + (tl.broadcast_to(x2, [XBLOCK, R0_BLOCK])), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp7 = tmp4 < tmp6 + tmp8 = tmp3 < tmp6 + tmp9 = tmp7 & tmp8 + tmp10 = tmp5 & tmp9 + tmp11 = tl.full([1, 1], False, tl.int1) + tmp12 = tmp11 | tmp10 + tmp13 = tl.full([1, 1], 2048, tl.int64) + tmp14 = tmp4 >= tmp13 + tmp15 = ((r0_3 + 128*x0) % 2048) + tmp16 = tmp15 < tmp6 + tmp17 = tmp14 & tmp16 + tmp18 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp19 = (tmp18 % tmp13) + tmp20 = tl.full([1, 1], 0, tl.int32) + tmp21 = tmp19 != tmp20 + tmp22 = (libdevice.signbit(tmp19) != 0) if (tmp19).dtype is tl.float32 else tmp19 < 0 + tmp23 = (libdevice.signbit(tmp13) != 0) if (tmp13).dtype is tl.float32 else tmp13 < 0 + tmp24 = tmp22 != tmp23 + tmp25 = tmp21 & tmp24 + tmp26 = tmp19 + tmp13 + tmp27 = tl.where(tmp25, tmp26, tmp19) + tmp28 = tl.full([1, 1], 0, tl.int64) + tmp29 = tmp27 == tmp28 + tmp30 = tmp17 & tmp29 + tmp31 = tmp12 | tmp30 + tmp32 = tl.full(tmp31.shape, False, tmp31.dtype) + tmp33 = tl.where(tmp2, tmp31, tmp32) + tmp34 = tmp33.to(tl.int64) + tmp35 = tl.broadcast_to(tmp34, [XBLOCK, R0_BLOCK]) + tmp37 = _tmp36 + tmp35 + _tmp36 = tl.where(r0_mask & xmask, tmp37, _tmp36) + tmp36 = tl.sum(_tmp36, 1)[:, None] + tmp38 = tl.full([1, 1], 0, tl.int64) + tmp39 = tmp36 > tmp38 + tmp40 = tl.full([1, 1], 16384, tl.int64) + tmp41 = tmp36 < tmp40 + tmp42 = tmp39 & tmp41 + tmp43 = tmp42.to(tl.int8) + tmp44 = tmp43.to(tl.int32) + tmp45 = tmp36 == tmp40 + tmp46 = tmp45.to(tl.int8) + tmp47 = tmp46.to(tl.int32) + tl.store(out_ptr1 + (x5), tmp44, xmask) + tl.store(out_ptr2 + (x5), tmp47, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/zb/czbkusdbmiy6kikq4qod7mdx6wj5hskghcdn5gr4z6hcdy4v3nbz.py b/SpecForge-ext/cache/compiled_kernels/zb/czbkusdbmiy6kikq4qod7mdx6wj5hskghcdn5gr4z6hcdy4v3nbz.py new file mode 100644 index 0000000000000000000000000000000000000000..7d665486d1fff622617e1b77a827b9bdf9e3d819 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/zb/czbkusdbmiy6kikq4qod7mdx6wj5hskghcdn5gr4z6hcdy4v3nbz.py @@ -0,0 +1,46 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2, 'r0_': 8192}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 8, 'r0_': 393216}} +) +@triton.jit +def triton_red_fused_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2 + r0_numel = 8192 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.load(in_ptr1 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp4 = tl.load(in_ptr2 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask & xmask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + tl.store(out_ptr0 + (x0), tmp7, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/zg/9c5264aef321a107b976923f962436d88a4f005efd2f24b155395c72716b09f3.best_config b/SpecForge-ext/cache/compiled_kernels/zg/9c5264aef321a107b976923f962436d88a4f005efd2f24b155395c72716b09f3.best_config new file mode 100644 index 0000000000000000000000000000000000000000..a337a719c6503c8dcbad0c427c4a5067600d0bd0 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/zg/9c5264aef321a107b976923f962436d88a4f005efd2f24b155395c72716b09f3.best_config @@ -0,0 +1 @@ +{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "6FB7I6IASCIGI3DSKLBL4Q2CXFFWPYWXW7AMHNUUDLPGKUCB3PDA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/zg/czg53pk3l24wn74a6bylpzbgb44kx2zfplies7n5uiiogfzwg4z2.py b/SpecForge-ext/cache/compiled_kernels/zg/czg53pk3l24wn74a6bylpzbgb44kx2zfplies7n5uiiogfzwg4z2.py new file mode 100644 index 0000000000000000000000000000000000000000..37ba16f3f92b1f38d8600705809712a546f91e85 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/zg/czg53pk3l24wn74a6bylpzbgb44kx2zfplies7n5uiiogfzwg4z2.py @@ -0,0 +1,52 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 131072, 'r0_': 128}, + reduction_hint=ReductionHint.OUTER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mul_sum_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_mul_sum_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = xindex // ks0 + x0 = (xindex % ks0) + _tmp13 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = r0_2 + x1*((31 + ks1*ks2) // 32) + tmp1 = ks1*ks2 + tmp2 = tmp0 < tmp1 + tmp3 = tl.load(in_ptr0 + (x0 + ks0*(((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2)))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x0 + ks0*(((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2)))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp5 = tmp4.to(tl.float32) + tmp6 = tl.load(in_ptr2 + (((r0_2 + x1*((31 + ks1*ks2) // 32)) % (ks1*ks2))), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp7 = tmp5 * tmp6 + tmp8 = tmp7.to(tl.float32) + tmp9 = tmp3 * tmp8 + tmp10 = tl.full(tmp9.shape, 0, tmp9.dtype) + tmp11 = tl.where(tmp2, tmp9, tmp10) + tmp12 = tl.broadcast_to(tmp11, [XBLOCK, R0_BLOCK]) + tmp14 = _tmp13 + tmp12 + _tmp13 = tl.where(r0_mask & xmask, tmp14, _tmp13) + tmp13 = tl.sum(_tmp13, 1)[:, None] + tl.store(out_ptr0 + (x3), tmp13, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/zg/czgvbv7mroso4vhea3ibxxosa5syusucgs7u4gi2cl3bj2oi4u5l.py b/SpecForge-ext/cache/compiled_kernels/zg/czgvbv7mroso4vhea3ibxxosa5syusucgs7u4gi2cl3bj2oi4u5l.py new file mode 100644 index 0000000000000000000000000000000000000000..3f795b6b5653f36ca683d3e54a5999250642dcdb --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/zg/czgvbv7mroso4vhea3ibxxosa5syusucgs7u4gi2cl3bj2oi4u5l.py @@ -0,0 +1,25 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 1024}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 4352}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_1(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 544 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/zj/34ce7267e05f32abd011e61d991fc62e33975090ed3219470f5fd09e944548f2.best_config b/SpecForge-ext/cache/compiled_kernels/zj/34ce7267e05f32abd011e61d991fc62e33975090ed3219470f5fd09e944548f2.best_config new file mode 100644 index 0000000000000000000000000000000000000000..b9c83cd70cc4f7d46eca037549afe001d843ad6c --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/zj/34ce7267e05f32abd011e61d991fc62e33975090ed3219470f5fd09e944548f2.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 49, "triton_cache_hash": "Z2RWAHMO7VUWQKIIRA5A46JYV2SEXHWLKREQM7TOP6VGUWDXAYAQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/zj/czj6vv4i5yhbwkg6and5xgv6trjy7uc6ixza33q5syppm7ku5e6s.py b/SpecForge-ext/cache/compiled_kernels/zj/czj6vv4i5yhbwkg6and5xgv6trjy7uc6ixza33q5syppm7ku5e6s.py new file mode 100644 index 0000000000000000000000000000000000000000..67ee674ca0cd1e526eaaa713aaeabf322cc59f44 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/zj/czj6vv4i5yhbwkg6and5xgv6trjy7uc6ixza33q5syppm7ku5e6s.py @@ -0,0 +1,164 @@ +# AOT ID: ['0_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/kb/ckbv35732vlbuwbdmg6pzqsokaq3orhgmp5ym4q4fg2onojs4phw.py +# Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul] +# Source node to ATen node mapping: +# getitem_1 => unsqueeze +# position_mask => mul +# target_mask => index +# target_mask_1 => convert_element_type +# target_max_token => argmax +# Graph fragment: +# %arg0_1 : Tensor "bf16[8, 2048, 151936][311164928, 151936, 1]cuda:1" = PlaceHolder[target=arg0_1] +# %argmax : Tensor "i64[8, 2048][2048, 1]cuda:1" = PlaceHolder[target=argmax] +# %arg1_1 : Tensor "b8[151936][1]cuda:1" = PlaceHolder[target=arg1_1] +# %arg2_1 : Tensor "i64[8, 2048, 1][2048, 1, 1]cuda:1" = PlaceHolder[target=arg2_1] +# %argmax : Tensor "i64[8, 2048][2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg0_1, -1), kwargs = {}) +# %index : Tensor "b8[8, 2048][2048, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg1_1, [%argmax]), kwargs = {}) +# %unsqueeze : Tensor "b8[8, 2048, 1][2048, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 2), kwargs = {}) +# %convert_element_type : Tensor "i32[8, 2048, 1][2048, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze, torch.int32), kwargs = {}) +# %mul : Tensor "i64[8, 2048, 1][2048, 1, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %arg2_1), kwargs = {}) +# return %argmax,%mul +triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0 = async_compile.triton('triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 262144}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i64', 'r0_numel': 'i64', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 151936 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0).to(tl.int64) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None].to(tl.int64) + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :].to(tl.int64) + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 9223372036854775807, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tmp11 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last') + tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32) + tmp4 = tmp2 + tmp3 + tmp5 = tmp2 < 0 + tmp6 = tl.where(tmp5, tmp4, tmp2) + tl.device_assert((0 <= tmp6) & (tmp6 < 151936), "index out of bounds: 0 <= tmp6 < 151936") + tmp8 = tl.load(in_ptr1 + (tmp6), None, eviction_policy='evict_last').to(tl.int1) + tmp9 = tmp8.to(tl.int32) + tmp10 = tmp9.to(tl.int64) + tmp12 = tmp10 * tmp11 + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1 = args + args.clear() + assert_size_stride(arg0_1, (8, 2048, 151936), (311164928, 151936, 1)) + assert_size_stride(arg1_1, (151936, ), (1, )) + assert_size_stride(arg2_1, (8, 2048, 1), (2048, 1, 1)) + with torch.cuda._DeviceGuard(1): + torch.cuda.set_device(1) + buf0 = empty_strided_cuda((8, 2048), (2048, 1), torch.int64) + buf1 = reinterpret_tensor(buf0, (8, 2048, 1), (2048, 1, 1), 0); del buf0 # reuse + # Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul] + stream1 = get_raw_stream(1) + triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0.run(buf1, arg0_1, arg1_1, arg2_1, 16384, 151936, stream=stream1) + del arg0_1 + del arg1_1 + del arg2_1 + return (buf1, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((8, 2048, 151936), (311164928, 151936, 1), device='cuda:1', dtype=torch.bfloat16) + arg1_1 = rand_strided((151936, ), (1, ), device='cuda:1', dtype=torch.bool) + arg2_1 = rand_strided((8, 2048, 1), (2048, 1, 1), device='cuda:1', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1, arg2_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/zj/czjfy4ep2ki2h5g622iazagvjfsjocncon5vbpldegigrx5lxkhe.py b/SpecForge-ext/cache/compiled_kernels/zj/czjfy4ep2ki2h5g622iazagvjfsjocncon5vbpldegigrx5lxkhe.py new file mode 100644 index 0000000000000000000000000000000000000000..50d9453be040ddb544971f5e15001b06bd689f9d --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/zj/czjfy4ep2ki2h5g622iazagvjfsjocncon5vbpldegigrx5lxkhe.py @@ -0,0 +1,72 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 512, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 8192, 'r0_': 0}} +) +@triton.jit +def triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 512 + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = ((xindex // 16) % 16) + x0 = (xindex % 16) + x2 = xindex // 256 + tmp3 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + _tmp29 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x6 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_4 = r0_index // 128 + r0_3 = (r0_index % 128) + tmp0 = r0_4 + 128*x1 + tmp1 = r0_3 + 128*x0 + tmp2 = tmp0 >= tmp1 + tmp4 = tmp1 < tmp3 + tmp5 = tmp0 < tmp3 + tmp6 = tmp4 & tmp5 + tmp7 = tmp2 & tmp6 + tmp8 = tl.full([1, 1], False, tl.int1) + tmp9 = tmp8 | tmp7 + tmp10 = tl.full([1, 1], 2048, tl.int64) + tmp11 = tmp1 >= tmp10 + tmp12 = tmp11 & tmp4 + tmp13 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp14 = (tmp13 % tmp10) + tmp15 = tl.full([1, 1], 0, tl.int32) + tmp16 = tmp14 != tmp15 + tmp17 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp18 = (libdevice.signbit(tmp10) != 0) if (tmp10).dtype is tl.float32 else tmp10 < 0 + tmp19 = tmp17 != tmp18 + tmp20 = tmp16 & tmp19 + tmp21 = tmp14 + tmp10 + tmp22 = tl.where(tmp20, tmp21, tmp14) + tmp23 = tl.full([1, 1], 0, tl.int64) + tmp24 = tmp22 == tmp23 + tmp25 = tmp12 & tmp24 + tmp26 = tmp9 | tmp25 + tmp27 = tmp26.to(tl.int64) + tmp28 = tl.broadcast_to(tmp27, [XBLOCK, R0_BLOCK]) + tmp30 = _tmp29 + tmp28 + _tmp29 = tl.where(r0_mask & xmask, tmp30, _tmp29) + tmp29 = tl.sum(_tmp29, 1)[:, None] + tl.store(out_ptr0 + (x6), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/zj/czjh2orviux4gjtyibvg2imw6ffvw36tn4jyfkfu6fpomfwp53n4.py b/SpecForge-ext/cache/compiled_kernels/zj/czjh2orviux4gjtyibvg2imw6ffvw36tn4jyfkfu6fpomfwp53n4.py new file mode 100644 index 0000000000000000000000000000000000000000..f4eef58ceaec48f6cc6b16d756c11bde7b5c2d04 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/zj/czjh2orviux4gjtyibvg2imw6ffvw36tn4jyfkfu6fpomfwp53n4.py @@ -0,0 +1,693 @@ +# AOT ID: ['9_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +from torch._C import _cuda_getCurrentRawStream as get_raw_stream +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/il/cilp5kqvrljsbeu2eadyyvf76cdxypu34m6m4bfrk3qitwvhuaei.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:7" = PlaceHolder[target=primals_1] +# %primals_3 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:7" = PlaceHolder[target=primals_3] +# %primals_5 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:7" = PlaceHolder[target=primals_5] +# %getitem_1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:7" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:7" = PlaceHolder[target=buf1] +# %primals_9 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:7" = PlaceHolder[target=primals_9] +# %primals_7 : Tensor "i32[8, 1, 16, s72][16*s72, 16*s72, s72, 1]cuda:7" = PlaceHolder[target=primals_7] +# %primals_11 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:7" = PlaceHolder[target=primals_11] +# %primals_13 : Tensor "i32[8, 1, 16, s4][16*s4, 16*s4, s4, 1]cuda:7" = PlaceHolder[target=primals_13] +# %primals_10 : Tensor "i64[8][1]cuda:7" = PlaceHolder[target=primals_10] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_3, %primals_5, %sdpa_score0, (2048, %primals_8, %primals_9, %primals_7, %primals_11, %primals_13, %primals_15, %primals_17, %primals_19, %primals_21, 128, 128, %sdpa_mask0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_10,)), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 8 + HQ = 32 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21 = args + args.clear() + s0 = primals_2 + s43 = primals_4 + s72 = primals_6 + s71 = primals_8 + s4 = primals_12 + s56 = primals_14 + s84 = primals_16 + s99 = primals_18 + s6 = primals_20 + assert_size_stride(primals_1, (8, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_3, (8, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_5, (8, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_7, (8, 1, 16, s72), (16*s72, 16*s72, s72, 1)) + assert_size_stride(primals_9, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (8, ), (1, )) + assert_size_stride(primals_11, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_13, (8, 1, 16, s4), (16*s4, 16*s4, s4, 1)) + assert_size_stride(primals_15, (8, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_17, (8, 1, s84, 16), (16*s84, 16*s84, 16, 1)) + assert_size_stride(primals_19, (8, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_21, (8, 1, s6, 16), (16*s6, 16*s6, 16, 1)) + with torch.cuda._DeviceGuard(7): + torch.cuda.set_device(7) + buf0 = empty_strided_cuda((8, 32, 2048), (65536, 2048, 1), torch.float32) + buf1 = empty_strided_cuda((8, 32, 2048), (65536, 2048, 1), torch.float32) + buf2 = empty_strided_cuda((8, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream7 = get_raw_stream(7) + triton_tem_fused_0.run(primals_1, primals_3, primals_5, buf0, buf1, primals_9, primals_7, primals_11, primals_13, primals_10, buf2, s0, s72, 16, 8, 32, stream=stream7) + del buf1 + return (buf2, primals_1, primals_3, primals_5, primals_7, primals_9, primals_10, primals_11, primals_13, primals_15, primals_17, primals_19, primals_21, buf2, buf0, s0, s72, s4, s56, s84, s99, s6, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16) + primals_2 = 4096 + primals_3 = rand_strided((8, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:7', dtype=torch.bfloat16) + primals_4 = 4096 + primals_5 = rand_strided((8, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:7', dtype=torch.bfloat16) + primals_6 = 32 + primals_7 = rand_strided((8, 1, 16, 32), (512, 512, 32, 1), device='cuda:7', dtype=torch.int32) + primals_8 = 4096 + primals_9 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:7', dtype=torch.int32) + primals_10 = rand_strided((8, ), (1, ), device='cuda:7', dtype=torch.int64) + primals_11 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:7', dtype=torch.int32) + primals_12 = 32 + primals_13 = rand_strided((8, 1, 16, 32), (512, 512, 32, 1), device='cuda:7', dtype=torch.int32) + primals_14 = 32 + primals_15 = rand_strided((8, 1, 32), (32, 32, 1), device='cuda:7', dtype=torch.int32) + primals_16 = 32 + primals_17 = rand_strided((8, 1, 32, 16), (512, 512, 16, 1), device='cuda:7', dtype=torch.int32) + primals_18 = 32 + primals_19 = rand_strided((8, 1, 32), (32, 32, 1), device='cuda:7', dtype=torch.int32) + primals_20 = 32 + primals_21 = rand_strided((8, 1, 32, 16), (512, 512, 16, 1), device='cuda:7', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/zj/czjkm7k57yqvjckfd6v5afshsyjeafn25srmkljdz6wbhdiqqu7l.py b/SpecForge-ext/cache/compiled_kernels/zj/czjkm7k57yqvjckfd6v5afshsyjeafn25srmkljdz6wbhdiqqu7l.py new file mode 100644 index 0000000000000000000000000000000000000000..245e33bcc9adab961c45086b6ab83220e97f9bbd --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/zj/czjkm7k57yqvjckfd6v5afshsyjeafn25srmkljdz6wbhdiqqu7l.py @@ -0,0 +1,56 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/zl/382525a86cdd789bf42adb1705d045647126c2067725fa125ca2ec0acf0cfbe2.best_config b/SpecForge-ext/cache/compiled_kernels/zl/382525a86cdd789bf42adb1705d045647126c2067725fa125ca2ec0acf0cfbe2.best_config new file mode 100644 index 0000000000000000000000000000000000000000..6c3a3559496e6e4d68292da2e678eca0b03342ab --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/zl/382525a86cdd789bf42adb1705d045647126c2067725fa125ca2ec0acf0cfbe2.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 65, "triton_cache_hash": "NFABHOURJ57C2IKXWDMS2VHZ76PCVKJVD7V6CBWJDLMT5TQE5GFA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/zl/czl6w4fdttcxezteyksfn5e72p4oeudpkffzkhtwid2jbs7q45lr.py b/SpecForge-ext/cache/compiled_kernels/zl/czl6w4fdttcxezteyksfn5e72p4oeudpkffzkhtwid2jbs7q45lr.py new file mode 100644 index 0000000000000000000000000000000000000000..06cafab85e34f5801bb969a3d57365f283b2c2ed --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/zl/czl6w4fdttcxezteyksfn5e72p4oeudpkffzkhtwid2jbs7q45lr.py @@ -0,0 +1,159 @@ +# AOT ID: ['1_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hb/chbrjdlxerx7gk6nuxxzvrjjkodqe76vwjpbpxzupgjmc6dfh6un.py +# Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax] +# Source node to ATen node mapping: +# target_head => convert_element_type +# target_p => div +# Graph fragment: +# %arg0_1 : Tensor "bf16[2, 2048, 32000][65536000, 32000, 1]cuda:1" = PlaceHolder[target=arg0_1] +# %getitem : Tensor "f32[2, 2048, 1][2048, 1, 4096]cuda:1" = PlaceHolder[target=getitem] +# %getitem_1 : Tensor "f32[2, 2048, 1][2048, 1, 4096]cuda:1" = PlaceHolder[target=getitem_1] +# %convert_element_type : Tensor "f32[2, 2048, 32000][65536000, 32000, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%arg0_1, torch.float32), kwargs = {}) +# %prepare_softmax_online_default : [num_users=2] = call_function[target=torch.ops.prims.prepare_softmax_online.default](args = (%convert_element_type, 2), kwargs = {}) +# %sub_tensor : Tensor "f32[2, 2048, 32000][65536000, 32000, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type, %getitem), kwargs = {}) +# %exp_default : Tensor "f32[2, 2048, 32000][65536000, 32000, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub_tensor,), kwargs = {}) +# %div : Tensor "f32[2, 2048, 32000][65536000, 32000, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%exp_default, %getitem_1), kwargs = {}) +# return %getitem,%getitem_1,%div +triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0 = async_compile.triton('triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'add_persistent_rblock': True, 'tiling_scores': {'x': 0, 'r0_': 1310720000}} +) +@triton.jit +def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 4096 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32) + _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + + _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine( + _tmp3_max, _tmp3_sum, tmp2, False + ) + + _tmp3_max = tl.where(r0_mask, _tmp3_max_next, _tmp3_max) + _tmp3_sum = tl.where(r0_mask, _tmp3_sum_next, _tmp3_sum) + + tmp3, tmp4 = triton_helpers.online_softmax_reduce( + _tmp3_max, _tmp3_sum, 1, False) + tmp3 = tmp3[:, None] + tmp4 = tmp4[:, None] + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp6 - tmp3 + tmp8 = libdevice.exp(tmp7) + tmp9 = (tmp8 / tmp4) + tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, = args + args.clear() + assert_size_stride(arg0_1, (2, 2048, 32000), (65536000, 32000, 1)) + with torch.cuda._DeviceGuard(1): + torch.cuda.set_device(1) + buf2 = empty_strided_cuda((2, 2048, 32000), (65536000, 32000, 1), torch.float32) + # Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax] + stream1 = get_raw_stream(1) + triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0.run(arg0_1, buf2, 4096, 32000, stream=stream1) + del arg0_1 + return (buf2, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((2, 2048, 32000), (65536000, 32000, 1), device='cuda:1', dtype=torch.bfloat16) + fn = lambda: call([arg0_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/zl/czlkgr2ysljuz6n7awr3zuc4mf2crnythezffygnjef6dv2xy4uv.py b/SpecForge-ext/cache/compiled_kernels/zl/czlkgr2ysljuz6n7awr3zuc4mf2crnythezffygnjef6dv2xy4uv.py new file mode 100644 index 0000000000000000000000000000000000000000..61175a6777d4e4456580e846751b1097c58b7241 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/zl/czlkgr2ysljuz6n7awr3zuc4mf2crnythezffygnjef6dv2xy4uv.py @@ -0,0 +1,56 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 67108864}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/zl/czlx3bjrur52gcpjbjyrdrlp4dizvnit36zdutjf7kvswid53e2m.py b/SpecForge-ext/cache/compiled_kernels/zl/czlx3bjrur52gcpjbjyrdrlp4dizvnit36zdutjf7kvswid53e2m.py new file mode 100644 index 0000000000000000000000000000000000000000..cfad59e86c021cdad7b8276ff4ba036be0e5cf2f --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/zl/czlx3bjrur52gcpjbjyrdrlp4dizvnit36zdutjf7kvswid53e2m.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 2097152, 262144, 128, 1 + + ZQ = 2 + HQ = 32 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/zr/3e86bca9c722413d3894999293a170744b35b0aa6aa08f1ece4df6248146bb32.best_config b/SpecForge-ext/cache/compiled_kernels/zr/3e86bca9c722413d3894999293a170744b35b0aa6aa08f1ece4df6248146bb32.best_config new file mode 100644 index 0000000000000000000000000000000000000000..e4cb24d6d43a8cad608f96e9c9a993411ec60631 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/zr/3e86bca9c722413d3894999293a170744b35b0aa6aa08f1ece4df6248146bb32.best_config @@ -0,0 +1 @@ +{"XBLOCK": 2, "R0_BLOCK": 16, "num_warps": 2, "num_stages": 1, "configs_hash": "21ad1ee516cd6d15e1fb8e88c10082cd54bef654f8a281c7d5ccd54b6509a685", "found_by_coordesc": false, "time_taken_ms": 29, "triton_cache_hash": "5K5257V2CIYPSOSDD5J3O2K2XOBM6YPUNDVRXIBHB3B26LQ54CSA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/zr/czrritdkatkiq6gisucy3tyfojxxy2il47j3sgudvr7n3phywdxk.py b/SpecForge-ext/cache/compiled_kernels/zr/czrritdkatkiq6gisucy3tyfojxxy2il47j3sgudvr7n3phywdxk.py new file mode 100644 index 0000000000000000000000000000000000000000..b3ea5e30c00bf5fe369e59d7a9af7751d2a03f14 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/zr/czrritdkatkiq6gisucy3tyfojxxy2il47j3sgudvr7n3phywdxk.py @@ -0,0 +1,43 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_clone_slice_sum_transpose_5', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_clone_slice_sum_transpose_5(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (x0 + ks0*r0_2 + ks0*ks1*x1), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp5, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/zw/czwh6tgkq6scdstgzueb3goqqnllndikoasj2i2iehu2qyvoccwt.py b/SpecForge-ext/cache/compiled_kernels/zw/czwh6tgkq6scdstgzueb3goqqnllndikoasj2i2iehu2qyvoccwt.py new file mode 100644 index 0000000000000000000000000000000000000000..5d3beaccd32845f8901efdc372ccf307b31faa32 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/zw/czwh6tgkq6scdstgzueb3goqqnllndikoasj2i2iehu2qyvoccwt.py @@ -0,0 +1,57 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 262144}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i64', 'r0_numel': 'i64', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 151936 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0).to(tl.int64) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None].to(tl.int64) + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :].to(tl.int64) + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 9223372036854775807, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tmp11 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last') + tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32) + tmp4 = tmp2 + tmp3 + tmp5 = tmp2 < 0 + tmp6 = tl.where(tmp5, tmp4, tmp2) + tl.device_assert((0 <= tmp6) & (tmp6 < 151936), "index out of bounds: 0 <= tmp6 < 151936") + tmp8 = tl.load(in_ptr1 + (tmp6), None, eviction_policy='evict_last').to(tl.int1) + tmp9 = tmp8.to(tl.int32) + tmp10 = tmp9.to(tl.int64) + tmp12 = tmp10 * tmp11 + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, None)